PyTorch로 구현된 논문 코드 리뷰 중, 모델 레이어의 weight 들의 mean, std 계산을 다음과 같이 구현한 코드를 보았다.
vector = torch.cat([
torch.mean(self.conv3.bn.running_mean).view(1), torch.std(self.conv3.bn.running_mean).view(1),
torch.mean(self.conv3.bn.running_var).view(1), torch.std(self.conv3.bn.running_var).view(1),
torch.mean(self.conv5.bn.running_mean).view(1), torch.std(self.conv5.bn.running_mean).view(1),
torch.mean(self.conv5.bn.running_var).view(1), torch.std(self.conv5.bn.running_var).view(1)
])
여기서 bn 은 다른 클래스에서 선언한 nn.BatchNorm2d인데, nn.BatchNorm2d 에 running_mean과 runnging_var라는 것이 있다는 걸 처음 알았다. 그래서 Batch normalization과 Pytorch의 nn.BatchNorm에 대해 좀 더 정리해보려고 한다.
Batch Normalization
Batch Normalization 은 Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift, ( Seguey Ioffe and Christian Szegedy, ICML'2015 )에서 제안되었다. 해당 논문에서 BatchNormalization 은 internal covariate shift를 줄임으로써, 딥한 네트워크를 학습하는 속도를 높이는 방법으로 제시되었다. Batch Normalization 은 학습을 더 빠르게 하고 오버 피팅 방지의 목적으로 사용한다. Batch Normalization 은 무엇을 목적으로 만들어진 걸까?
covariate shift
깊이가 깊은 네트워크를 학습할 때 역전파가 일어나면서 출력층부터 입력층까지, 레이어의 weight 값들이 바뀌게 된다. 이전 레이어의 파라미터 값이 변경되면 그다음 레이어도 영향을 받으면서, 레이어를 통과할 때마다 레이어 입력값의 distribution 이 계속해서 변하게 되는 것이다. 이런 변화를 covariate shift라고 말한다.
internal covariate shift
모델 학습 시 인풋 데이터가 각 레이어를 통과할 때마다, 혹은 활성화 함수를 통과할 때마다 레이어 입력 값의 distribution 이 바뀌기 때문에 ( covariate shift ) 모델이 안정적으로 학습되지 않을 수 있다. 이걸 internal covariate shift라고 한다.
그러니까, Batch normalization 은 internal covariate shift 문제를 해결하기 위해 내부 노드들의 distribution 이 크게 변하지 않도록 각 배치마다 레이어 입력값들을 표준화한다. 그래서 각 레이어마다 정규화 레이어를 통해 배치 데이터들의 distribution 이 큰 폭으로 바뀌지 않도록 조절한다. 이 과정은 internal covariate shift를 줄여서 그래디언트가 이전보다 더 예측 가능하도록 보장해서, 네트워크가 빠르게 수렴할 수 있도록 한다. 즉 훈련 과정이 이전보다 안정화될 수 있고 또 더 빠르게 학습될 수 있는 것이다!
Pytorch의 BatchNorm2 d
pytorch에서도 Batch normalization을 사용할 수 있는데, 위 논문을 바탕으로 함수가 구현되었다. 2, 3차원 입력에는 BatchNorm1 d를 4차원 입력에는 BatchNrom2 d를 사용할 수 있다.
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
torch.nn.BatchNorm2d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True, device=None, dtype=None)
nn.BatchNorm2d : running_mean / running_var
vector = torch.cat([
torch.mean(self.conv3.bn.running_mean).view(1), torch.std(self.conv3.bn.running_mean).view(1),
torch.mean(self.conv3.bn.running_var).view(1), torch.std(self.conv3.bn.running_var).view(1),
torch.mean(self.conv5.bn.running_mean).view(1), torch.std(self.conv5.bn.running_mean).view(1),
torch.mean(self.conv5.bn.running_var).view(1), torch.std(self.conv5.bn.running_var).view(1)
])
그럼 대체 running_mean과 running_var는 뭘 의미하는 걸까?
Batch normalization 은 각 배치마다 배치 데이터를 정규화할 때 평균값과 분산 값을 이용한다. 배치가 끝날 때마다 측정되는 통계 값 ( 평균, 분산 )을 메모리에 저장해서 정규화를 위해 어떤 통계 값을 계산할 때 이 값을 꺼내 써야 하는데 그게 바로 running_mean과 running_var ( variance ) 값이다. 그리고 이 값들이 다음 통계를 업데이트할 때 얼마나 영향을 받을 수 있게 할지 조절하는 값이 바로 momentum 값이다.
#참조
Pytorch official document - BATCHNORM2D
A Gentle Introduction to Batch Normalization for Deep Neural Networks
What do BatchNorm2d's running_mean / running_var mean in PyTorch?
[Deep Learning] Batch Normalization (배치 정규화)
'인공지능 > PyTorch' 카테고리의 다른 글
[PyTorch] 모델 로드 시 이슈 및 해결방법 정리 (0) | 2022.06.25 |
---|---|
[PyTorch] no_grad() 와 eval() 의 차이 (0) | 2022.06.24 |
[PyTorch] optimizer.zero_grad() 를 하는 이유 (0) | 2022.01.28 |