인공지능/PyTorch

[PyTorch] 모델 로드 시 이슈 및 해결방법 정리

고등어찌짐 2022. 6. 25. 16:09

파이토치로 모델을 로드할 때, 특히 오픈소스를 활용할 때 논문의 결과가 정말로 그렇게 나오는지 테스트해보고 싶은데... 모델이 제대로 로드되지 않은 경우가 있었다. 모델 save load 가 제대로 되지 않으니 많은 시간을 허비했다. 그 동안 삽질하며 찾았던 여러가지 모델 로드 방법을 정리한다.


일반적인 모델 로드 방법

먼저 가장 일반적이고 파이토치 공식 문서가 추천하는 방식이다. state_dict() 를 이용해 wegiht 값을 dict 형식으로 save, load 한다.

torch.save(model.state_dict(), PATH)
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

모델 로드 시 발생 가능한 이슈들

하지만 위 방법으로 모델이 잘 로드되지 않는 경우가 있다. 이전에 겪었던 이슈들과 해결방법은 다음과 같다.

 

Issue 1. Missing keys & unexpected keys in state_dict

RuntimeError: Error(s) in loading state_dict for VGG: Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias". Unexpected key(s) in state_dict: "features.module.0.weight", "features.module.0.bias", "features.module.2.weight", "features.module.2.bias", "features.module.5.weight", "features.module.5.bias", "features.module.7.weight", "features.module.7.bias", "features.module.10.weight", "features.module.10.bias", "features.module.12.weight", "features.module.12.bias", "features.module.14.weight", "features.module.14.bias", "features.module.17.weight", "features.module.17.bias", "features.module.19.weight", "features.module.19.bias", "features.module.21.weight", "features.module.21.bias", "features.module.24.weight", "features.module.24.bias", "features.module.26.weight", "features.module.26.bias", "features.module.28.weight", "features.module.28.bias".

모델이 학습되어 state 가 저장된 개발 환경과, 모델 state 를 로드하는 환경이 다른 경우 발생할 수 있다. 저장된 model 의 key 값과 로드 시 key 값이 pytorch 버전 등이 맞지 않는 경우 다른 형태로 저장될 수 있다고 한다.

 

방법1. strict = False 사용

model.load_state_dict(torch.load(PATH), strict=False)

이 옵션을 사용하면 key 값이 서로 다르더라도 일치하는 key 값에 대한 weight 만 불러온다. 그렇기 때문에 위 문제가 해결된다. 다만, 이 방법을 사용하면 **누락되는 weight 값**이 생길 수 있다. 그래서 학습한 weight 값이 로드되어 같은 모델 아키텍쳐에 적용되고 있는데도, > 모델의 key 값이 다르다면 inference 값은 전혀 다른 값이 나타나는 대참사가 일어날 수 있다. 👉 Issue 2

 

방법2. key 를 직접 수정

이 이슈를 해결하는 두 번째 방법은 말 그대로 수작업(?)이다. error 코드를 보면서 model 의 state_dict 값을 돌면서 매칭되지 않는 key 값을 모조리 수정하는 것이다.

checkpoint = torch.load(weights_path, map_location=self.device)['model_state_dict'] 
for key in list(checkpoint.keys()): 
	if 'model.' in key: 
    	checkpoint[key.replace('model.', '')] = checkpoint[key] 
		del checkpoint[key] self.model.load_state_dict(checkpoint)

 

이렇게 checkpoint 에 직접 필요한 key 값을 모델에 맞게 수정해주면 된다.


Issue 2. 모델을 로드했지만 predict 성능이 train, validate 성능과 다름

가끔 모델을 서빙한다거나 pretrain 된 모델을 가져오려고 할 때 겪었던 이슈이다.

 

방법 1.  seed 를 설정

seed = 20220625
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)

레이어 내 가중치를 초기화할 때 난수의 영향을 받는 것으로 알고 있다. 이건 사실 tensorflow 를 사용할 때 겪었던 이슈인데 검색해보니 pytorch 에도 비슷한 이슈가 있는 것 같다. 모델 학습 전과 로드 시에 고정된 seed 값을 사용해서 같은 seed 를 기반으로 model 를 로드하는 것으로 위 문제가 해결된다. 다만, 특정 시드와 가장 잘 작동되는 매개변수를 사용하는 방법이 학습될 가능성이 있는 듯 하다.

 

방법 2. model.eval() 모드를 적용했는지 확인

학습 이후에 필요 없어진 dropout, batchnorm 과 같은 기능을 비활성화해 추론할 수 있도록 모델 모드를 변경해준다. 역시 테스트 코드도 함께 작성하는게 좋을 것 같다.

 

방법 3.  strict = True 옵션 적용

말그대로 strict 하게 key 가 완전히 같아야만 weight 를 똑같이 가져올 수 있다. 근데 ... 그랬으면 여기까지 안왔겠지 ㅠㅠ 이 옵션을 쓰면 마주칠 수 있는 이슈가 바로 Issue 3 이다.

 

방법 4. 모델 전체를 저장

python torch.save(model, PATH) 
model = torch.load(PATH) 
model.eval()

사실 여기까지 했는데 다 안돼서 그냥 모델 전체를 저장해버리기로 했다. 이 방법은 공간도 많이 차지하고 모델 아키텍쳐를 조금 수정했을 때 flexible 하게 원하는 key 의 weight 값만 가져올 수가 없다. 그래서 권장되는 방법이 아니다. 하지만 시간은 없는데 결과는 있어야하니.. 어쩔 수 없었다. 여기까지 시도했는데 또 이슈가 발생한다.


Issues 3.  Can't pickle local object \'AtomicModel.get_metrics.\<locals\>.\<lambda\>

AttributeError: Can't pickle local object 'AtomicModel.get_metrics..' AttributeError: Can't pickle local object 'AtomicModel._get_metrics.._accuracy_score'

 

방법. dill package

 모델에 python lambda 식을 사용해 layer 가 정의된 경우, 일반적인 방법으로는 모델 전체를 저장할 수 없다.

import dill model_copy=dill.dumps(model) 
torch.save(model_copy, ‘my_model.pt’) 
model = torch.load(model_name) 
model=dill.loads(model)

dill 이라는 패키지를 사용하면 이 문제를 해결할 수 있다. serializable ( 직렬화 : 객체 저장시 데이터를 줄세워 저장 ) 하지않은 변수들을 저장할 때 사용하는 패키지이다. ( pickle 과 반대 )

 


#참조

파이토치 공식 문서 모델 저장하기 & 불러오기

The result is different when I apply torch.manual_seed before loading cuda() after loading the model

torch.no_grad()와 model.eval()의 차이

pytorch 모델 저장 및 불러오기 관련

Missing keys & unexpected keys in state_dict when loading self trained model

Pytorch: model save and load - 文章整合

“pytorch dill model save” Code Answer