๐Ÿ Python & library/PyTorch

[PyTorch] tensor.detach()์˜ ๊ธฐ๋Šฅ๊ณผ ์˜ˆ์‹œ ์ฝ”๋“œ

๋ณต๋งŒ 2022. 10. 30. 03:15

PyTorch tensor์— ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋Š” detach() method๋Š” gradient์˜ ์ „ํŒŒ๋ฅผ ๋ฉˆ์ถ”๋Š” ์—ญํ• ์„ ํ•œ๋‹ค.

 

https://pytorch.org/docs/stable/generated/torch.Tensor.detach.html

 

torch.Tensor.detach — PyTorch 1.13 documentation

Shortcuts

pytorch.org

 

๋‹ค์Œ ์˜ˆ์‹œ๋ฅผ ํ†ตํ•ด ์‰ฝ๊ฒŒ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋‹ค.

 

import torch
import torch.nn as nn

class TestModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(10, 10)
        self.layer2 = nn.Linear(10, 10)
        
    def forward(self, x):
        out1 = self.layer1(x)
        out2 = self.layer2(out1.detach())
        return out2
    
model = TestModel()

 

์œ„์˜ ์ฝ”๋“œ๋ฅผ ๊ทธ๋ž˜ํ”„๋กœ ๋‚˜ํƒ€๋‚ด๋ฉด ๋‹ค์Œ๊ณผ ๊ฐ™๋‹ค. 

 

 

 

layer 1์—์„œ ๋‚˜์˜จ output์ด detach๋˜์—ˆ๊ธฐ ๋•Œ๋ฌธ์—, ์—ญ์ „ํŒŒ ์‹œ gradient๊ฐ€ ๊ทธ ์ด์ „ layer๋กœ ํ˜๋Ÿฌ๊ฐ€์ง€ ์•Š๋Š”๋‹ค.

 

๊ฐ layer์˜ weight gradient๋ฅผ ์ถœ๋ ฅํ•ด๋ณด๋ฉด, layer1์—๋Š” gradient๊ฐ€ ์ถ•์ ๋˜์ง€ ์•Š์€ ๊ฒƒ์„ ํ™•์ธํ•  ์ˆ˜ ์žˆ๋‹ค.

x = torch.randn(1, 10)
a = model(x)
a.mean().backward()
print(model.layer1.weight.grad)
print(model.layer2.weight.grad)

๋ฐ˜์‘ํ˜•