PyTorch

torch.nn.Module.parameters() 는 정확히 어떤 값을 돌려줄까?

쉽게가자 2020. 5. 28. 10:13

신경망 파라메터를 optimizer에 전달해 줄 때, torch.nn.Module 클래스의 parameters() 메소드를 사용한다.

 

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

 

위와 같은 경우, parameters()는 정확히 어떤 값들을 반환해주는지 궁금해졌다.


공식 문서를 보면, parameters() 메소드는 모듈의 파라메터들을 iterator로 반환한다고 적혀있다.

 

torch.nn.Module.parameters() 공식 문서

 

Example 부분을 보면 파라메터 오브젝트의 타입이 torch.Tensor인걸 알 수 있다.

파라메터 오브젝트에 관한 더 자세한 설명 설명은 공식 문서의 torch.nn.Parameter 클래스 부분에 나와있는데, torch.nn.Parameter 클래스는 torch.Tensor 클래스를 상속받아 만들어졌고, torch.nn.Module 클래스의 attribute로 할당하면, 자동으로 파라메터 리스트에 추가되는 것이 기존의 torch.Tensor 클래스와의 차이점이라고 한다.

 


그런데 신경망 구조를 구현할때 보통 아래와 같이 하위 모듈을 추가하는식으로 구현을 하기 때문에, 파라메터를 직접 추가할 일이 거의 없었다.

 

class MyModule(torch.nn.Module):
    def __init__(self):
    	super(MyModule, self).__init__()
        self.A = torch.nn.Linear(100, 200)
        self.B = torch.nn.Linear(200, 10)
 	
    def forward(input):
    	return self.B(self.A(input))

 

그렇다면 하위 모듈만 추가해도 자동으로 파라메터가 설정되는 마법(?)은 어떻게 구현된 것일까?

공식 문서에는 설명이 없었기 때문에 torch.nn.Module 클래스의 코드를 직접 열어서 읽어보았다.

 

    def __init__(self):
        """
        Initializes internal Module state, shared by both nn.Module and ScriptModule.
        """
        torch._C._log_api_usage_once("python.nn_module")

        self.training = True
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._state_dict_hooks = OrderedDict()
        self._load_state_dict_pre_hooks = OrderedDict()
        self._modules = OrderedDict()

 

위 코드는 torch.nn.Module 클래스의 생성자를 정의하는데, _parameters_modules 라는 딕셔너리 타입의 멤버변수들이 있는것을 볼 수 있다. 당연하지만 _paramters는 파라메터들이, _modules에는 하위 모듈들이 들어간다. (key: 이름, value: 인스턴스)

 

조금 더 코드를 읽어본 결과, parameters() 메소드는 먼저 모든 하위 모듈들을 탐색하고(recursive=True), 각 모듈의 _parameters에 들어있는 파라메터들을 하나씩 반환해주는 함수였다.

 

아래의 그림은 모듈의 한 예시인데, 이 경우 A.parameters()를 호출하면, [A.B.weight, A.B.bias, A.C.D.weight, A.C.D.bias]가 반환된다. (물론 recursive=False를 지정하면 직접 A에 속한 파라메터만 반환하는데, 아래 경우에는 아무것도 반환되지 않게된다.)

 

모듈의 트리구조 (예시)


파이토치 모듈을 상속받아 신경망 구조를 구현할때, 인터페이스가 너무 편리해서 내부 동작에 관해서 생각해 본 적이 없었는데, 구현 코드를 직접 들여다보니 생각보다 읽기도 쉽고, API 이해도가 높아졌다. 다음에도 궁금한 부분이 있으면 하나씩 코드를 까봐야겠다.