PyTorch

torchvision으로 MNIST 데이터 로드하기

쉽게가자 2020. 5. 23. 01:45

MNIST 데이터를 쉽게 로드하기 위해서는 torchvision 모듈을 사용해야한다.

torchvision에 관한 문서를 한번 읽어봤다.

 

*참고: https://pytorch.org/docs/stable/torchvision/index.html

 

torchvision — PyTorch 1.5.0 documentation

Shortcuts

pytorch.org


TORCHVISION

The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision.


그런데 특별한 설명은 없고, 컴퓨터 비전을 위한 이미지 변환, 네트워크 아키텍쳐, 데이터셋 등을 모아놓았다고 한다...

 

무튼 MNIST 데이터를 로드하려면 torchvision.datasets 라는 패키지를 이용하면 되는데, 여기에는 MNIST 외에도 CIFAR, COCO, VOC등 유명한 데이터셋이 클래스로 구현되어있다.

train_data = torchvision.datasets.MNIST('./data', train=True, download=True)

이런식으로 MNIST 데이터셋 오브젝트를 만들 수 있고, torchvision.datasets.MNIST 클래스의 자세한 내용은 이곳을 참조하면 된다: https://pytorch.org/docs/stable/torchvision/datasets.html#mnist

친절하게도 train을 True로 하면 training set이, False로 하면 test set이 로드된다.

게다가 download를 True로 했을때 데이터셋을 인터넷에서 다운로드 해준다!

 

그런데 주의점은, 이렇게 로드를 하면 PIL.Image.Image 타입의 오브젝트가 데이터로 들어간다는 것이다. 이것을 torch.Tensor 타입으로 바꾸려면 아래 코드처럼, transform 인자로, torchvision.transforms.toTensor()를 넣어주면 된다.

train_data = torchvision.datasets.MNIST('./data', train=True, download=True, transform=torchvision.transforms.toTensor())