AI/tensorflow

학습 중지된 경우 이어서 하려면?

교 향 2023. 9. 19. 17:05
EPOCHS = 100

for epoch in range(EPOCHS):
    for image_batch, _ in train_ds: # _ is label
        train_step(image_batch)
    
    if (epoch + 1) % 15 == 0:
    
        checkpoint.save(file_prefix=checkpoint_prefix)
        
        generate_and_save_images(generator, epoch + 1, seed)

checkpoint.save를 하도록 설정해 두었다면

컴퓨터가 다운되거나 학습이 중지되었어도

 

아래와 같이 체크포인트 가중치를 불러와서,

# 체크포인트 불러오기
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

# 멈춘 부분부터 다시 실행
EPOCHS = 100

for epoch in range(EPOCHS):
    for image_batch, _ in train_ds: # _ is label
        train_step(image_batch)
    
    if (epoch + 1) % 15 == 0:
        checkpoint.save(file_prefix=checkpoint_prefix)
        generate_and_save_images(generator, epoch + 1, seed)


checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))를 통해

 

후에 훈련을 재개할 때 이전에 멈춘 시점(epoch)부터 다시 훈련이 진행되게 할 수 있음.