- Batch normalization overview
- A computational graph view of Batch Normalization
- Piece the pieces together, then simplify
- The code
Batch normalization overview
Batch normalization is a technique that stablizes the input distribution of intermediate layers in deepnet during training, this reduces the effect of information morphing and thus helps speed up training (during a few first steps). Batch-norm is claimed to reduce the need of using dropout due to its regularizing effect. As the name Normalization
suggests, you simply normalize the output of the layer to zero mean and unit variance:
This requires expensive computation on Cov[x] and its inverse square root Cov[x]−1/2, so an approximation over each mini-batch during training is proposed. For a layer with input B={x1..m}, besides its original parameters to learn, Batch normalization
introduces two other learnable parameters: γ and β. The forwarding proceeds as follows
Goal
The above equation describes how BN handles scalars input. Real implementation, however, deals with much higher dimension vectors/ matrices. As an example, an information volume flowing through convolutional layer is 4 dimensional and we only takes the normalization steps over each feature map, separately.
Let the volume of interest be k dimensional and we are normalizing over the first k−1 dimensions, I will derive a sequence of vectorized operations on ▽yLoss - the gradient propagated back to this layer - to produce ▽xLoss and ▽γLoss. ▽βLoss will not be considered as it can be cast as a bias-adding operation and does not belong to the atomic view of batch-normalization.
Denotations
For ease of denotation, I denote ▽xLoss as δx and consider the case of one dimensional vector: xij being the jth feature of ith training example in the batch. The design matrix is then x of size n×f where n is n is the batch size and f is number of features. Although we are limiting the analysis to only one dimensional vectors, the later code is applicable to arbitralily bigger number of dimensions.
For example, a 4 dimensional volume with shape (n,h,w,f) can be considered as (n∗h∗w,f) after reshaping. For numpy broadcasting ability, reshaping is not even necessary.
A computational graph view of Batch Normalization
First we break the process into simpler parts

Namely,
m=1nn∑i=1xiIn our setting δxBN is available. Gradients flows backwards, so we first consider δx∗. Each entry x∗ij contribute to the loss only through xBNij, so according to chain rule:
δx∗ij=δxBNij∗∂xBNij/∂x∗ij=γδxBNijNext, we can do either v or ¯x, v is simpler since it contributes to the loss only through x∗ as shown in the graph (while ¯x also contributes to the loss through v). Consider a single entry vj, it contributes to the loss through x∗ij for all i, so according to chain rule:
δvj=∑iδx∗ij∂x∗ij∂vj=∑iδx∗ij∂(¯xijv−1/2j)∂vj=−1/2v−3/2j∑iδx∗ij¯xijWhere ⊙ denotes elemenet-wise multiplication and the power of −3/2 is applied element-wise. Move on to x2 with v being its mean, the gradient can be easily shown to be uniformly spreaded out from v as follows:
δ¯x2i=1nδvWe are now ready to calculate δ¯x, since it contributes to the loss through x∗ and x2, its gradient consists of two parts, one coming from x∗ and the other from x2. Let’s do the x2 first, since this square is applied element-wise, there is no summing in the derivative chain:
δx2¯xij=δ¯x2ij∂¯x2ij/∂¯xij=δ¯x2ij∂¯x2ij/∂¯xij=2δ¯x2ij¯xijFor x∗, ¯xij contributes to the loss through only x∗ij, so there is also no summing in the chain:
δx∗¯xij=δx∗ij/v−1/2jThere is no matrix-wide equation this time, however if we extend the definition of ⊙ from element-wise to broadcasted mutiplication, then:
δx∗¯x=v−1/2δx∗Take the sum of δx2¯x and δx∗¯x, we have
δ¯x=2δx2⊙¯x+v−1/2δx∗Now for m, each entry in mj contributes to the loss through the whole jth colume of ¯x, so:
δmj=∑iδ¯xij∂¯xij/∂mj=∑iδ¯xij∂(xij−mj)/∂mj=−∑iδ¯xijx contributes to the loss through m and ¯x, so its gradient is the sum of two parts. The part corresponds to m is analogous to that of ¯x2 and v in the sense that one is the row mean of the other. Therefore we can quickly derive that part to be δmx=1nδm.
The other part is also simple, as ¯x=x−m, there is no interaction between x and m, hence δ¯xx=δ¯x. So finally δx=δ¯x+1nδm.
Piece the pieces together, then simplify
Remember that the goal is to derive δx, we’ll do it now using the results derived above:
δx=δ¯x+1nδmInterestingly this is the action of centering δ¯x around zeros, precisely what ¯x did to x.
δ¯x=δx∗v−1/2+2δx2⊙¯xWe are done here. For efficient computation, in the forward pass we will save the value of v−1/2 and x∗. This will not add anything to the computation complexity of forward pass.
The code
Forward pass
def forward(self, x, gamma, is_training):
if is_training:
mean = x.mean(self._fd)
var = x.var(self._fd)
self._mv_mean.apply_update(mean)
self._mv_var.apply_update(var)
else:
mean = self._mv_mean.val
var = self._mv_var.val
self._rstd = 1. / np.sqrt(var + 1e-8)
self._normed = (x - mean) * self._rstd
self._gamma = gamma
return self._normed * gamma
Backward pass
def backward(self, grad):
N = np.prod(grad.shape[:-1])
g_gamma = np.multiply(grad, self._normed)
g_gamma = g_gamma.sum(self._fd)
x_ = grad - self._normed * g_gamma * 1. / N
x_ = self._rstd * self._gamma * x_
return x_ - x_.mean(self._fd), g_gamma
The code is taken from a Github repo of mine where I am building something similar to an audo-diff DAG graph. Visit if you are interested. I conclude the post here.