๐Ÿ Python & library/PyTorch

[PyTorch] Autograd ์ž‘๋™๋ฐฉ์‹ ์•Œ์•„๋ณด๊ธฐ

๋ณต๋งŒ 2023. 12. 2. 23:20

 

 

 

 

์œ„ ๋™์˜์ƒ์—์„œ PyTorch Autograd๋ฅผ ์ดํ•ดํ•˜๊ธฐ ์‰ฝ๊ฒŒ ์„ค๋ช…ํ•ด์ฃผ๊ณ  ์žˆ๋‹ค. ๋‹ค์Œ์€ ์œ„ ๋™์˜์ƒ์„ ๊ฐ„๋‹จํžˆ ์ •๋ฆฌํ•œ ๊ธ€์ด๋‹ค.

 

 

 

 

1. torch.Tensor

๊ฐ tensor์€ ๋‹ค์Œ์˜ attr์„ ๊ฐ–๋Š”๋‹ค

  • data: tensor์˜ ๊ฐ’
  • grad: tensor์˜ gradient ๊ฐ’. is_leaf์ธ ๊ฒฝ์šฐ์—๋งŒ gradient๊ฐ€ ์ž๋™์œผ๋กœ ์ €์žฅ๋œ๋‹ค.
  • grad_fn: gradient function. ํ•ด๋‹น tensor๊ฐ€ ์–ด๋–ค ์—ฐ์‚ฐ์„ ํ†ตํ•ด forward๋˜์—ˆ๋Š”์ง€์— ๋”ฐ๋ผ ๊ฒฐ์ •๋œ๋‹ค.
    • ex) a * b = c ์ธ ๊ฒฝ์šฐ c์˜ grad_fn์€ MulBackward์ด๋‹ค.
    • is_leaf์ธ ๊ฒฝ์šฐ None
  • is_leaf: (backward ๊ธฐ์ค€) ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰ tensor์ธ์ง€
  • requires_grad: ๊ณ„์‚ฐ ๊ทธ๋ž˜ํ”„์˜ ์ผ๋ถ€๋กœ ๋“ค์–ด๊ฐˆ ๊ฒƒ์ธ์ง€

 

 

2. grad_fn

grad_fn์€ ๋‹ค์Œ์˜ attr์„ ๊ฐ–๋Š”๋‹ค.

  • saved_tensors: forward ์—ฐ์‚ฐ์œผ๋กœ๋ถ€ํ„ฐ ๋ฐ›์Œ
    • ๊ณ„์‚ฐ๊ทธ๋ž˜ํ”„์— ํฌํ•จ๋˜์ง€ ์•Š์€ in_place ์—ฐ์‚ฐ ๋“ฑ์œผ๋กœ ์ธํ•ด tensor ๊ฐ’์ด ๋ณ€๊ฒฝ๋˜๋Š” ๊ฒฝ์šฐ๋ฅผ ๋Œ€๋น„ํ•˜์—ฌ ๊ณ„์‚ฐ ๋‹น์‹œ์˜ tensor ๊ฐ’์„ context ๋ณ€์ˆ˜์— ์ €์žฅํ•ด ๋†“๋Š”๋‹ค.
    • ๋งŒ์•ฝ "Add"์ฒ˜๋Ÿผ ์ด์ „ tensor ๊ฐ’์ด ํ•„์š”ํ•˜์ง€ ์•Š์€ ์—ฐ์‚ฐ์˜ ๊ฒฝ์šฐ context ๋ณ€์ˆ˜๊ฐ€ ๊ฐ’์„ ์ €์žฅํ•ด ๋‘์ง€ ์•Š์•„๋„ ๋œ๋‹ค.
  • next_functions: ๋‹ค์Œ tuple๋กœ ๊ตฌ์„ฑ๋œ list
    • backward๊ธฐ์ค€์œผ๋กœ ๋‹ค์Œ tensor์˜ grad_fn
      • is_leaf์ด๊ณ  requires_grad์ผ ๊ฒฝ์šฐ AcummulateGrad - ๊ณ„์‚ฐ๋œ gradient๋ฅผ self.grad์— ์ €์žฅ
      • is_leaf์ด๊ณ  requires_grad๊ฐ€ ์•„๋‹ ๊ฒฝ์šฐ None
    • grad_fn์˜ ๋ช‡๋ฒˆ์งธ input์œผ๋กœ ์ „๋‹ฌ๋  ๊ฒƒ์ธ์ง€
      • ๋ณดํ†ต์€ grad_fn์ด ํ•˜๋‚˜์˜ input๋งŒ ๋ฐ›์ง€๋งŒ forward ์—ฐ์‚ฐ์˜ output์ด ์—ฌ๋Ÿฌ๊ฐœ์ธ ๊ฒฝ์šฐ grad_fn์ด ์—ฌ๋Ÿฌ๊ฐœ์˜ input์„ ๋ฐ›์„ ์ˆ˜ ์žˆ๋‹ค

 

 

3. backward()

tensor์˜ backward() ์—ฐ์‚ฐ์ด ํ˜ธ์ถœ๋˜๋ฉด ํ•ด๋‹น tensor์€ gradient 1๋กœ ์‹œ์ž‘ํ•œ๋‹ค. ์ด ๊ฐ’์ด grad_fn์„ ํƒ€๊ณ  ํ˜๋Ÿฌ๊ฐ„๋‹ค.

 

  • ์ด์ „ tensor์˜ gradient๊ฐ€ MulBackward๋กœ ์ „๋‹ฌ
    • 1 → MulBackward
  • ๋‹ค์Œ tensor์˜ gradient๋ฅผ ๊ณ„์‚ฐํ•˜์—ฌ next_function์œผ๋กœ ์ „๋‹ฌ
    • ๋‹ค์Œ tensor์˜ gradient = ํ˜„์žฌ ์—ฐ์‚ฐ์—์„œ์˜ gradient x ์ด์ „ tensor์˜ gradient (chain rule)
    • 4 → Mulbackward
    • 6 → AccumulateGrad
  • AccumulateGrad ํ•จ์ˆ˜๋Š” ํ•ด๋‹น tensor์˜ grad์— gradient ์ €์žฅ

 

gradient๋Š” is_leaf์ธ ๊ฒฝ์šฐ์—๋งŒ ์ €์žฅ๋œ๋‹ค. leaf๊ฐ€ ์•„๋‹Œ tensor์˜ gradient๋Š” ์ €์žฅ๋˜์ง€ ์•Š๊ณ  grad_fn์„ ๋”ฐ๋ผ ์ „๋‹ฌ๋˜๊ธฐ๋งŒ ํ•œ๋‹ค.

 

๊ทธ๋Ÿฌ๋‚˜ intermediate tensor์—๋„ gradient๋ฅผ ์ €์žฅํ•˜๊ณ  ์‹ถ๋‹ค๋ฉด tensor.retain_grad() ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜๋ฉด ๋œ๋‹ค.

๋ฐ˜์‘ํ˜•