[PyTorch] 모델 로드 시 이슈 및 해결방법 정리
파이토치로 모델을 로드할 때, 특히 오픈소스를 활용할 때 논문의 결과가 정말로 그렇게 나오는지 테스트해보고 싶은데... 모델이 제대로 로드되지 않은 경우가 있었다. 모델 save load 가 제대로 되지 않으니 많은 시간을 허비했다. 그 동안 삽질하며 찾았던 여러가지 모델 로드 방법을 정리한다.
일반적인 모델 로드 방법
먼저 가장 일반적이고 파이토치 공식 문서가 추천하는 방식이다. state_dict() 를 이용해 wegiht 값을 dict 형식으로 save, load 한다.
torch.save(model.state_dict(), PATH)
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()
모델 로드 시 발생 가능한 이슈들
하지만 위 방법으로 모델이 잘 로드되지 않는 경우가 있다. 이전에 겪었던 이슈들과 해결방법은 다음과 같다.
Issue 1. Missing keys & unexpected keys in state_dict
RuntimeError: Error(s) in loading state_dict for VGG: Missing key(s) in state_dict: "features.0.weight", "features.0.bias", "features.2.weight", "features.2.bias", "features.5.weight", "features.5.bias", "features.7.weight", "features.7.bias", "features.10.weight", "features.10.bias", "features.12.weight", "features.12.bias", "features.14.weight", "features.14.bias", "features.17.weight", "features.17.bias", "features.19.weight", "features.19.bias", "features.21.weight", "features.21.bias", "features.24.weight", "features.24.bias", "features.26.weight", "features.26.bias", "features.28.weight", "features.28.bias". Unexpected key(s) in state_dict: "features.module.0.weight", "features.module.0.bias", "features.module.2.weight", "features.module.2.bias", "features.module.5.weight", "features.module.5.bias", "features.module.7.weight", "features.module.7.bias", "features.module.10.weight", "features.module.10.bias", "features.module.12.weight", "features.module.12.bias", "features.module.14.weight", "features.module.14.bias", "features.module.17.weight", "features.module.17.bias", "features.module.19.weight", "features.module.19.bias", "features.module.21.weight", "features.module.21.bias", "features.module.24.weight", "features.module.24.bias", "features.module.26.weight", "features.module.26.bias", "features.module.28.weight", "features.module.28.bias".
모델이 학습되어 state 가 저장된 개발 환경과, 모델 state 를 로드하는 환경이 다른 경우 발생할 수 있다. 저장된 model 의 key 값과 로드 시 key 값이 pytorch 버전 등이 맞지 않는 경우 다른 형태로 저장될 수 있다고 한다.
방법1. strict = False 사용
model.load_state_dict(torch.load(PATH), strict=False)
이 옵션을 사용하면 key 값이 서로 다르더라도 일치하는 key 값에 대한 weight 만 불러온다. 그렇기 때문에 위 문제가 해결된다. 다만, 이 방법을 사용하면 **누락되는 weight 값**이 생길 수 있다. 그래서 학습한 weight 값이 로드되어 같은 모델 아키텍쳐에 적용되고 있는데도, > 모델의 key 값이 다르다면 inference 값은 전혀 다른 값이 나타나는 대참사가 일어날 수 있다. 👉 Issue 2
방법2. key 를 직접 수정
이 이슈를 해결하는 두 번째 방법은 말 그대로 수작업(?)이다. error 코드를 보면서 model 의 state_dict 값을 돌면서 매칭되지 않는 key 값을 모조리 수정하는 것이다.
checkpoint = torch.load(weights_path, map_location=self.device)['model_state_dict']
for key in list(checkpoint.keys()):
if 'model.' in key:
checkpoint[key.replace('model.', '')] = checkpoint[key]
del checkpoint[key] self.model.load_state_dict(checkpoint)
이렇게 checkpoint 에 직접 필요한 key 값을 모델에 맞게 수정해주면 된다.
Issue 2. 모델을 로드했지만 predict 성능이 train, validate 성능과 다름
가끔 모델을 서빙한다거나 pretrain 된 모델을 가져오려고 할 때 겪었던 이슈이다.
방법 1. seed 를 설정
seed = 20220625
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
레이어 내 가중치를 초기화할 때 난수의 영향을 받는 것으로 알고 있다. 이건 사실 tensorflow 를 사용할 때 겪었던 이슈인데 검색해보니 pytorch 에도 비슷한 이슈가 있는 것 같다. 모델 학습 전과 로드 시에 고정된 seed 값을 사용해서 같은 seed 를 기반으로 model 를 로드하는 것으로 위 문제가 해결된다. 다만, 특정 시드와 가장 잘 작동되는 매개변수를 사용하는 방법이 학습될 가능성이 있는 듯 하다.
방법 2. model.eval() 모드를 적용했는지 확인
학습 이후에 필요 없어진 dropout, batchnorm 과 같은 기능을 비활성화해 추론할 수 있도록 모델 모드를 변경해준다. 역시 테스트 코드도 함께 작성하는게 좋을 것 같다.
방법 3. strict = True 옵션 적용
말그대로 strict 하게 key 가 완전히 같아야만 weight 를 똑같이 가져올 수 있다. 근데 ... 그랬으면 여기까지 안왔겠지 ㅠㅠ 이 옵션을 쓰면 마주칠 수 있는 이슈가 바로 Issue 3 이다.
방법 4. 모델 전체를 저장
python torch.save(model, PATH)
model = torch.load(PATH)
model.eval()
사실 여기까지 했는데 다 안돼서 그냥 모델 전체를 저장해버리기로 했다. 이 방법은 공간도 많이 차지하고 모델 아키텍쳐를 조금 수정했을 때 flexible 하게 원하는 key 의 weight 값만 가져올 수가 없다. 그래서 권장되는 방법이 아니다. 하지만 시간은 없는데 결과는 있어야하니.. 어쩔 수 없었다. 여기까지 시도했는데 또 이슈가 발생한다.
Issues 3. Can't pickle local object \'AtomicModel.get_metrics.\<locals\>.\<lambda\>
AttributeError: Can't pickle local object 'AtomicModel.get_metrics..' AttributeError: Can't pickle local object 'AtomicModel._get_metrics.._accuracy_score'
방법. dill package
모델에 python lambda 식을 사용해 layer 가 정의된 경우, 일반적인 방법으로는 모델 전체를 저장할 수 없다.
import dill model_copy=dill.dumps(model)
torch.save(model_copy, ‘my_model.pt’)
model = torch.load(model_name)
model=dill.loads(model)
dill 이라는 패키지를 사용하면 이 문제를 해결할 수 있다. serializable ( 직렬화 : 객체 저장시 데이터를 줄세워 저장 ) 하지않은 변수들을 저장할 때 사용하는 패키지이다. ( pickle 과 반대 )
#참조
The result is different when I apply torch.manual_seed before loading cuda() after loading the model
torch.no_grad()와 model.eval()의 차이
Missing keys & unexpected keys in state_dict when loading self trained model