ABOUT ME

-

Today
-
Yesterday
-
Total
-
  • torch.optim.lr_scheduler를 이용하여 learning rate 조절하기
    PyTorch 2020. 5. 24. 17:03

    요즘 knowledge distillation 논문[Hinton14]을 읽고 있다. 여기에 나온 대로 3층 퍼셉트론을 구현해서 MNIST 데이터를 학습시켜봤는데, 적혀있는 것 보다 낮은 성능이 나왔다. 네트워크 구조는 제대로 구현한 것 같은데, optimizer 옵션이 다른것이 원인인듯 했다. 그래서 learning rate, batch size 등을 논문에 나온 내용과 똑같이 수정해보려고 한다.

     

    사실 자세한 optimizer 옵션은 이전 논문[Hinton12]에 나와있는데, 다음과 같이 적혀있다.

     

    위 수식의 핵심은, learning rate과 momentum을 epoch수에 따라 변화시키는 것이다.

    • learning rate : 10에서 시작해서, 각 epoch마다 0.998을 곱하여 점점 줄인다
    • momentum  : 0.5에서 시작하여, epoch 500 일때 0.99가 되도록 증가시키고, 그후에는 0.99를 그대로 사용한다
    • clipping gradient norm: 각 레이어의 가중치 벡터가 15보다 큰 경우에는 길이를 줄여서 15가 되도록 한다

    우선은 learning rate 조절 방법을 알아보았다.

     

    torch.optim.lr_scheduler 모듈에 들어있는 스케줄러 클래스들 중 원하는걸 골라서 쓰면 된다.

    처음에 나온 4가지 클래스만 요약하면 다음과 같다.

    • LambdaLR: lr_lambda인자로 넣어준 함수로 계산된 값을 초기 lr에 곱해서 사용한다
    • MultiplicativeLR: lr_lambda인자로 넣어준 함수로 계산된 값을 매 epoch마다 이전 lr에 곱해서 사용한다
    • StepLR: step_size에 지정된 epoch 수 마다 이전 lr에 gamma만큼 곱해서 사용한다
    • ExponentialLR: 매 epoch마다 이전 lr에 gamma만큼 곱해서 사용한다

    참고: https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

     

    torch.optim — PyTorch 1.5.0 documentation

    torch.optim torch.optim is a package implementing various optimization algorithms. Most commonly used methods are already supported, and the interface is general enough, so that more sophisticated ones can be also easily integrated in the future. How to us

    pytorch.org

    이번 목표에 맞는것은 ExponentialLR 인데, 다음과 같이 사용하면 된다.

    torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.998)

    다음은 momentum을 조절하는 방법인데, 미리 정의된 함수나 클래스가 없어서 직접 구현해야한다. 그런데 마침 github에서 이 기능이 잘 구현된 repository를 찾았다.

     

    torch.optim.lr_scheduler._LRScheduler 클래스를 상속받아 새로운 클래스를 만들고, step() 메소드를 오버라이딩해서 optimizer.param_groups의 각 그룹의 'lr', 'momentum'에 매핑된 값을 갱신하도록 만들면 된다. (코드는 이곳을 참고)


    마지막으로 gradient norm을 clipping 하는 것은 torch.utils.clip_gradient_norm_() 함수를 이용해서 구현할 수 있다.

    참고: pytorch.org/docs/stable/_modules/torch/nn/utils/clip_grad.html

     

    torch.nn.utils.clip_grad — PyTorch 1.5.0 documentation

    Shortcuts

    pytorch.org


    최종 코드는 다음과 같다.

    import numpy as np
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.optim as optim
    import torchvision
    import scheduler
    
    from models import Teacher
    
    MNIST_DIR = '../mnist/'
    
    # Create gpu device
    device = torch.device('cuda')
    print(device)
    
    # Create model
    teacher_model = Teacher()
    
    # Transfer
    teacher_model.to(device)
    
    # Define Loss
    criterion = nn.CrossEntropyLoss()
    
    # Define optimizer
    optimizer = optim.SGD(teacher_model.parameters(), lr=10.0)
    
    # Define schedule for learning rate and momentum
    lr_init = 10.0
    gamma = 0.998
    lrs = np.zeros(shape=(3000,))
    lr = lr_init
    for step in range(3000):
        lrs[step] = lr
        lr *= gamma
    momentums = np.concatenate([np.linspace(0.5, 0.99, 500), np.full(shape=(2500,), fill_value=0.99)])
    list_lr_momentum_scheduler = scheduler.ListScheduler(optimizer, lrs=lrs, momentums=momentums)
    
    ## Load dataset
    train_data = torchvision.datasets.MNIST(MNIST_DIR, train=True, download=True,
                                            transform=torchvision.transforms.Compose([
                                                torchvision.transforms.ToTensor(), # image to Tensor
                                                torchvision.transforms.Normalize((0.1307,), (0.3081,)) # image, label
                                                ]))
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=100, shuffle=True)
    
    teacher_model.train()
    for epoch_count in range(3000):
        print('epoch: {}'.format(epoch_count))
    
        ## Optimize parameters
        total_loss = 0.0
        for step_count, (x, y_gt) in enumerate(train_loader):
            # Initialize gradients with 0
            optimizer.zero_grad()
    
            # Transfer device
            x = x.to(device)
            y_gt = y_gt.to(device)
    
            # Predict
            x = torch.flatten(x, start_dim=1, end_dim=-1)
            y_pred = teacher_model(x)
    
            # Compute loss (foward propagation)
            loss = criterion(y_pred, y_gt)
            
            # Compute gradients (backward propagation)
            loss.backward()
            
            # Clip gradient
            torch.nn.utils.clip_grad_norm_(teacher_model.parameters(), 15.0)
    
            # Update parameters (SGD)
            optimizer.step()
    
            total_loss += loss.item()
            # if step_count % 1000 == 0:
            #     print('progress: {}\t/ {}\tloss: {}'.format(step_count, len(train_loader), loss.item()))
        
        print('loss: {}'.format(total_loss / len(train_loader)))
        list_lr_momentum_scheduler.step()
    
    ## Save model
    torch.save(teacher_model.state_dict(), './data/teacher.pth')
    
Designed by Tistory.