PyTorch出現錯誤“RuntimeError: Found dtype Double but expected Float”

songyuc發表於2020-12-04

1 錯誤描述

今天在除錯PyTorch程式碼時出現“RuntimeError: Found dtype Double but expected Float”的錯誤,相關提示資訊如下

File “/home/…/train.py”, line 78, in main
running_loss = trainer.train_one_epoch(epoch, qa=qa)
File “/home/…/model/…py”, line 347, in train_one_epoch
loss.backward()
File “/home/…/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/tensor.py”, line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File “/home/…/software/python/anaconda/anaconda3/envs/conda-general/lib/python3.7/site-packages/torch/autograd/init.py”, line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: Found dtype Double but expected Float

可以看到這是一個跟資料型別相關的錯誤;

2 相關資料

感謝網友lcqin111提供的資料——《Pytorch: RuntimeError: expected Double tensor (got Float tensor)》
裡面對這個問題進行了解釋。

3 解決方案

其實這個問題產生的原因就是資料型別不一致,比較solid一點的方法,就是從報錯的地方開始一點一點除錯程式碼,看看參與運算的張量是否存在型別不同的情況,例如:
如果a[FloatTensor]和b[DoubleTensor]是兩個參與運算的張量,且有運算程式碼“loss = criterion(a,b)”,則會引發上面的問題。
所以可以從出錯的程式碼位置一步步進行除錯;
其實,最主要的原因還是張量型別不一致,所以實際上將張量型別統一就可以了
可以使用程式碼:

torch_tensor = torch_tensor.float()

3.1 小提示——使用double()則會佔用很多視訊記憶體

南溪自己試過用double()進行運算,不過這樣視訊記憶體佔用會增大許多,而很可能出現視訊記憶體爆炸的情況,所以最後還是使用FloatTensor型別;

相關文章