torch.optim.lr_scheduler를 이용하여 learning rate 조절하기
요즘 knowledge distillation 논문[Hinton14]을 읽고 있다. 여기에 나온 대로 3층 퍼셉트론을 구현해서 MNIST 데이터를 학습시켜봤는데, 적혀있는 것 보다 낮은 성능이 나왔다. 네트워크 구조는 제대로 구현한 것 같은데, optimizer 옵션이 다른것이 원인인듯 했다. 그래서 learning rate, batch size 등을 논문에 나온 내용과 똑같이 수정해보려고 한다.
사실 자세한 optimizer 옵션은 이전 논문[Hinton12]에 나와있는데, 다음과 같이 적혀있다.
위 수식의 핵심은, learning rate과 momentum을 epoch수에 따라 변화시키는 것이다.
- learning rate : 10에서 시작해서, 각 epoch마다 0.998을 곱하여 점점 줄인다
- momentum : 0.5에서 시작하여, epoch 500 일때 0.99가 되도록 증가시키고, 그후에는 0.99를 그대로 사용한다
- clipping gradient norm: 각 레이어의 가중치 벡터가 15보다 큰 경우에는 길이를 줄여서 15가 되도록 한다
우선은 learning rate 조절 방법을 알아보았다.
torch.optim.lr_scheduler 모듈에 들어있는 스케줄러 클래스들 중 원하는걸 골라서 쓰면 된다.
처음에 나온 4가지 클래스만 요약하면 다음과 같다.
- LambdaLR: lr_lambda인자로 넣어준 함수로 계산된 값을 초기 lr에 곱해서 사용한다
- MultiplicativeLR: lr_lambda인자로 넣어준 함수로 계산된 값을 매 epoch마다 이전 lr에 곱해서 사용한다
- StepLR: step_size에 지정된 epoch 수 마다 이전 lr에 gamma만큼 곱해서 사용한다
- ExponentialLR: 매 epoch마다 이전 lr에 gamma만큼 곱해서 사용한다
참고: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
이번 목표에 맞는것은 ExponentialLR 인데, 다음과 같이 사용하면 된다.
torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998)
다음은 momentum을 조절하는 방법인데, 미리 정의된 함수나 클래스가 없어서 직접 구현해야한다. 그런데 마침 github에서 이 기능이 잘 구현된 repository를 찾았다.
torch.optim.lr_scheduler._LRScheduler 클래스를 상속받아 새로운 클래스를 만들고, step() 메소드를 오버라이딩해서 optimizer.param_groups의 각 그룹의 'lr', 'momentum'에 매핑된 값을 갱신하도록 만들면 된다. (코드는 이곳을 참고)
마지막으로 gradient norm을 clipping 하는 것은 torch.utils.clip_gradient_norm_() 함수를 이용해서 구현할 수 있다.
참고: pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html
최종 코드는 다음과 같다.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import scheduler
from models import Teacher
MNIST_DIR = '../mnist/'
# Create gpu device
device = torch.device('cuda')
print(device)
# Create model
teacher_model = Teacher()
# Transfer
teacher_model.to(device)
# Define Loss
criterion = nn.CrossEntropyLoss()
# Define optimizer
optimizer = optim.SGD(teacher_model.parameters(), lr=10.0)
# Define schedule for learning rate and momentum
lr_init = 10.0
gamma = 0.998
lrs = np.zeros(shape=(3000,))
lr = lr_init
for step in range(3000):
lrs[step] = lr
lr *= gamma
momentums = np.concatenate([np.linspace(0.5, 0.99, 500), np.full(shape=(2500,), fill_value=0.99)])
list_lr_momentum_scheduler = scheduler.ListScheduler(optimizer, lrs=lrs, momentums=momentums)
## Load dataset
train_data = torchvision.datasets.MNIST(MNIST_DIR, train=True, download=True,
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(), # image to Tensor
torchvision.transforms.Normalize((0.1307,), (0.3081,)) # image, label
]))
train_loader = torch.utils.data.DataLoader(train_data, batch_size=100, shuffle=True)
teacher_model.train()
for epoch_count in range(3000):
print('epoch: {}'.format(epoch_count))
## Optimize parameters
total_loss = 0.0
for step_count, (x, y_gt) in enumerate(train_loader):
# Initialize gradients with 0
optimizer.zero_grad()
# Transfer device
x = x.to(device)
y_gt = y_gt.to(device)
# Predict
x = torch.flatten(x, start_dim=1, end_dim=-1)
y_pred = teacher_model(x)
# Compute loss (foward propagation)
loss = criterion(y_pred, y_gt)
# Compute gradients (backward propagation)
loss.backward()
# Clip gradient
torch.nn.utils.clip_grad_norm_(teacher_model.parameters(), 15.0)
# Update parameters (SGD)
optimizer.step()
total_loss += loss.item()
# if step_count % 1000 == 0:
# print('progress: {}\t/ {}\tloss: {}'.format(step_count, len(train_loader), loss.item()))
print('loss: {}'.format(total_loss / len(train_loader)))
list_lr_momentum_scheduler.step()
## Save model
torch.save(teacher_model.state_dict(), './data/teacher.pth')