libtorch使用model.forward報std::runtime_error錯誤

珠峰上吹泡泡發表於2024-05-03

1、原因

模型向GPU複製時發生異常

	model = torch::jit::load(ptFile);
	if (isHalf)
	{
		model.to(torch::kHalf);
	}
	model.to(device);//GPU版異常,可能模型並沒有完全放到GPU上

2、解決方法

model = torch::jit::load(ptFile, torch::kCUDA);

參考:https://github.com/pytorch/pytorch/issues/19302

相關文章