모델로드
-
파이토치 모델 파라메터 저장/불러오기PyTorch 2020. 5. 23. 16:52
파이토치로 신경망 모델 파라메터를 학습시키는데는 시간이 많이 걸린다. 그래서 한번 학습해서 얻어진 파라메터를 소중히 보관해야하는 방법을 알아봤다. 1. 저장하기 torch.save(model.state_dict(), 'PATH.pth') torch 모듈의 save() 함수와, torch.nn.Module의 state_dict()함수를 사용한다. state_dict()는 모델의 파라메터를 dictionary 타입으로 반환해 주고, save() 함수는 인수로 들어온 오브젝트를 파일로 저장해준다. 2. 불러오기 model.load_state_dict(torch.load('PATH.pth')) model.eval() torch 모듈의 load() 함수와, torch.nn.Module의 load_state_dic..