PyTorch tensor์ ์ฌ์ฉํ ์ ์๋ detach() method๋ gradient์ ์ ํ๋ฅผ ๋ฉ์ถ๋ ์ญํ ์ ํ๋ค.
https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html
๋ค์ ์์๋ฅผ ํตํด ์ฝ๊ฒ ์ดํดํ ์ ์๋ค.
import torch
import torch.nn as nn
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(10, 10)
self.layer2 = nn.Linear(10, 10)
def forward(self, x):
out1 = self.layer1(x)
out2 = self.layer2(out1.detach())
return out2
model = TestModel()
์์ ์ฝ๋๋ฅผ ๊ทธ๋ํ๋ก ๋ํ๋ด๋ฉด ๋ค์๊ณผ ๊ฐ๋ค.
layer 1์์ ๋์จ output์ด detach๋์๊ธฐ ๋๋ฌธ์, ์ญ์ ํ ์ gradient๊ฐ ๊ทธ ์ด์ layer๋ก ํ๋ฌ๊ฐ์ง ์๋๋ค.
๊ฐ layer์ weight gradient๋ฅผ ์ถ๋ ฅํด๋ณด๋ฉด, layer1์๋ gradient๊ฐ ์ถ์ ๋์ง ์์ ๊ฒ์ ํ์ธํ ์ ์๋ค.
x = torch.randn(1, 10)
a = model(x)
a.mean().backward()
print(model.layer1.weight.grad)
print(model.layer2.weight.grad)
๋ฐ์ํ
'๐ Python & library > PyTorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
[PyTorch] Autograd ์๋๋ฐฉ์ ์์๋ณด๊ธฐ (1) | 2023.12.02 |
---|---|
PyTorch 2.0์์ ๋ฌ๋ผ์ง๋ ์ - torch.compile (1) | 2023.05.06 |
[PyTorch] nn.Embedding ์ด๊ธฐํํ๊ธฐ (initialization) (0) | 2022.10.27 |
Numpy & PyTorch๋ก 2D fourier transform, inverse fourier transformํ๊ธฐ (1) | 2022.08.27 |
[PyTorch] make_grid๋ก ์ฌ๋ฌ ๊ฐ์ ์ด๋ฏธ์ง ํ๋ฒ์ plotํ๊ธฐ (0) | 2022.07.29 |