解說pytorch中的model=model.to(device)
導讀 | 這篇文章主要介紹了pytorch中的model=model.to(device)使用說明,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教 |
這代表將模型載入到指定裝置上。
其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")則代表的使用GPU。
當我們指定了裝置之後,就需要將模型載入到相應裝置中,此時需要使用model=model.to(device),將模型載入到相應的裝置中。
將由GPU儲存的模型載入到CPU上。
將torch.load()函式中的map_location引數設定為torch.device('cpu')
device = torch.device('cpu') model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location=device))
將由GPU儲存的模型載入到GPU上。確保對輸入的tensors呼叫input = input.to(device)方法。
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.to(device)
將由CPU儲存的模型載入到GPU上。
確保對輸入的tensors呼叫input = input.to(device)方法。map_location是將模型載入到GPU上,model.to(torch.device('cuda'))是將模型引數載入為CUDA的tensor。
最後保證使用.to(torch.device('cuda'))方法將需要使用的引數放入CUDA。
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want model.to(device)
補充:pytorch中model.to(device)和map_location=device的區別
在已訓練並儲存在CPU上的GPU上載入模型時,載入模型時經常由於訓練和儲存模型時裝置不同出現讀取模型時出現錯誤,在對跨裝置的模型讀取時候涉及到兩個引數的使用,分別是model.to(device)和map_location=devicel兩個引數,簡介一下兩者的不同。
將map_location函式中的引數設定 torch.load()為 cuda:device_id。這會將模型載入到給定的GPU裝置。
呼叫model.to(torch.device('cuda'))將模型的引數張量轉換為CUDA張量,無論在cpu上訓練還是gpu上訓練,儲存的模型引數都是引數張量不是cuda張量,因此,cpu裝置上不需要使用torch.to(torch.device("cpu"))。
瞭解了兩者代表的意義,以下介紹兩者的使用。
儲存:
torch.save(model.state_dict(), PATH)
載入:
device = torch.device('cpu') model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location=device))
解釋:
在使用GPU訓練的CPU上載入模型時,請傳遞 torch.device('cpu')給map_location函式中的 torch.load()引數,使用map_location引數將張量下面的儲存器動態地重新對映到CPU裝置 。
儲存:
torch.save(model.state_dict(), PATH)
載入:
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.to(device) # Make sure to call input = input.to(device) on any input tensors that you feed to the model
解釋:
在GPU上訓練並儲存在GPU上的模型時,只需將初始化model模型轉換為CUDA最佳化模型即可model.to(torch.device('cuda'))。
此外,請務必.to(torch.device('cuda'))在所有模型輸入上使用該 功能來準備模型的資料。
請注意,呼叫my_tensor.to(device) 返回my_tensorGPU上的新副本。
它不會覆蓋 my_tensor。
因此,請記住手動覆蓋張量: my_tensor = my_tensor.to(torch.device('cuda'))
儲存:
torch.save(model.state_dict(), PATH)
載入:
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want model.to(device) # Make sure to call input = input.to(device) on any input tensors that you feed to the model
解釋:
在已訓練並儲存在CPU上的GPU上載入模型時,請將map_location函式中的引數設定 torch.load()為 cuda:device_id。
這會將模型載入到給定的GPU裝置。
接下來,請務必呼叫model.to(torch.device('cuda'))將模型的引數張量轉換為CUDA張量。
最後,確保.to(torch.device('cuda'))在所有模型輸入上使用該 函式來為CUDA最佳化模型準備資料。
請注意,呼叫 my_tensor.to(device)返回my_tensorGPU上的新副本。
它不會覆蓋my_tensor。
因此,請記住手動覆蓋張量:my_tensor = my_tensor.to(torch.device('cuda'))
原文來自: https://www.linuxprobe.com/pytorch-model-todevice.html
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69955379/viewspace-2780716/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 在pytorch框架下,訓練model過程中,loss=nan問題時該怎麼解決?PyTorch框架NaN
- Laravel中的ModelLaravel
- 聊聊MVX中的Model
- PyTorch 中 loss.grad_fn 解釋PyTorch
- PyTorch 中 torch.matmul() 函式的文件詳解PyTorch函式
- pytorch的 scatter詞解PyTorch
- Pytorch的API詳解PyTorchAPI
- 瞭解Linux系統中的Device Mapper機制:使用者空間LinuxdevAPP
- pytorch中forward的理解PyTorchForward
- 小心 Laravel 中的 Model::incrementLaravelREM
- pytorch中中的模型剪枝方法PyTorch模型
- Pytorch | Pytorch格式 .pt .pth .bin .onnx 詳解PyTorch
- Pytorch中的損失函式PyTorch函式
- Pytorch中backward()的思考記錄PyTorch
- 轉:Pytorch中的register_buffer()PyTorch
- vue中v-model的學習Vue
- Django model select的各種用法詳解Django
- C++中map的使用詳解說明C++
- 【小白學PyTorch】10 pytorch常見運算詳解PyTorch
- 【小白學PyTorch】13 EfficientNet詳解及PyTorch實現PyTorch
- 【小白學PyTorch】12 SENet詳解及PyTorch實現PyTorchSENet
- 說說Flutter中的RepaintBoundaryFlutterAI
- 說說Flutter中的SemanticsFlutter
- iOS Device ID 的前世今生iOSdev
- 前端:說說工作中解決過的印象比較深刻的問題前端
- iOS專案中Json轉Model的坑iOSJSON
- 實踐Pytorch中的模型剪枝方法PyTorch模型
- 說說JavaScript中的事件模型JavaScript事件模型
- Pytorch框架詳解之一PyTorch框架
- pytorch lstm原始碼解讀PyTorch原始碼
- pytorch 中 Tensor 的 pow 方法是幹嘛的?PyTorch
- 淺談 iOS Device ID 的修改iOSdev
- 【大廠面試05期】說一說你對MySQL中鎖的瞭解?面試MySql
- PYTORCH中的學習率怎麼理解PyTorch
- PyTorch中的多程序並行處理PyTorch並行
- Pytorch中stack()方法的總結及理解PyTorch
- Pytorch建模過程中的DataLoader與DatasetPyTorch
- Win10開機提示reboot and select proper boot device的解決方法Win10bootdev