PyTorch

파이토치 모델 파라메터 저장/불러오기

쉽게가자 2020. 5. 23. 16:52

파이토치로 신경망 모델 파라메터를 학습시키는데는 시간이 많이 걸린다. 그래서 한번 학습해서 얻어진 파라메터를 소중히 보관해야하는 방법을 알아봤다.

 

1. 저장하기

torch.save(model.state_dict(), 'PATH.pth')
  • torch 모듈의 save() 함수와, torch.nn.Module의 state_dict()함수를 사용한다.
  • state_dict()는 모델의 파라메터를 dictionary 타입으로 반환해 주고,
  • save() 함수는 인수로 들어온 오브젝트를 파일로 저장해준다.

2. 불러오기

model.load_state_dict(torch.load('PATH.pth'))
model.eval()
  • torch 모듈의 load() 함수와, torch.nn.Module의 load_state_dict()함수를 사용한다.
  • 먼저, load() 함수가 파일을 읽어서 state_dict 오브젝트를 만들어주고,
  • load_state_dict() 함수가 state_dict 오브젝트의 값을 읽고, model의 파라메터 값을 채워준다.
  • 이때 eval() 함수를 호출하지 않으면, dropout과 batch normalization레이어가 제대로 세팅되지 않아 inference시 비정상적인 결과가 나온다고 한다.

참고: https://pytorch.org/tutorials/beginner/saving_loading_models.html

 

Saving and Loading Models — PyTorch Tutorials 1.5.0 documentation

Note Click here to download the full example code Saving and Loading Models Author: Matthew Inkawhich This document provides solutions to a variety of use cases regarding the saving and loading of PyTorch models. Feel free to read the whole document, or ju

pytorch.org