Comments (1)
Hi, I tried to modify the original codes to fit the case you mentioned:
class SwitchNorm1d(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.997, using_moving_average=True):
super(SwitchNorm1d, self).__init__()
self.eps = eps
self.momentum = momentum
self.using_moving_average = using_moving_average
self.weight = nn.Parameter(torch.ones(1, num_features, 1))
self.bias = nn.Parameter(torch.zeros(1, num_features, 1))
self.mean_weight = nn.Parameter(torch.ones(3))
self.var_weight = nn.Parameter(torch.ones(3))
self.register_buffer('running_mean', torch.zeros(1, num_features, 1))
self.register_buffer('running_var', torch.zeros(1, num_features, 1))
self.reset_parameters()
def reset_parameters(self):
self.running_mean.zero_()
self.running_var.zero_()
self.weight.data.fill_(1)
self.bias.data.zero_()
def _check_input_dim(self, input):
if input.dim() != 3:
raise ValueError('expected 3D input (got {}D input)'
.format(input.dim()))
def forward(self, x): # (B, C, L)
self._check_input_dim(x)
mean_ln = x.mean(1, keepdim=True) # (B, 1, L)
var_ln = x.var(1, keepdim=True)
mean_in = x.mean(-1, keepdim=True) # (B, C, 1)
var_in = x.var(-1, keepdim=True)
temp = var_in + mean_in ** 2
if self.training:
mean_bn = mean_in.mean(0, keepdim=True) # (1, C, 1)
var_bn = temp.mean(0, keepdim=True) - mean_bn ** 2
if self.using_moving_average:
self.running_mean.mul_(self.momentum)
self.running_mean.add_((1 - self.momentum) * mean_bn.data)
self.running_var.mul_(self.momentum)
self.running_var.add_((1 - self.momentum) * var_bn.data)
else:
self.running_mean.add_(mean_bn.data)
self.running_var.add_(mean_bn.data ** 2 + var_bn.data)
else:
mean_bn = torch.autograd.Variable(self.running_mean)
var_bn = torch.autograd.Variable(self.running_var)
softmax = nn.Softmax(0)
mean_weight = softmax(self.mean_weight)
var_weight = softmax(self.var_weight)
mean = mean_weight[0] * mean_in + mean_weight[1] * mean_ln + mean_weight[2] * mean_bn
var = var_weight[0] * var_in + var_weight[1] * var_ln + var_weight[2] * var_bn
x = (x - mean) / (var + self.eps).sqrt()
return x * self.weight + self.bias
For temporal data, LN is slightly different from how it calculates for images (SwitchNorm2d). It should work as desired :)
from switchable-normalization.
Related Issues (20)
- how to mix sn and bn HOT 2
- Switchable Normalization ne
- caffe
- traing time?
- why not add gn HOT 2
- cannot apply switchnorm1d to 3D input? HOT 1
- Problems about Usage of SyncSN HOT 12
- where can we find the meta files in the updated loader? HOT 1
- ResNet-50 uses Bottleneck Block HOT 2
- Difference between resnetv1 and resnetv2? HOT 1
- Switchable Norm v.s. IBN-Net?
- The value of weight in Figure 7? HOT 1
- about SwitchNorm3d HOT 1
- Could you share the resnet-101 model pretrained on Imagenet? HOT 16
- Nan error caused by “N X C X 1 X 1” input features HOT 4
- I complete the SN by Keras. welcome to advice HOT 2
- BackPropagation?
- when I use SN instead of BN, there is a big difference between val acc and train acc HOT 1
- Failed to access ResNet101v1+SN (8,32) HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from switchable-normalization.