🐍 Python & library/PyTorch
[PyTorch] model weight 값 조정하기 / weight normalization
복만
2022. 4. 22. 18:18
PyTorch model의 weight값을 바로 바꾸려고 하면 FloatTensor을 parameter로 사용할 수 없다는 에러가 발생한다.
import torch.nn as nn
model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 4, bias=False))
model[2].weight = model[2].weight/2.
>> cannot assign 'torch.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)
학습 도중 사용하는 weight normalization은 nn.utils.weight_norm
을 사용하면 되는 것 같지만,
바로 weight 값에 접근하고자 하는 경우 다음과 같이 torch.no_grad()
block 내부에서 div_
method를 이용할 수 있다.
출처: https://discuss.pytorch.org/t/how-to-do-weight-normalization-in-last-classification-layer/35193
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 4, bias=False))
with torch.no_grad():
model[2].weight.div_(2.)
weight normalization (L2 norm) 을 해주고 싶으면 다음과 같이 해주면 된다.
with torch.no_grad():
model[2].weight.div_(torch.norm(model[2].weight, p=2, dim=1, keepdim=True)
반응형