libtorch入門例程

兜尼完發表於2024-03-09

libtorch C++版可以直接在官網下載。自己學習如果沒有合適的顯示卡可以選擇下載CPU版的。下面是官網連結:

  • PyTorch

下載後就可以把開發包包含到VS的專案中使用。注意libtorch官網提供的Release/Debug的開發包,Debug版的程式用Debug版的庫,Release版的程式用Release版的庫,不能混用。另外libtorch更新很快是用最新的C++版本寫的,需要在編譯器設定中設定合適的C++語言版本。比如我的是libtorch1.13.1,它只能在C++14版本下使用,C++11或C++17都不行。在使用時如果編譯報錯有很多“std不明確的符號”,可用的改正方法是:開啟專案屬性→屬性→C/C++→語言→符合模式→改為“否”。

下面給出一個可以執行的程式碼。我的測試環境是VS2017(C++14)和libtorch1.13.1。如果在你的編譯器下能正常編譯執行,那麼說明libtorch是正常的。此程式碼的功能是擬合函式${ z=3x+y+2.5 }$。請注意torch::Tensor loss = lossFunc(predict, c);這一句,predict和c的位置不能反過來。

#include "torch/all.h"

int main()
{
    torch::nn::Linear linear(2, 1);

    /* 30個樣本。在這裡是一行一個樣本 */
    at::Tensor b = torch::rand({ 30, 2 });
    at::Tensor c = torch::zeros({ 30, 1 });
    for (int i = 0; i < 30; i++)
    {
        c[i] = 3 * b[i][0] + b[i][1] + 2.5f;
    }

    cout << b << endl;
    cout << c << endl;

    /* 訓練過程 */
    torch::optim::SGD optim(linear->parameters(), torch::optim::SGDOptions(0.01));
    torch::nn::MSELoss lossFunc;
    linear->train();
    for (int i = 0; i < 10000; i++)
    {
        torch::Tensor predict = linear(b);
        torch::Tensor loss = lossFunc(predict, c);
        optim.zero_grad();
        loss.backward();
        optim.step();
        if (i % 1000 == 0)
        {
            /* 每1000次迴圈輸出一次損失函式值 */
            cout << "LOOP:" << i << ",LOSS=" << loss.item() << endl;
        }
    }
    /* 輸出訓練之後的網路引數 */
    cout << linear->parameters() << endl;

    /* 做個測試 */
    at::Tensor x = torch::tensor({ 1.5f, 2.0f });
    at::Tensor y = linear(x);
    cout << "3*1.5+1*2+2.5=" << y.item();

    return 0;
}

輸出內容是:

 0.0341  0.6551
 0.9524  0.1005
 0.3764  0.5524
 0.8860  0.6767
 0.6554  0.9601
 0.7736  0.0955
 0.4260  0.3402
 0.1248  0.1497
 0.2288  0.2765
 0.4508  0.6151
 0.1954  0.0717
 0.5392  0.5821
 0.8622  0.2375
 0.9371  0.0668
 0.6593  0.2563
 0.1854  0.8515
 0.1299  0.4341
 0.8148  0.6432
 0.7303  0.0794
 0.6853  0.5018
 0.7687  0.8698
 0.6909  0.7306
 0.8921  0.8072
 0.6477  0.0745
 0.5048  0.8875
 0.6906  0.4306
 0.7410  0.6294
 0.0095  0.8609
 0.0862  0.8630
 0.6828  0.5330
[ CPUFloatType{30,2} ]
 3.2576
 5.4577
 4.1815
 5.8348
 5.4263
 4.9163
 4.1181
 3.0239
 3.4628
 4.4676
 3.1578
 4.6996
 5.3240
 5.3781
 4.7343
 3.9076
 3.3239
 5.5875
 4.7702
 5.0576
 5.6757
 5.3035
 5.9835
 4.5176
 4.9018
 5.0025
 5.3523
 3.3893
 3.6218
 5.0814
[ CPUFloatType{30,1} ]
LOOP:0,LOSS=33.1978
LOOP:1000,LOSS=0.0120936
LOOP:2000,LOSS=0.00164465
LOOP:3000,LOSS=0.000271623
LOOP:4000,LOSS=4.59804e-05
LOOP:5000,LOSS=7.80881e-06
LOOP:6000,LOSS=1.33087e-06
LOOP:7000,LOSS=2.33062e-07
LOOP:8000,LOSS=4.16838e-08
LOOP:9000,LOSS=8.05691e-09
 2.9999  0.9999
[ CPUFloatType{1,2} ]  2.5001
[ CPUFloatType{1} ]
3*1.5+1*2+2.5=8.99974

相關文章