PyTorch
GPU 사용하기
쉽게가자
2020. 5. 23. 22:55
파이토치에서 GPU를 이용한 forward/backward 연산을 하는 방법을 알아봤다.
- 사용할 GPU정보를 입력하여 torch.device() 클래스를 생성한다
device = torch.device('cuda')
- torch.nn.Module 클래스의 to() 메소드를 이용하여 신경망 파라메터 변수들을 device에 옮긴다
model.to(device)
- 학습 진행시, 넣어줄 입력 텐서도 device로 옯겨준다
x, y_pred = data[0].to(device), data[1].to(device)