Loading [MathJax]/jax/output/CommonHTML/jax.js

💩 에러 해결

[PyTorch/에러 해결] Dataparallel이 complex tensor을 real view로 전환시키는 문제

복만 2022. 1. 10. 16:31

문제: 

complex tensor을 input으로 받는 모델을 사용 중이었고,

forward method를 테스트 할 때는 잘 돌아가다가

전체 train 코드를 돌렸더니 tensor 차원이 안맞는다는 에러를 내뱉었다..

 

RuntimeError: The size of tensor a (2) must match the size of tensor b (232) at non-singleton dimension 3

 

바로 이렇게..

 

디버깅을 해보니 complex tensor가 model 내부로 들어가면 float으로 변환되면서 real-imag part가 분리되는 것이었다,,

따로 model forward 코드만 돌릴때는 잘만 돌아갔는데 ?

 

 

 

원인:

결론은.. nn.DataParallel이 문제였다 

nn.DataParallel로 감싼 모델은 내부로 전달된 input에 torch.view_as_real을 호출한다고 한다.

(근데 그러면 다시 원래대로 돌려놔야 하는거 아닌가요?)

 

그래서 이렇게 된다.

 

자세한 설명은 아래 두 링크에서 확인.

 

Data Parallel splits Complex Parameter · Issue #60931 · pytorch/pytorch

🐛 Bug I am testing a toy model that fourier transforms an image and do pointwise multiplication with a complex tensor and then inverse fourier transform. When I train the model with single GPU ever...

github.com

 

DataParall (broadccoasced) with complex tensors yield real views · Issue #55375 · pytorch/pytorch

🐛 Bug Using DataParallel on complex tensors (either parameters or inputs/outputs) yield real views. The expected behavior would be to obtain complex tensors on each replicate. Casting the views bac...

github.com

 

해결방법:

DataParallel은 maintainance mode에 있기 때문에 DistributedDataParallel (DDP)를 쓰라고 한다

DataParallel은 여러모로 문제가 많구나

아니면 GPU 하나만 쓰던지..

반응형