alt text     alt text     alt text     alt text     alt text

 

Batch normalization overview

Batch normalization is a technique that stablizes the input distribution of intermediate layers in deepnet during training, this reduces the effect of information morphing and thus helps speed up training (during a few first steps). Batch-norm is claimed to reduce the need of using dropout due to its regularizing effect. As the name Normalization suggests, you simply normalize the output of the layer to zero mean and unit variance:

ˆx=Norm(x,X)

This requires expensive computation on Cov[x] and its inverse square root Cov[x]1/2, so an approximation over each mini-batch during training is proposed. For a layer with input B={x1..m}, besides its original parameters to learn, Batch normalization introduces two other learnable parameters: γ and β. The forwarding proceeds as follows

μB1mmi=1xi
σB1mmi=1(xiμB)2
^xixiμBσ2B+ϵ
yiγ^xi+βBNγ,β(xi)

Goal

The above equation describes how BN handles scalars input. Real implementation, however, deals with much higher dimension vectors/ matrices. As an example, an information volume flowing through convolutional layer is 4 dimensional and we only takes the normalization steps over each feature map, separately.

Let the volume of interest be k dimensional and we are normalizing over the first k1 dimensions, I will derive a sequence of vectorized operations on yLoss - the gradient propagated back to this layer - to produce xLoss and γLoss. βLoss will not be considered as it can be cast as a bias-adding operation and does not belong to the atomic view of batch-normalization.

Denotations

For ease of denotation, I denote xLoss as δx and consider the case of one dimensional vector: xij being the jth feature of ith training example in the batch. The design matrix is then x of size n×f where n is n is the batch size and f is number of features. Although we are limiting the analysis to only one dimensional vectors, the later code is applicable to arbitralily bigger number of dimensions.

For example, a 4 dimensional volume with shape (n,h,w,f) can be considered as (nhw,f) after reshaping. For numpy broadcasting ability, reshaping is not even necessary.

A computational graph view of Batch Normalization

First we break the process into simpler parts

Namely,

m=1nni=1xi
¯xi=xim
v=1nni=1¯xi2
xij=¯xijvj
xBN=γx

In our setting δxBN is available. Gradients flows backwards, so we first consider δx. Each entry xij contribute to the loss only through xBNij, so according to chain rule:

δxij=δxBNijxBNij/xij=γδxBNij
δx=γxBN

Next, we can do either v or ¯x, v is simpler since it contributes to the loss only through x as shown in the graph (while ¯x also contributes to the loss through v). Consider a single entry vj, it contributes to the loss through xij for all i, so according to chain rule:

δvj=iδxijxijvj=iδxij(¯xijv1/2j)vj=1/2v3/2jiδxij¯xij
δvj=1/2v3/2j(ixi¯xi)j

Where denotes elemenet-wise multiplication and the power of 3/2 is applied element-wise. Move on to x2 with v being its mean, the gradient can be easily shown to be uniformly spreaded out from v as follows:

δ¯x2i=1nδv

We are now ready to calculate δ¯x, since it contributes to the loss through x and x2, its gradient consists of two parts, one coming from x and the other from x2. Let’s do the x2 first, since this square is applied element-wise, there is no summing in the derivative chain:

δx2¯xij=δ¯x2ij¯x2ij/¯xij=δ¯x2ij¯x2ij/¯xij=2δ¯x2ij¯xij
δx2¯x=2δx2¯x

For x, ¯xij contributes to the loss through only xij, so there is also no summing in the chain:

δx¯xij=δxij/v1/2j

There is no matrix-wide equation this time, however if we extend the definition of from element-wise to broadcasted mutiplication, then:

δx¯x=v1/2δx

Take the sum of δx2¯x and δx¯x, we have

δ¯x=2δx2¯x+v1/2δx

Now for m, each entry in mj contributes to the loss through the whole jth colume of ¯x, so:

δmj=iδ¯xij¯xij/mj=iδ¯xij(xijmj)/mj=iδ¯xij
δm=iδ¯xi

x contributes to the loss through m and ¯x, so its gradient is the sum of two parts. The part corresponds to m is analogous to that of ¯x2 and v in the sense that one is the row mean of the other. Therefore we can quickly derive that part to be δmx=1nδm.

The other part is also simple, as ¯x=xm, there is no interaction between x and m, hence δ¯xx=δ¯x. So finally δx=δ¯x+1nδm.

Piece the pieces together, then simplify

Remember that the goal is to derive δx, we’ll do it now using the results derived above:

δx=δ¯x+1nδm
=δ¯x1/niδ¯x

Interestingly this is the action of centering δ¯x around zeros, precisely what ¯x did to x.

δ¯x=δxv1/2+2δx2¯x
=γv1/2δxBN+2nδv¯x
=γv1/2δxBN1nv3/2i(δx¯x)i¯x
=γv1/2δxBNγnv3/2i(δxBN¯x)i¯x
=γv1/2(δxBN1ni(δxBN¯xv1/2)i¯xv1/2)
=γv1/2(δxBN1ni(δxBNx)ix)

We are done here. For efficient computation, in the forward pass we will save the value of v1/2 and x. This will not add anything to the computation complexity of forward pass.

The code

Forward pass

def forward(self, x, gamma, is_training):
    if is_training:
        mean = x.mean(self._fd)
        var = x.var(self._fd)
        self._mv_mean.apply_update(mean)
        self._mv_var.apply_update(var)
    else:
        mean = self._mv_mean.val
        var  = self._mv_var.val

    self._rstd = 1. / np.sqrt(var + 1e-8)
    self._normed = (x - mean) * self._rstd
    self._gamma = gamma
    return self._normed * gamma

Backward pass

def backward(self, grad):
    N = np.prod(grad.shape[:-1])
    g_gamma = np.multiply(grad, self._normed)
    g_gamma = g_gamma.sum(self._fd)
    x_ = grad - self._normed * g_gamma * 1. / N
    x_ = self._rstd * self._gamma * x_
    return x_ - x_.mean(self._fd), g_gamma

The code is taken from a Github repo of mine where I am building something similar to an audo-diff DAG graph. Visit if you are interested. I conclude the post here.

 

alt text     alt text     alt text     alt text     alt text