-
파이토치 모델 파라메터 저장/불러오기PyTorch 2020. 5. 23. 16:52
파이토치로 신경망 모델 파라메터를 학습시키는데는 시간이 많이 걸린다. 그래서 한번 학습해서 얻어진 파라메터를 소중히 보관해야하는 방법을 알아봤다.
1. 저장하기
torch.save(model.state_dict(), 'PATH.pth')
- torch 모듈의 save() 함수와, torch.nn.Module의 state_dict()함수를 사용한다.
- state_dict()는 모델의 파라메터를 dictionary 타입으로 반환해 주고,
- save() 함수는 인수로 들어온 오브젝트를 파일로 저장해준다.
2. 불러오기
model.load_state_dict(torch.load('PATH.pth')) model.eval()
- torch 모듈의 load() 함수와, torch.nn.Module의 load_state_dict()함수를 사용한다.
- 먼저, load() 함수가 파일을 읽어서 state_dict 오브젝트를 만들어주고,
- load_state_dict() 함수가 state_dict 오브젝트의 값을 읽고, model의 파라메터 값을 채워준다.
- 이때 eval() 함수를 호출하지 않으면, dropout과 batch normalization레이어가 제대로 세팅되지 않아 inference시 비정상적인 결과가 나온다고 한다.
참고: https://pytorch.org/tutorials/beginner/saving_loading_models.html
'PyTorch' 카테고리의 다른 글
GPU 사용하기 (0) 2020.05.23 nn.Dropout 으로 dropout 레이어 넣기 (0) 2020.05.23 torchvision으로 MNIST 데이터 로드하기 (1) 2020.05.23 torch.optim (0) 2020.05.22 torch.nn.Module (0) 2020.05.22