AI/pytorch

copy.deepcopy(model.state_dict())

교 향 2022. 9. 28. 20:24
copy.deepcopy(model.state_dict())

NOTE

만약 (검증 손실(validation loss) 결과에 따라) 가장 성능이 좋은 모델만 유지할 계획이라면, best_model_state = model.state_dict() 은 모델의 복사본이 아닌 모델의 현재 상태에 대한 참조(reference)만 반환한다는 사실을 잊으시면 안됩니다! 따라서 best_model_state 을 직렬화(serialize)하거나, best_model_state = deepcopy(model.state_dict()) 을 사용해야 합니다. 그렇지 않으면, 제일 좋은 성능을 내는 best_model_state 은 계속되는 학습 단계에서 갱신될 것입니다. 결과적으로, 최종 모델의 상태는 과적합(overfit)된 상태가 됩니다.