PyTorch

torch.optim.lr_scheduler를 이용하여 learning rate 조절하기

쉽게가자 2020. 5. 24. 17:03

요즘 knowledge distillation 논문[Hinton14]을 읽고 있다. 여기에 나온 대로 3층 퍼셉트론을 구현해서 MNIST 데이터를 학습시켜봤는데, 적혀있는 것 보다 낮은 성능이 나왔다. 네트워크 구조는 제대로 구현한 것 같은데, optimizer 옵션이 다른것이 원인인듯 했다. 그래서 learning rate, batch size 등을 논문에 나온 내용과 똑같이 수정해보려고 한다.

 

사실 자세한 optimizer 옵션은 이전 논문[Hinton12]에 나와있는데, 다음과 같이 적혀있다.

 

위 수식의 핵심은, learning rate과 momentum을 epoch수에 따라 변화시키는 것이다.

  • learning rate : 10에서 시작해서, 각 epoch마다 0.998을 곱하여 점점 줄인다
  • momentum  : 0.5에서 시작하여, epoch 500 일때 0.99가 되도록 증가시키고, 그후에는 0.99를 그대로 사용한다
  • clipping gradient norm: 각 레이어의 가중치 벡터가 15보다 큰 경우에는 길이를 줄여서 15가 되도록 한다

우선은 learning rate 조절 방법을 알아보았다.

 

torch.optim.lr_scheduler 모듈에 들어있는 스케줄러 클래스들 중 원하는걸 골라서 쓰면 된다.

처음에 나온 4가지 클래스만 요약하면 다음과 같다.

  • LambdaLR: lr_lambda인자로 넣어준 함수로 계산된 값을 초기 lr에 곱해서 사용한다
  • MultiplicativeLR: lr_lambda인자로 넣어준 함수로 계산된 값을 매 epoch마다 이전 lr에 곱해서 사용한다
  • StepLR: step_size에 지정된 epoch 수 마다 이전 lr에 gamma만큼 곱해서 사용한다
  • ExponentialLR: 매 epoch마다 이전 lr에 gamma만큼 곱해서 사용한다

참고: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

 

torch.optim — PyTorch 1.5.0 documentation

torch.optim torch.optim is a package implementing various optimization algorithms. Most commonly used methods are already supported, and the interface is general enough, so that more sophisticated ones can be also easily integrated in the future. How to us

pytorch.org

이번 목표에 맞는것은 ExponentialLR 인데, 다음과 같이 사용하면 된다.

torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998)

다음은 momentum을 조절하는 방법인데, 미리 정의된 함수나 클래스가 없어서 직접 구현해야한다. 그런데 마침 github에서 이 기능이 잘 구현된 repository를 찾았다.

 

torch.optim.lr_scheduler._LRScheduler 클래스를 상속받아 새로운 클래스를 만들고, step() 메소드를 오버라이딩해서 optimizer.param_groups의 각 그룹의 'lr', 'momentum'에 매핑된 값을 갱신하도록 만들면 된다. (코드는 이곳을 참고)


마지막으로 gradient norm을 clipping 하는 것은 torch.utils.clip_gradient_norm_() 함수를 이용해서 구현할 수 있다.

참고: pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

 

torch.nn.utils.clip_grad — PyTorch 1.5.0 documentation

Shortcuts

pytorch.org


최종 코드는 다음과 같다.

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import scheduler

from models import Teacher

MNIST_DIR = '../mnist/'

# Create gpu device
device = torch.device('cuda')
print(device)

# Create model
teacher_model = Teacher()

# Transfer
teacher_model.to(device)

# Define Loss
criterion = nn.CrossEntropyLoss()

# Define optimizer
optimizer = optim.SGD(teacher_model.parameters(), lr=10.0)

# Define schedule for learning rate and momentum
lr_init = 10.0
gamma = 0.998
lrs = np.zeros(shape=(3000,))
lr = lr_init
for step in range(3000):
    lrs[step] = lr
    lr *= gamma
momentums = np.concatenate([np.linspace(0.5, 0.99, 500), np.full(shape=(2500,), fill_value=0.99)])
list_lr_momentum_scheduler = scheduler.ListScheduler(optimizer, lrs=lrs, momentums=momentums)

## Load dataset
train_data = torchvision.datasets.MNIST(MNIST_DIR, train=True, download=True,
                                        transform=torchvision.transforms.Compose([
                                            torchvision.transforms.ToTensor(), # image to Tensor
                                            torchvision.transforms.Normalize((0.1307,), (0.3081,)) # image, label
                                            ]))
train_loader = torch.utils.data.DataLoader(train_data, batch_size=100, shuffle=True)

teacher_model.train()
for epoch_count in range(3000):
    print('epoch: {}'.format(epoch_count))

    ## Optimize parameters
    total_loss = 0.0
    for step_count, (x, y_gt) in enumerate(train_loader):
        # Initialize gradients with 0
        optimizer.zero_grad()

        # Transfer device
        x = x.to(device)
        y_gt = y_gt.to(device)

        # Predict
        x = torch.flatten(x, start_dim=1, end_dim=-1)
        y_pred = teacher_model(x)

        # Compute loss (foward propagation)
        loss = criterion(y_pred, y_gt)
        
        # Compute gradients (backward propagation)
        loss.backward()
        
        # Clip gradient
        torch.nn.utils.clip_grad_norm_(teacher_model.parameters(), 15.0)

        # Update parameters (SGD)
        optimizer.step()

        total_loss += loss.item()
        # if step_count % 1000 == 0:
        #     print('progress: {}\t/ {}\tloss: {}'.format(step_count, len(train_loader), loss.item()))
    
    print('loss: {}'.format(total_loss / len(train_loader)))
    list_lr_momentum_scheduler.step()

## Save model
torch.save(teacher_model.state_dict(), './data/teacher.pth')