[PyTorch/에러 해결] Dataparallel이 complex tensor을 real view로 전환시키는 문제
문제:
complex tensor을 input으로 받는 모델을 사용 중이었고,
forward method를 테스트 할 때는 잘 돌아가다가
전체 train 코드를 돌렸더니 tensor 차원이 안맞는다는 에러를 내뱉었다..
RuntimeError: The size of tensor a (2) must match the size of tensor b (232) at non-singleton dimension 3
바로 이렇게..
디버깅을 해보니 complex tensor가 model 내부로 들어가면 float으로 변환되면서 real-imag part가 분리되는 것이었다,,
따로 model forward 코드만 돌릴때는 잘만 돌아갔는데 ?

원인:
결론은.. nn.DataParallel
이 문제였다
nn.DataParallel
로 감싼 모델은 내부로 전달된 input에 torch.view_as_real
을 호출한다고 한다.
(근데 그러면 다시 원래대로 돌려놔야 하는거 아닌가요?)

자세한 설명은 아래 두 링크에서 확인.
Data Parallel splits Complex Parameter · Issue #60931 · pytorch/pytorch
🐛 Bug I am testing a toy model that fourier transforms an image and do pointwise multiplication with a complex tensor and then inverse fourier transform. When I train the model with single GPU ever...
github.com
`DataParallel` (`broadcast_coalesced`) with complex tensors yield real views · Issue #55375 · pytorch/pytorch
🐛 Bug Using DataParallel on complex tensors (either parameters or inputs/outputs) yield real views. The expected behavior would be to obtain complex tensors on each replicate. Casting the views bac...
github.com
해결방법:
DataParallel은 maintainance mode에 있기 때문에 DistributedDataParallel (DDP)를 쓰라고 한다
DataParallel은 여러모로 문제가 많구나
아니면 GPU 하나만 쓰던지..