인공지능/PyTorch

[PyTorch] Batch Normalization 와 running_mean / running_var

고등어찌짐 2022. 7. 9. 23:17

 

PyTorch로 구현된 논문 코드 리뷰 중,  모델 레이어의 weight 들의 mean, std 계산을 다음과 같이 구현한 코드를 보았다. 

 

vector = torch.cat([
    torch.mean(self.conv3.bn.running_mean).view(1), torch.std(self.conv3.bn.running_mean).view(1),
    torch.mean(self.conv3.bn.running_var).view(1), torch.std(self.conv3.bn.running_var).view(1),
    torch.mean(self.conv5.bn.running_mean).view(1), torch.std(self.conv5.bn.running_mean).view(1),
    torch.mean(self.conv5.bn.running_var).view(1), torch.std(self.conv5.bn.running_var).view(1)
])

 

여기서 bn 은 다른 클래스에서 선언한 nn.BatchNorm2d인데, nn.BatchNorm2d 에  running_mean과 runnging_var라는 것이 있다는 걸 처음 알았다.  그래서 Batch normalization과 Pytorch의 nn.BatchNorm에 대해 좀 더 정리해보려고 한다. 


Batch Normalization 

Batch Normalization 은 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift, ( Seguey Ioffe and Christian SzegedyICML'2015 )에서 제안되었다.  해당 논문에서 BatchNormalization 은 internal covariate shift를 줄임으로써, 딥한 네트워크를 학습하는 속도를 높이는 방법으로 제시되었다.  Batch Normalization 은 학습을 더 빠르게 하고 오버 피팅 방지의 목적으로 사용한다.  Batch Normalization 은 무엇을 목적으로 만들어진 걸까? 

 

covariate shift 

깊이가 깊은 네트워크를 학습할 때 역전파가 일어나면서 출력층부터 입력층까지, 레이어의 weight 값들이 바뀌게 된다.  이전 레이어의 파라미터 값이 변경되면 그다음 레이어도 영향을 받으면서, 레이어를 통과할 때마다 레이어 입력값의 distribution 이 계속해서 변하게 되는 것이다.  이런 변화를 covariate shift라고 말한다. 

 

internal covariate shift

모델 학습 시 인풋 데이터가 각 레이어를 통과할 때마다, 혹은 활성화 함수를 통과할 때마다 레이어 입력 값의 distribution 이 바뀌기 때문에 ( covariate shift ) 모델이 안정적으로 학습되지 않을 수 있다. 이걸 internal covariate shift라고 한다. 

 

그러니까, Batch normalization 은 internal covariate shift  문제를 해결하기 위해 내부 노드들의 distribution 이 크게 변하지 않도록  각 배치마다 레이어 입력값들을 표준화한다. 그래서 각 레이어마다 정규화 레이어를 통해 배치 데이터들의  distribution 이 큰 폭으로 바뀌지 않도록 조절한다. 이 과정은 internal covariate shift를 줄여서 그래디언트가 이전보다 더 예측 가능하도록 보장해서, 네트워크가 빠르게 수렴할 수 있도록 한다.  즉 훈련 과정이 이전보다 안정화될 수 있고 또 더 빠르게 학습될 수 있는 것이다! 


Pytorch의 BatchNorm2 d

pytorch에서도 Batch normalization을 사용할 수 있는데, 위 논문을 바탕으로 함수가 구현되었다. 2, 3차원 입력에는 BatchNorm1 d를 4차원 입력에는 BatchNrom2 d를 사용할 수 있다. 

 

torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)

 

nn.BatchNorm2d : running_mean / running_var

vector = torch.cat([
    torch.mean(self.conv3.bn.running_mean).view(1), torch.std(self.conv3.bn.running_mean).view(1),
    torch.mean(self.conv3.bn.running_var).view(1), torch.std(self.conv3.bn.running_var).view(1),
    torch.mean(self.conv5.bn.running_mean).view(1), torch.std(self.conv5.bn.running_mean).view(1),
    torch.mean(self.conv5.bn.running_var).view(1), torch.std(self.conv5.bn.running_var).view(1)
])

 

그럼 대체 running_mean과 running_var는 뭘 의미하는 걸까? 

 

Batch normalization 은 각 배치마다 배치 데이터를 정규화할 때 평균값과 분산 값을 이용한다. 배치가 끝날 때마다 측정되는 통계 값 ( 평균, 분산 )을 메모리에 저장해서 정규화를 위해 어떤 통계 값을 계산할 때 이 값을 꺼내 써야 하는데 그게 바로 running_mean과 running_var ( variance ) 값이다. 그리고 이 값들이 다음 통계를 업데이트할 때 얼마나 영향을 받을 수 있게 할지 조절하는 값이 바로 momentum 값이다. 


#참조

Pytorch official document - BATCHNORM2D

A Gentle Introduction to Batch Normalization for Deep Neural Networks
What do BatchNorm2d's running_mean / running_var mean in PyTorch?

[Deep Learning] Batch Normalization (배치 정규화)