AI/pytorch
copy.deepcopy(model.state_dict())
교 향
2022. 9. 28. 20:24
모델 저장하기 & 불러오기
Author: Matthew Inkawhich, 번역: 박정환,. 이 문서에서는 PyTorch 모델을 저장하고 불러오는 다양한 방법을 제공합니다. 이 문서 전체를 다 읽는 것도 좋은 방법이지만, 필요한 사용 예의 코드만 참고하
tutorials.pytorch.kr
참고 파이토치 공식문서
early stopping에서 저장한 가중치를 복사해올때
best_model_wts = copy.deepcopy(model.state_dict())
copy.deepcopy(model.state_dict())
NOTE
만약 (검증 손실(validation loss) 결과에 따라) 가장 성능이 좋은 모델만 유지할 계획이라면, best_model_state = model.state_dict() 은 모델의 복사본이 아닌 모델의 현재 상태에 대한 참조(reference)만 반환한다는 사실을 잊으시면 안됩니다! 따라서 best_model_state 을 직렬화(serialize)하거나, best_model_state = deepcopy(model.state_dict()) 을 사용해야 합니다. 그렇지 않으면, 제일 좋은 성능을 내는 best_model_state 은 계속되는 학습 단계에서 갱신될 것입니다. 결과적으로, 최종 모델의 상태는 과적합(overfit)된 상태가 됩니다.