解說pytorch中的model=model.to(device)
導讀 | 這篇文章主要介紹了pytorch中的model=model.to(device)使用說明,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教 |
這代表將模型載入到指定裝置上。
其中,device=torch.device("cpu")代表的使用cpu,而device=torch.device("cuda")則代表的使用GPU。
當我們指定了裝置之後,就需要將模型載入到相應裝置中,此時需要使用model=model.to(device),將模型載入到相應的裝置中。
將由GPU儲存的模型載入到CPU上。
將torch.load()函式中的map_location引數設定為torch.device('cpu')
device = torch.device('cpu') model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location=device))
將由GPU儲存的模型載入到GPU上。確保對輸入的tensors呼叫input = input.to(device)方法。
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.to(device)
將由CPU儲存的模型載入到GPU上。
確保對輸入的tensors呼叫input = input.to(device)方法。map_location是將模型載入到GPU上,model.to(torch.device('cuda'))是將模型引數載入為CUDA的tensor。
最後保證使用.to(torch.device('cuda'))方法將需要使用的引數放入CUDA。
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want model.to(device)
補充:pytorch中model.to(device)和map_location=device的區別
在已訓練並儲存在CPU上的GPU上載入模型時,載入模型時經常由於訓練和儲存模型時裝置不同出現讀取模型時出現錯誤,在對跨裝置的模型讀取時候涉及到兩個引數的使用,分別是model.to(device)和map_location=devicel兩個引數,簡介一下兩者的不同。
將map_location函式中的引數設定 torch.load()為 cuda:device_id。這會將模型載入到給定的GPU裝置。
呼叫model.to(torch.device('cuda'))將模型的引數張量轉換為CUDA張量,無論在cpu上訓練還是gpu上訓練,儲存的模型引數都是引數張量不是cuda張量,因此,cpu裝置上不需要使用torch.to(torch.device("cpu"))。
瞭解了兩者代表的意義,以下介紹兩者的使用。
儲存:
torch.save(model.state_dict(), PATH)
載入:
device = torch.device('cpu') model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location=device))
解釋:
在使用GPU訓練的CPU上載入模型時,請傳遞 torch.device('cpu')給map_location函式中的 torch.load()引數,使用map_location引數將張量下面的儲存器動態地重新對映到CPU裝置 。
儲存:
torch.save(model.state_dict(), PATH)
載入:
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.to(device) # Make sure to call input = input.to(device) on any input tensors that you feed to the model
解釋:
在GPU上訓練並儲存在GPU上的模型時,只需將初始化model模型轉換為CUDA最佳化模型即可model.to(torch.device('cuda'))。
此外,請務必.to(torch.device('cuda'))在所有模型輸入上使用該 功能來準備模型的資料。
請注意,呼叫my_tensor.to(device) 返回my_tensorGPU上的新副本。
它不會覆蓋 my_tensor。
因此,請記住手動覆蓋張量: my_tensor = my_tensor.to(torch.device('cuda'))
儲存:
torch.save(model.state_dict(), PATH)
載入:
device = torch.device("cuda") model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH, map_location="cuda:0")) # Choose whatever GPU device number you want model.to(device) # Make sure to call input = input.to(device) on any input tensors that you feed to the model
解釋:
在已訓練並儲存在CPU上的GPU上載入模型時,請將map_location函式中的引數設定 torch.load()為 cuda:device_id。
這會將模型載入到給定的GPU裝置。
接下來,請務必呼叫model.to(torch.device('cuda'))將模型的引數張量轉換為CUDA張量。
最後,確保.to(torch.device('cuda'))在所有模型輸入上使用該 函式來為CUDA最佳化模型準備資料。
請注意,呼叫 my_tensor.to(device)返回my_tensorGPU上的新副本。
它不會覆蓋my_tensor。
因此,請記住手動覆蓋張量:my_tensor = my_tensor.to(torch.device('cuda'))
原文來自: https://www.linuxprobe.com/pytorch-model-todevice.html
來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/69955379/viewspace-2780716/,如需轉載,請註明出處,否則將追究法律責任。
相關文章
- 在pytorch框架下,訓練model過程中,loss=nan問題時該怎麼解決?PyTorch框架NaN
- Laravel中的ModelLaravel
- VMWARE WORKSATTION 中 “the device is curreently in use” 解決一例dev
- python 中model.py詳解Python
- 不在models.py中的models
- pytorch中forward的理解PyTorchForward
- pytorch中中的模型剪枝方法PyTorch模型
- Pytorch的API詳解PyTorchAPI
- pytorch的 scatter詞解PyTorch
- PyTorch 中 loss.grad_fn 解釋PyTorch
- 聊聊MVX中的Model
- Eclipse中EventModel轉Model的錯誤Eclipse
- php中curl的詳細解說PHP
- PyTorch 中 torch.matmul() 函式的文件詳解PyTorch函式
- Spring中Model、ModelMap及ModelAndView之間的區別SpringView
- 小心 Laravel 中的 Model::incrementLaravelREM
- Pytorch中的損失函式PyTorch函式
- Pytorch中backward()的思考記錄PyTorch
- 轉:Pytorch中的register_buffer()PyTorch
- Pytorch | Pytorch格式 .pt .pth .bin .onnx 詳解PyTorch
- 說說Flutter中的SemanticsFlutter
- 說說Flutter中的RepaintBoundaryFlutterAI
- C++中map的使用詳解說明C++
- 解決Qt中ui->tableView->setModel(model);導致程式崩潰 問題QTUIView
- 瞭解Linux系統中的Device Mapper機制:使用者空間LinuxdevAPP
- 轉:AIX中The largest dump device is too small的處理AIdev
- 前端:說說工作中解決過的印象比較深刻的問題前端
- 實踐Pytorch中的模型剪枝方法PyTorch模型
- 說說JavaScript中的事件模型JavaScript事件模型
- statsmodels中的summary解讀(以linear regression模型為例)模型
- 【大廠面試05期】說一說你對MySQL中鎖的瞭解?面試MySql
- vue中v-model的學習Vue
- 【小白學PyTorch】10 pytorch常見運算詳解PyTorch
- Chrome中的Device模組調式響應性設計Chromedev
- 說說Python中的閉包Python
- 關於DPM(Deformable Part Model)演算法中模型結構的解釋ORM演算法模型
- Pytorch中stack()方法的總結及理解PyTorch
- PYTORCH中的學習率怎麼理解PyTorch