🐍 Python & library/PyTorch

[PyTorch] model weight 값 조정하기 / weight normalization

복만 2022. 4. 22. 18:18

PyTorch model의 weight값을 바로 바꾸려고 하면 FloatTensor을 parameter로 사용할 수 없다는 에러가 발생한다.

import torch.nn as nn

model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 4, bias=False))
model[2].weight = model[2].weight/2.

>> cannot assign 'torch.FloatTensor' as parameter 'weight' (torch.nn.Parameter or None expected)

 

 

학습 도중 사용하는 weight normalization은 nn.utils.weight_norm을 사용하면 되는 것 같지만,

바로 weight 값에 접근하고자 하는 경우 다음과 같이 torch.no_grad() block 내부에서 div_ method를 이용할 수 있다.

출처: https://discuss.pytorch.org/t/how-to-do-weight-normalization-in-last-classification-layer/35193

import torch
import torch.nn as nn

model = nn.Sequential(nn.Linear(10, 5), nn.ReLU(), nn.Linear(5, 4, bias=False))

with torch.no_grad():
    model[2].weight.div_(2.)

 

 

weight normalization (L2 norm) 을 해주고 싶으면 다음과 같이 해주면 된다.

with torch.no_grad():
    model[2].weight.div_(torch.norm(model[2].weight, p=2, dim=1, keepdim=True)
반응형