🐍 Python & library/PyTorch

[PyTorch] nn.ModuleList 기능과 사용 이유

복만 2021. 8. 4. 21:09

참고 : https://discuss.pytorch.org/t/when-should-i-use-nn-modulelist-and-when-should-i-use-nn-sequential/5463/3

 

When should I use nn.ModuleList and when should I use nn.Sequential?

From what I see, it is interchangeable then? Unless there is some order to be followed, then we should use Sequential. Am I right?

discuss.pytorch.org

official docs : https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html

 

ModuleList — PyTorch 1.9.0 documentation

Shortcuts

pytorch.org

 

nn.ModuleList의 기능

PyTorch에 nn.Sequential()과 비슷한 nn.ModuleList() 모듈이 있다.

nn.Sequential은 input으로 준 module에 대해 순차적으로 forward() method를 호출해주는 역할을 하는데,

nn.ModuleList는 처음 봐서 그 용도에 대해 정리를 해 보았다.

 

 

- 우선, nn.ModuleListnn.Sequential과 마찬가지로 nn.Module의 list를 input으로 받는다.

- 이는 Python list와 마찬가지로, nn.Module을 저장하는 역할을 한다. index로 접근도 할 수 있다.

- 하지만 nn.Sequential과 다르게 forward() method가 없다.

- 또한, 안에 담긴 module 간의 connection도 없다.

 

 

그렇다면 왜 nn.ModuleList를 사용해야 할까?

- 우리는 nn.ModuleList안에 Module들을 넣어 줌으로써 Module의 존재를 PyTorch에게 알려 주어야 한다.

- 만약 nn.ModuleList에 넣어 주지 않고, Python list에만 Module들을 넣어 준다면, PyTorch는 이들의 존재를 알지 못한다.

- 이 경우 optimzier을 선언하고 model.parameter()로 parameter을 넘겨줄 때 "your model has no parameter" 와 같은 error을 받게 된다.

- 따라서 Module들을 Python list에 넣어 보관한다면, 꼭 마지막에 이들을 nn.ModuleList로 wrapping 해줘야 한다.

 


 

example) 두 종류의 module이 받는 input이 서로 다르고, 여러 개를 반복적으로 정의해야 할 때 유용하게 사용할 수 있다.

class Network(nn.Module):
    def __init__(self, n_blocks):
        super(Network, self).__init__()
        self.n_blocks = n_blocks
        block_A_list = []
        block_B_list = []
        for _ in range(n_blocks):
            block_A_list.append(Block_A())
            block_B_list.append(Block_B())
        self.block_A_list = nn.ModuleList(block_A_list)
        self.block_B_list = nn.ModuleList(block_B_list)

    def forward(self, x, k):
        for i in range(self.n_blocks):
            out = self.block_A_list[i](x)
            out = self.block_B_list[i](out, k)
        return out
반응형