์ ๋์์์์ PyTorch Autograd๋ฅผ ์ดํดํ๊ธฐ ์ฝ๊ฒ ์ค๋ช ํด์ฃผ๊ณ ์๋ค. ๋ค์์ ์ ๋์์์ ๊ฐ๋จํ ์ ๋ฆฌํ ๊ธ์ด๋ค.
1. torch.Tensor
๊ฐ tensor์ ๋ค์์ attr์ ๊ฐ๋๋ค
data
: tensor์ ๊ฐgrad
: tensor์ gradient ๊ฐ. is_leaf์ธ ๊ฒฝ์ฐ์๋ง gradient๊ฐ ์๋์ผ๋ก ์ ์ฅ๋๋ค.grad_fn
: gradient function. ํด๋น tensor๊ฐ ์ด๋ค ์ฐ์ฐ์ ํตํด forward๋์๋์ง์ ๋ฐ๋ผ ๊ฒฐ์ ๋๋ค.- ex) a * b = c ์ธ ๊ฒฝ์ฐ c์ grad_fn์ MulBackward์ด๋ค.
- is_leaf์ธ ๊ฒฝ์ฐ None
is_leaf
: (backward ๊ธฐ์ค) ๊ฐ์ฅ ๋ง์ง๋ง tensor์ธ์งrequires_grad
: ๊ณ์ฐ ๊ทธ๋ํ์ ์ผ๋ถ๋ก ๋ค์ด๊ฐ ๊ฒ์ธ์ง
2. grad_fn
grad_fn
์ ๋ค์์ attr์ ๊ฐ๋๋ค.
saved_tensors
: forward ์ฐ์ฐ์ผ๋ก๋ถํฐ ๋ฐ์- ๊ณ์ฐ๊ทธ๋ํ์ ํฌํจ๋์ง ์์ in_place ์ฐ์ฐ ๋ฑ์ผ๋ก ์ธํด tensor ๊ฐ์ด ๋ณ๊ฒฝ๋๋ ๊ฒฝ์ฐ๋ฅผ ๋๋นํ์ฌ ๊ณ์ฐ ๋น์์ tensor ๊ฐ์ context ๋ณ์์ ์ ์ฅํด ๋๋๋ค.
- ๋ง์ฝ "Add"์ฒ๋ผ ์ด์ tensor ๊ฐ์ด ํ์ํ์ง ์์ ์ฐ์ฐ์ ๊ฒฝ์ฐ context ๋ณ์๊ฐ ๊ฐ์ ์ ์ฅํด ๋์ง ์์๋ ๋๋ค.
next_functions
: ๋ค์ tuple๋ก ๊ตฌ์ฑ๋ list- backward๊ธฐ์ค์ผ๋ก ๋ค์ tensor์ grad_fn
- is_leaf์ด๊ณ requires_grad์ผ ๊ฒฝ์ฐ
AcummulateGrad
- ๊ณ์ฐ๋ gradient๋ฅผ self.grad์ ์ ์ฅ - is_leaf์ด๊ณ requires_grad๊ฐ ์๋ ๊ฒฝ์ฐ None
- is_leaf์ด๊ณ requires_grad์ผ ๊ฒฝ์ฐ
- grad_fn์ ๋ช๋ฒ์งธ input์ผ๋ก ์ ๋ฌ๋ ๊ฒ์ธ์ง
- ๋ณดํต์ grad_fn์ด ํ๋์ input๋ง ๋ฐ์ง๋ง forward ์ฐ์ฐ์ output์ด ์ฌ๋ฌ๊ฐ์ธ ๊ฒฝ์ฐ grad_fn์ด ์ฌ๋ฌ๊ฐ์ input์ ๋ฐ์ ์ ์๋ค
- backward๊ธฐ์ค์ผ๋ก ๋ค์ tensor์ grad_fn
3. backward()
tensor์ backward() ์ฐ์ฐ์ด ํธ์ถ๋๋ฉด ํด๋น tensor์ gradient 1๋ก ์์ํ๋ค. ์ด ๊ฐ์ด grad_fn
์ ํ๊ณ ํ๋ฌ๊ฐ๋ค.
- ์ด์ tensor์ gradient๊ฐ
MulBackward
๋ก ์ ๋ฌ- 1 →
MulBackward
- 1 →
- ๋ค์ tensor์ gradient๋ฅผ ๊ณ์ฐํ์ฌ
next_function
์ผ๋ก ์ ๋ฌ- ๋ค์ tensor์ gradient = ํ์ฌ ์ฐ์ฐ์์์ gradient x ์ด์ tensor์ gradient (chain rule)
- 4 →
Mulbackward
- 6 →
AccumulateGrad
AccumulateGrad
ํจ์๋ ํด๋น tensor์ grad์ gradient ์ ์ฅ
gradient๋ is_leaf์ธ ๊ฒฝ์ฐ์๋ง ์ ์ฅ๋๋ค. leaf๊ฐ ์๋ tensor์ gradient๋ ์ ์ฅ๋์ง ์๊ณ grad_fn์ ๋ฐ๋ผ ์ ๋ฌ๋๊ธฐ๋ง ํ๋ค.
๊ทธ๋ฌ๋ intermediate tensor์๋ gradient๋ฅผ ์ ์ฅํ๊ณ ์ถ๋ค๋ฉด tensor.retain_grad()
๋ฉ์๋๋ฅผ ์ฌ์ฉํ๋ฉด ๋๋ค.
๋ฐ์ํ
'๐ Python & library > PyTorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
PyTorch 2.0์์ ๋ฌ๋ผ์ง๋ ์ - torch.compile (1) | 2023.05.06 |
---|---|
[PyTorch] tensor.detach()์ ๊ธฐ๋ฅ๊ณผ ์์ ์ฝ๋ (0) | 2022.10.30 |
[PyTorch] nn.Embedding ์ด๊ธฐํํ๊ธฐ (initialization) (0) | 2022.10.27 |
Numpy & PyTorch๋ก 2D fourier transform, inverse fourier transformํ๊ธฐ (1) | 2022.08.27 |
[PyTorch] make_grid๋ก ์ฌ๋ฌ ๊ฐ์ ์ด๋ฏธ์ง ํ๋ฒ์ plotํ๊ธฐ (0) | 2022.07.29 |