๐Ÿ Python & library/PyTorch

[PyTorch] Weight clipping

๋ณต๋งŒ 2022. 1. 29. 19:02

์ด์ „ ๊ธ€์—์„œ PyTorch๋ฅผ ์ด์šฉํ•œ gradient clipping์„ ์†Œ๊ฐœํ–ˆ๋Š”๋ฐ, 

gradient๊ฐ€ ์•„๋‹ˆ๋ผ weight ๊ฐ’ ์ž์ฒด๋ฅผ ์ผ์ • ๋ฒ”์œ„ ์•ˆ์œผ๋กœ ์ œํ•œํ•ด์•ผ ํ•˜๋Š” ๊ฒฝ์šฐ๊ฐ€ ์žˆ๋‹ค.

 

์ด ๊ฒฝ์šฐ Weight clipping์„ ์ˆ˜ํ–‰ํ•ด์ฃผ๋Š” class๋ฅผ ์ •์˜ํ•˜์—ฌ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋‹ค.

 

์ถœ์ฒ˜: https://discuss.pytorch.org/t/set-constraints-on-parameters-or-layers/23620

 

Set constraints on parameters or layers

Hi, are there any ways in Pytorch to set the range of parameters or values in each layer? For example, is it able to constrain the range of the linear product Y = WX to [-1, 1]? If not, how about limiting the range of the weight? I noticed in Karas, user c

discuss.pytorch.org

 

 

Weight clipper ์ถ”๊ฐ€ํ•˜๊ธฐ

๋‹ค์Œ๊ณผ ๊ฐ™์ด Weight clipper์„ ์ •์˜ํ•ด์ค„ ์ˆ˜ ์žˆ๋‹ค. torch.clamp method๋ฅผ ์ด์šฉํ•˜๋ฉด tensor์˜ ๊ฐ’์„ ํ•ด๋‹น ๋ฒ”์œ„ ์•ˆ์œผ๋กœ clipํ•ด์ค„ ์ˆ˜ ์žˆ๋‹ค.

class WeightClipper(object):
    def __call__(self, module, param, clip_min, clip_max):
    	if hasattr(module, param):
	    self.clip(module, param, clip_min, clip_max)
        
    def clip(self, module, param, clip_min, clip_max):
        p = getattr(module, param).data
        p = p.clamp(clip_max, clip_max)
        getattr(module, param).data = p

 

์‚ฌ์šฉ๋ฒ•์€ ์•„๋ž˜์™€ ๊ฐ™๋‹ค.

model = Net()
weightclipper = WeightClipper()

#after each backward operation
loss.backward()
optimizer.step()
model.apply(weightclipper)

 

 


 

๋‚˜์˜ ๊ฒฝ์šฐ๋Š” model ๋‚ด์˜ ์–ด๋–ค parameter์ด ํŠน์ • ๋ฒ”์œ„ ์•ˆ์œผ๋กœ ์ œํ•œ๋˜๋„๋ก ์œ„์™€ ๊ฐ™์ด weight clipper์„ ์ •์˜ํ–ˆ์ง€๋งŒ,

๋งŒ์•ฝ conv2d ๋“ฑ์˜ ๋ชจ๋“  module๋“ค์˜ weight๋ฅผ [-1, 1] ๋ฒ”์œ„๋กœ clipํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™์ด ๋งŒ๋“ค ์ˆ˜๋„ ์žˆ๋‹ค.

class WeightClipper(object):
    def __call__(self, module, param):
    	if hasattr(module, param):
	    self.clip(module, param)
        
    def clip(self, module, param):
        p = getattr(module, param).data
        p = p.clamp(-1, 1)
        getattr(module, param).data = p

 

๋ฐ˜์‘ํ˜•