PyTorch model의 weight값을 바로 바꾸려고 하면 FloatTensor을 parameter로 사용할 수 없다는 에러가 발생한다.
import torch.nn as nn
model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 4, bias=False))
model[2].weight = model[2].weight/2.
>> cannot assign 'torch.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)
학습 도중 사용하는 weight normalization은 nn.utils.weight_norm
을 사용하면 되는 것 같지만,
바로 weight 값에 접근하고자 하는 경우 다음과 같이 torch.no_grad()
block 내부에서 div_
method를 이용할 수 있다.
출처: https://discuss.pytorch.org/t/how-to-do-weight-normalization-in-last-classification-layer/35193
import torch
import torch.nn as nn
model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 4, bias=False))
with torch.no_grad():
model[2].weight.div_(2.)
weight normalization (L2 norm) 을 해주고 싶으면 다음과 같이 해주면 된다.
with torch.no_grad():
model[2].weight.div_(torch.norm(model[2].weight, p=2, dim=1, keepdim=True)
반응형
'🐍 Python & library > PyTorch' 카테고리의 다른 글
Numpy & PyTorch로 2D fourier transform, inverse fourier transform하기 (1) | 2022.08.27 |
---|---|
[PyTorch] make_grid로 여러 개의 이미지 한번에 plot하기 (0) | 2022.07.29 |
[PyTorch] nn.Conv의 padding과 padding_mode (2) | 2022.03.24 |
[PyTorch] Weight clipping (0) | 2022.01.29 |
[PyTorch] Enable anomaly detection (torch.autograd.detect_anomaly() / torch.autograd.set_detect_anomaly(True)) (0) | 2022.01.29 |