์ด์ ๊ธ์์ PyTorch๋ฅผ ์ด์ฉํ gradient clipping์ ์๊ฐํ๋๋ฐ,
gradient๊ฐ ์๋๋ผ weight ๊ฐ ์์ฒด๋ฅผ ์ผ์ ๋ฒ์ ์์ผ๋ก ์ ํํด์ผ ํ๋ ๊ฒฝ์ฐ๊ฐ ์๋ค.
์ด ๊ฒฝ์ฐ Weight clipping์ ์ํํด์ฃผ๋ class๋ฅผ ์ ์ํ์ฌ ์ฌ์ฉํ ์ ์๋ค.
์ถ์ฒ: https://discuss.pytorch.org/t/set-constraints-on-parameters-or-layers/23620
Weight clipper ์ถ๊ฐํ๊ธฐ
๋ค์๊ณผ ๊ฐ์ด Weight clipper์ ์ ์ํด์ค ์ ์๋ค. torch.clamp method๋ฅผ ์ด์ฉํ๋ฉด tensor์ ๊ฐ์ ํด๋น ๋ฒ์ ์์ผ๋ก clipํด์ค ์ ์๋ค.
class WeightClipper(object):
def __call__(self, module, param, clip_min, clip_max):
if hasattr(module, param):
self.clip(module, param, clip_min, clip_max)
def clip(self, module, param, clip_min, clip_max):
p = getattr(module, param).data
p = p.clamp(clip_max, clip_max)
getattr(module, param).data = p
์ฌ์ฉ๋ฒ์ ์๋์ ๊ฐ๋ค.
model = Net()
weightclipper = WeightClipper()
#after each backward operation
loss.backward()
optimizer.step()
model.apply(weightclipper)
๋์ ๊ฒฝ์ฐ๋ model ๋ด์ ์ด๋ค parameter์ด ํน์ ๋ฒ์ ์์ผ๋ก ์ ํ๋๋๋ก ์์ ๊ฐ์ด weight clipper์ ์ ์ํ์ง๋ง,
๋ง์ฝ conv2d ๋ฑ์ ๋ชจ๋ module๋ค์ weight๋ฅผ [-1, 1] ๋ฒ์๋ก clipํ๊ณ ์ถ๋ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ๋ง๋ค ์๋ ์๋ค.
class WeightClipper(object):
def __call__(self, module, param):
if hasattr(module, param):
self.clip(module, param)
def clip(self, module, param):
p = getattr(module, param).data
p = p.clamp(-1, 1)
getattr(module, param).data = p