Gradient clipping์ ๋๋ฌด ํฌ๊ฑฐ๋ ์์ gradient์ ๊ฐ์ ์ ํํ์ฌ vanishing gradient๋ exploding gradient ํ์์ ๋ฐฉ์งํ๋ ๋ฐฉ๋ฒ์ด๋ค.
ํนํ RNN์์ ์์ฃผ ๋ฐ์ํ๋ ํ์์ธ๋ฐ ์ด์ธ์๋ ๊น์ ๋คํธ์ํฌ์์ ์ ์ฉํ๊ฒ ์ฌ์ฉ๋ ์ ์๋ ๋ฐฉ๋ฒ์ด๋ค.
์ค๊ฐ์ loss๊ฐ ๋๋ฌด ๋ฐ๋ฉด์ weight update๊ฐ ์ด์ํ ๋ฐฉํฅ์ผ๋ก ์งํ๋๋ค๋ฉด ์ฌ์ฉํด๋ณผ ์ ์๋ค.
์๋ ๊ธ์ ์ฐธ๊ณ ํ์๋ค.
Gradient clip by value, by norm
Gradient clipping์ ํ ์ ์๋ ๋ฐฉ๋ฒ์ ๋ ๊ฐ์ง๊ฐ ์๋ค.
์ฒซ ๋ฒ์งธ๋ก, Clipping-by-value๋ ๋จ์ํ ๋ชจ๋ gradient๋ฅผ (min_threshold, max_threshold) ๋ฒ์๋ก clippingํ๋ ๊ฒ์ด๋ค.
๋๋ค๋ฅธ ๋ฐฉ๋ฒ์ธ Clipping-by-norm์ norm์ด threshold ์ด์์ผ ๊ฒฝ์ฐ, threshold์ gradient์ unit vector๋ฅผ ๊ณฑํ ๊ฐ์ผ๋ก ๋ฐ๊ฟ ์ฃผ๋ ๊ฒ์ด๋ค. ์ฆ ๋ค์๊ณผ ๊ฐ๋ค.
$g\leftarrow \text{threshold} * g / ||g||$ if $||g|| \geq \text{threshold}$
์๋์์ Tensorflow์ PyTorch๋ฅผ ์ด์ฉํด ์ด ๋ ๊ฐ์ง gradient clipping์ ๊ตฌํํ๋ ๋ฐฉ๋ฒ์ ์๊ฐํ๋ค. ์ ์ฒด ์ฝ๋๋ ์ ๋งํฌ์์ ํ์ธํ ์ ์๋ค. ๋ณธ ๊ธ์์๋ gradient clipping ๋ถ๋ถ์ ์ฝ๋๋ง ์์ฝํด ๋์๋ค.
Tensorflow (v1)
tf.clip_by_value
๋ฅผ ์ด์ฉํ๋ค.
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
gvs = optimizer.compute_gradients(loss)
capped_gvs = [(tf.clip_by_value(grad, -1., 1.), var) for grad, var in gvs] #min_value์ max_value๋ฅผ ์ค์ ํ ์ ์๋ค.
train_op = optimizer.apply_gradients(capped_gvs)
Clipping_by_norm์ line 3์ ๋ค์๊ณผ ๊ฐ์ด ๋ฐ๊ฟ์ฃผ๋ฉด ๋๋ค.
gradients = [(tf.clip_by_norm(grad, clip_norm=2.0)) for grad in gradients]
Tensorflow (v2)
v1๊ณผ ๋์ผํ๋ค.
with tf.GradientTape() as tape:
predictions= model(inputs, training=True)
loss = get_loss(targets, predictions)
gradients = tape.gradient(loss, model.trainable_variables)
gradients = [(tf.clip_by_value(grad, -1.0, 1.0)) for grad in gradients]
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
Clipping_by_norm์ญ์ v1๊ณผ ๋์ผํ๋ค.
gradients = [(tf.clip_by_norm(grad, clip_norm=2.0)) for grad in gradients]
PyTorch
nn.utils.clip_grad_value_
๋ฅผ ์ด์ฉํ๋ค. PyTorch์ ๊ฒฝ์ฐ (-clip_value, clip_value)์ ๋ฒ์๋ก gradient๋ฅผ clipํด์ค๋ค.
optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_value_(model.parameters(), clip_value=1.0)
optimizer.step()
Clipping_by_norm์ line 3์ ๋ค์๊ณผ ๊ฐ์ด ๋ฐ๊ฟ์ฃผ๋ฉด ๋๋ค.
nn.utils.clip_grad_norm_(model.parameters(), max_norm=2.0, norm_type=2)