official docs : https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html
nn.ModuleList의 기능
PyTorch에 nn.Sequential()과 비슷한 nn.ModuleList() 모듈이 있다.
nn.Sequential은 input으로 준 module에 대해 순차적으로 forward() method를 호출해주는 역할을 하는데,
nn.ModuleList는 처음 봐서 그 용도에 대해 정리를 해 보았다.
- 우선, nn.ModuleList는 nn.Sequential과 마찬가지로 nn.Module의 list를 input으로 받는다.
- 이는 Python list와 마찬가지로, nn.Module을 저장하는 역할을 한다. index로 접근도 할 수 있다.
- 하지만 nn.Sequential과 다르게 forward() method가 없다.
- 또한, 안에 담긴 module 간의 connection도 없다.
그렇다면 왜 nn.ModuleList를 사용해야 할까?
- 우리는 nn.ModuleList안에 Module들을 넣어 줌으로써 Module의 존재를 PyTorch에게 알려 주어야 한다.
- 만약 nn.ModuleList에 넣어 주지 않고, Python list에만 Module들을 넣어 준다면, PyTorch는 이들의 존재를 알지 못한다.
- 이 경우 optimzier을 선언하고 model.parameter()로 parameter을 넘겨줄 때 "your model has no parameter" 와 같은 error을 받게 된다.
- 따라서 Module들을 Python list에 넣어 보관한다면, 꼭 마지막에 이들을 nn.ModuleList로 wrapping 해줘야 한다.
example) 두 종류의 module이 받는 input이 서로 다르고, 여러 개를 반복적으로 정의해야 할 때 유용하게 사용할 수 있다.
class Network(nn.Module):
def __init__(self, n_blocks):
super(Network, self).__init__()
self.n_blocks = n_blocks
block_A_list = []
block_B_list = []
for _ in range(n_blocks):
block_A_list.append(Block_A())
block_B_list.append(Block_B())
self.block_A_list = nn.ModuleList(block_A_list)
self.block_B_list = nn.ModuleList(block_B_list)
def forward(self, x, k):
for i in range(self.n_blocks):
out = self.block_A_list[i](x)
out = self.block_B_list[i](out, k)
return out
'🐍 Python & library > PyTorch' 카테고리의 다른 글
[PyTorch] Scheduler 시각화하기 (Visualize scheduler) (2) | 2021.11.24 |
---|---|
[PyTorch] ReduceLROnPlateau (0) | 2021.10.26 |
[PyTorch] CosineAnnealingLR, CosineAnnealingWarmRestarts (0) | 2021.10.14 |
[PyTorch] Livelossplot 사용예제 (0) | 2021.04.03 |
[PyTorch] 모델 저장하기 & 불러오기 (0) | 2020.02.03 |