PyTorch์ nn.Embedding
layer์ ์ด๊ธฐํํ๋ ๋ฐฉ๋ฒ์๋ ๋ ๊ฐ์ง๊ฐ ์๋ค.
embedding = nn.Embedding(num_embeddings, embedding_dim)
1. torch.tensor์ ๋ด์ฅ method ์ด์ฉํ๊ธฐ
embedding.weight.data.uniform_(-1, 1)
torch.tensor
์ uniform_
๋ฑ์ ๋ด์ฅ method๋ฅผ ๊ฐ์ง๊ณ ์์ด ์ด๋ฅผ ํตํด ๊ฐ์ ์ด๊ธฐํํ ์ ์๋ค.
2. torch.nn.init ์ด์ฉํ๊ธฐ
nn.init.uniform_(embedding.weight, -1.0, 1.0)
torch.nn.init์ method๋ค์ ์ด์ฉํ ์๋ ์๋ค.
์ด ๋ฐฉ๋ฒ์ ์ด์ฉํ๋ฉด uniform_
์ด์ธ์๋ xavier_uniform_
๋ฑ ๋ณด๋ค ๋ค์ํ initialization ๋ฐฉ๋ฒ๋ค์ ์ฌ์ฉํ ์ ์๋ค.
๋ฐ์ํ
'๐ Python & library > PyTorch' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
PyTorch 2.0์์ ๋ฌ๋ผ์ง๋ ์ - torch.compile (1) | 2023.05.06 |
---|---|
[PyTorch] tensor.detach()์ ๊ธฐ๋ฅ๊ณผ ์์ ์ฝ๋ (0) | 2022.10.30 |
Numpy & PyTorch๋ก 2D fourier transform, inverse fourier transformํ๊ธฐ (1) | 2022.08.27 |
[PyTorch] make_grid๋ก ์ฌ๋ฌ ๊ฐ์ ์ด๋ฏธ์ง ํ๋ฒ์ plotํ๊ธฐ (0) | 2022.07.29 |
[PyTorch] model weight ๊ฐ ์กฐ์ ํ๊ธฐ / weight normalization (0) | 2022.04.22 |