pytorch,訓練模型時記憶體佔用不斷上升

kksk43發表於2024-10-28

重點關注帶有梯度的變數,特別是累積它們或收集它們的地方
我的問題是,在訓練的時候收集logits時,忘記加上.detach(),導致梯度資訊也跟著收集,然後記憶體佔用不斷上升甚至超過90G
加上.detach()後,就固定在44G不變了

相關文章