- Batch normalization overview
- A computational graph view of Batch Normalization
- Piece the pieces together, then simplify
- The code
Batch normalization overview
Batch normalization is a technique that stablizes the input distribution of intermediate layers in deepnet during training, this reduces the effect of information morphing and thus helps speed up training (during a few first steps). Batch-norm is claimed to reduce the need of using dropout due to its regularizing effect. As the name
Normalization suggests, you simply normalize the output of the layer to zero mean and unit variance:
This requires expensive computation on and its inverse square root , so an approximation over each mini-batch during training is proposed. For a layer with input , besides its original parameters to learn,
Batch normalization introduces two other learnable parameters: and . The forwarding proceeds as follows
The above equation describes how BN handles scalars input. Real implementation, however, deals with much higher dimension vectors/ matrices. As an example, an information volume flowing through convolutional layer is 4 dimensional and we only takes the normalization steps over each feature map, separately.
Let the volume of interest be dimensional and we are normalizing over the first dimensions, I will derive a sequence of vectorized operations on - the gradient propagated back to this layer - to produce and . will not be considered as it can be cast as a bias-adding operation and does not belong to the atomic view of batch-normalization.
For ease of denotation, I denote as and consider the case of one dimensional vector: being the jth feature of ith training example in the batch. The design matrix is then of size where is n is the batch size and is number of features. Although we are limiting the analysis to only one dimensional vectors, the later code is applicable to arbitralily bigger number of dimensions.
For example, a 4 dimensional volume with shape can be considered as after reshaping. For numpy broadcasting ability, reshaping is not even necessary.
A computational graph view of Batch Normalization
First we break the process into simpler parts
In our setting is available. Gradients flows backwards, so we first consider . Each entry contribute to the loss only through , so according to chain rule:
Next, we can do either or , is simpler since it contributes to the loss only through as shown in the graph (while also contributes to the loss through ). Consider a single entry , it contributes to the loss through for all , so according to chain rule:
Where denotes elemenet-wise multiplication and the power of is applied element-wise. Move on to with being its mean, the gradient can be easily shown to be uniformly spreaded out from as follows:
We are now ready to calculate , since it contributes to the loss through and , its gradient consists of two parts, one coming from and the other from . Let’s do the first, since this square is applied element-wise, there is no summing in the derivative chain:
For , contributes to the loss through only , so there is also no summing in the chain:
There is no matrix-wide equation this time, however if we extend the definition of from element-wise to broadcasted mutiplication, then:
Take the sum of and , we have
Now for , each entry in contributes to the loss through the whole colume of , so:
contributes to the loss through and , so its gradient is the sum of two parts. The part corresponds to is analogous to that of and in the sense that one is the row mean of the other. Therefore we can quickly derive that part to be .
The other part is also simple, as , there is no interaction between and , hence . So finally .
Piece the pieces together, then simplify
Remember that the goal is to derive , we’ll do it now using the results derived above:
Interestingly this is the action of centering around zeros, precisely what did to .
We are done here. For efficient computation, in the forward pass we will save the value of and . This will not add anything to the computation complexity of forward pass.
def forward(self, x, gamma, is_training): if is_training: mean = x.mean(self._fd) var = x.var(self._fd) self._mv_mean.apply_update(mean) self._mv_var.apply_update(var) else: mean = self._mv_mean.val var = self._mv_var.val self._rstd = 1. / np.sqrt(var + 1e-8) self._normed = (x - mean) * self._rstd self._gamma = gamma return self._normed * gamma
def backward(self, grad): N = np.prod(grad.shape[:-1]) g_gamma = np.multiply(grad, self._normed) g_gamma = g_gamma.sum(self._fd) x_ = grad - self._normed * g_gamma * 1. / N x_ = self._rstd * self._gamma * x_ return x_ - x_.mean(self._fd), g_gamma
The code is taken from a Github repo of mine where I am building something similar to an audo-diff DAG graph. Visit if you are interested. I conclude the post here.