참고 파이토치 공식문서
early stopping에서 저장한 가중치를 복사해올때
best_model_wts = copy.deepcopy(model.state_dict())
copy.deepcopy(model.state_dict())
NOTE
만약 (검증 손실(validation loss) 결과에 따라) 가장 성능이 좋은 모델만 유지할 계획이라면, best_model_state = model.state_dict() 은 모델의 복사본이 아닌 모델의 현재 상태에 대한 참조(reference)만 반환한다는 사실을 잊으시면 안됩니다! 따라서 best_model_state 을 직렬화(serialize)하거나, best_model_state = deepcopy(model.state_dict()) 을 사용해야 합니다. 그렇지 않으면, 제일 좋은 성능을 내는 best_model_state 은 계속되는 학습 단계에서 갱신될 것입니다. 결과적으로, 최종 모델의 상태는 과적합(overfit)된 상태가 됩니다.
'AI > pytorch' 카테고리의 다른 글
파이토치 환경에서 randomness를 고정하기 위한 방법 (0) | 2022.09.28 |
---|