인공지능/PyTorch

[PyTorch] no_grad() 와 eval() 의 차이

고등어찌짐 2022. 6. 24. 23:43

pytorch 에서 사용하는 no_grad() 와 eval() 의 차이점을 정리해보자


torch.no_grad( )

gradient 계산을 꺼주는 context-manager 이다. 
python 에서 Context Manager 는 with 문으로 호출할 수 있는 객체로, 리소르를 할당하고 제공하는 역할을 한다. 

해당 모드를 사용하면 requires_grad=True 일 때보다 메모리 소비를 줄일 수 있기 때문에, Tensor.backward() 계산을 하지 않아도 되는 경우라면 해당 모드를 통해 gradient 계산을 하지 않는 것이 inference 를 할 때 효율적이다.

 

즉, gradient 계산이 필요하지 않은 경우 해당 모드를 적용해 메모리를 더 적게 사용하고, 더 빠르게 계산할 수 있도록 한다. 

 

또 이 모드를 켜게 되면 input 이 requires_grad=True 더라도 모든 계산에 대해 no_grad() 가 적용되고, 해당 context manager 는 thread local 하기 때문에 다른 thread 에 영향을 주지 않는다.  

 

데코레이터 함수로도 활용 가능하다.

x = torch.tensor([1], requires_grad=True)
with torch.no_grad():
  y = x * 2
y.requires_grad
@torch.no_grad()
def doubler(x):
    return x * 2
z = doubler(x)
z.requires_grad

eval( )

모델 평가시, 모델 파라미터를 로드하고 eval() 모드를 적용해주어야 한다. 

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

모델 평가 모드에서 꺼야하는 layer 들을 자동으로 꺼주는 역할을 한다. 

평가 시에는 학습 때와는 다르게 BatchNorm, DropOut 과 같은 layer 는 꺼주어야 하는데 이걸 자동으로 해준다. 

해당 모드를 사용하지 않으면, 일관성 없는 inference 결과를 얻게 된다. 

self.train(False) 를 해도 같은 효과를 얻을 수 있다. 


#참조

https://pytorch.org/docs/stable/generated/torch.no_grad.html
https://bloowhale.tistory.com/m/121

https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=eval#torch.nn.Module.eval