TiDB v5.1 體驗: 我用 TiDB 訓練了一個機器學習模型

PingCAP發表於2021-10-26

作者簡介

韓明聰 ,TiDB Contributor,上海交通大學 IPADS 實驗室博士研究生,研究方向為系統軟體。本文主要介紹瞭如何在 TiDB 中使用純 SQL 訓練一個機器學習模型。

前言

眾所周知,TiDB 5.1 版本增加了很多新特性,其中有一個特性,即 ANSI SQL 99 標準中的 Common Table Expression (CTE)。一般來說,CTE 可以被用作一個 Statement 作用於臨時的 View,將一個複雜的 SQL 解耦,提高開發效率。但是,CTE 還有一個重要的使用方式,即 Recursive CTE,允許 CTE 引用自身,這是完善 SQL 功能的最後一塊核心的拼圖。 在 StackOverflow 中有過這樣一個討論 “Is SQL or even TSQL Turing Complete” ,其中點贊最多的回覆中提到這樣一句話:

“ In this set of slides Andrew Gierth proves that with CTE and Windowing SQL is Turing Complete, by constructing a cyclic tag system , which has been proved to be Turing Complete. The CTE feature is the important part however – it allows you to create named sub-expressions that can refer to themselves, and thereby recursively solve problems. ”

即 CTE 和 Window Function 甚至使得 SQL 成為一個圖靈完備的語言。 而這又讓我想起來多年前看到過的一篇文章 Deep Neural Network implemented in pure SQL over BigQuery ,作者使用純 SQL 來實現了一個 DNN 模型,但是開啟 repo 後發現,他竟然是標題黨!實際上他還是使用了 Python 來實現迭代訓練。 因此,既然 Recursive CTE 給了我們 “迭代” 的能力,這讓我想挑戰一下,能否在 TiDB 中使用純 SQL 實現機器學習模型的訓練、推理

Iris Dataset

首先要選擇一個簡單的機器學習模型和任務,我們先嚐試 sklearn 中的入門資料集 iris dataset。這個資料集共包含 3 類 150 條記錄,每類各 50 個資料,每條記錄都有 4 項特徵:花萼長度、花萼寬度、花瓣長度、花瓣寬度,可以通過這 4 個特徵預測鳶尾花卉屬於 iris-setosa,iris-versicolour,iris-virginica 中的哪一品種。

當下載好資料後(已經是 CSV 格式),我們先將資料匯入到 TiDB 中。

mysql> create table iris(sl float, sw float, pl float, pw float, type  varchar(16));
mysql> LOAD DATA LOCAL INFILE 'iris.csv' INTO  TABLE iris FIELDS  TERMINATED  BY ',' LINES  TERMINATED  BY  '\n' ;
mysql> select * from iris limit 10;+------+------+------+------+-------------+| sl   | sw   | pl   | pw   | type        |+------+------+------+------+-------------+|  5.1 |  3.5 |  1.4 |  0.2 | Iris-setosa ||  4.9 |    3 |  1.4 |  0.2 | Iris-setosa ||  4.7 |  3.2 |  1.3 |  0.2 | Iris-setosa ||  4.6 |  3.1 |  1.5 |  0.2 | Iris-setosa ||    5 |  3.6 |  1.4 |  0.2 | Iris-setosa ||  5.4 |  3.9 |  1.7 |  0.4 | Iris-setosa ||  4.6 |  3.4 |  1.4 |  0.3 | Iris-setosa ||    5 |  3.4 |  1.5 |  0.2 | Iris-setosa ||  4.4 |  2.9 |  1.4 |  0.2 | Iris-setosa ||  4.9 |  3.1 |  1.5 |  0.1 | Iris-setosa |+------+------+------+------+-------------+10 rows in set (0.00 sec)
    mysql> select type, count(*) from iris group by type;+-----------------+----------+| type            | count(*) |+-----------------+----------+| Iris-versicolor |       50 || Iris-setosa     |       50 || Iris-virginica  |       50 |+-----------------+----------+3 rows in set (0.00 sec)

Softmax Logistic Regression

這裡我們選擇一個簡單的機器學習模型 —— Softmax 邏輯迴歸,來實現多分類。(以下的圖與介紹均來自百度百科 )

在 Softmax 迴歸中將 x 分類為類別 y 的概率為:

代價函式為:

可以求得梯度 :

因此可以通過梯度下降方法,每次更新梯度:

Model Inference

我們先寫一個 SQL 來實現 Inference,根據上面定義的模型和資料,輸入的資料 X 共有五維(sl, sw, pl, pw 以及一個常數 1.0),輸出使用 one-hot 編碼。

    mysql> create table data(    x0 decimal(35, 30), x1 decimal(35, 30), x2 decimal(35, 30), x3 decimal(35, 30), x4 decimal(35, 30),         y0 decimal(35, 30), y1 decimal(35, 30), y2 decimal(35, 30));
    mysql>insert into dataselect    sl, sw, pl, pw, 1.0,     case when type='Iris-setosa'then 1 else 0 end,    case when type='Iris-versicolor'then 1 else 0 end,      case when type='Iris-virginica'then 1 else 0 endfrom iris;

引數共有 3 類 * 5 維 = 15 個:

    mysql> create table weight(    w00 decimal(35, 30), w01 decimal(35, 30), w02 decimal(35, 30), w03 decimal(35, 30), w04 decimal(35, 30),    w10 decimal(35, 30), w11 decimal(35, 30), w12 decimal(35, 30), w13 decimal(35, 30), w14 decimal(35, 30),    w20 decimal(35, 30), w21 decimal(35, 30), w22 decimal(35, 30), w23 decimal(35, 30), w24 decimal(35, 30));

先全部初始化為 0.1,0.2,0.3(這裡選擇不同的數字是為了方便演示,也可以全部初始化為0.1):

    mysql> insert into weight values (    0.1, 0.1, 0.1, 0.1, 0.1,    0.2, 0.2, 0.2, 0.2, 0.2,    0.3, 0.3, 0.3, 0.3, 0.3);

下面我們寫一個 SQL 來統計對所有的 Data 進行 Inference 後結果的準確率

為了方便理解,我們先給一個虛擬碼描述這個過程:

    weight = (       w00, w01, w02, w03, w04,    w10, w11, w12, w13, w14,    w20, w21, w22, w23, w24)for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:    exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)    exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)    exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)    sum_exp = exp0 + exp1 + exp2    // softmax    p0 = exp0  sum_exp    p1 = exp1  sum_exp    p2 = exp2  sum_exp    // inference result    r0 = p0 > p1 and p0 > p2     r1 = p1 > p0 and p1 > p2    r2 = p2 > p0 and p2 > p1         data.correct = (y0 == r0 and y1 == r1 and y2 == r2)return sum(Data.correct)  count(Data)

在上述程式碼中,我們對 Data 中的每一行元素進行計算,首先求三個向量點乘的 exp,然後求 softmax,最後選擇 p0, p1, p2 中最大的為 1,其餘為 0,這樣就完成了一個樣本的 Inference。如果一個樣本最後 Inference 的結果與它本來的分類一致,那就是一次正確的預測,最後我們對所有樣本中正確的數量求和,即可得到最後的正確率。

下面給出 SQL 的實現,我們選擇把 data 中的每一行資料都和 weight (只有一行資料) join 起來,然後計算每一行資料的 Inference 結果,再對正確的樣本數量求和:

    select sum(y0 = r0 and y1 = r1 and y2 = r2)  count(*)from    (select        y0, y1, y2,        p0 > p1 and p0 > p2 as r0, p1 > p0 and p1 > p2 as r1, p2 > p0 and p2 > p1 as r2    from        (select             y0, y1, y2,            e0/(e0+e1+e2) as p0, e1/(e0+e1+e2) as p1,  e2/(e0+e1+e2) as p2        from            (select                  y0, y1, y2,                 exp(                     w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4                 ) as e0,                 exp(                     w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4                 ) as e1,                 exp(                     w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4                  ) as e2             from data, weight) t1        )t2    )t3;

可以看到上述 SQL 幾乎是按步驟實現了上述虛擬碼的計算過程,得到結果:

    +-----------------------------------------------+| sum(y0 = r0 and y1 = r1 and y2 = r2)/count(*) |+-----------------------------------------------+|                                        0.3333 |+-----------------------------------------------+1 row in set (0.01 sec)

下面我們就對模型的引數進行學習。

Model Training

Notice: 這裡為了簡化問題,不考慮 “訓練集”、“驗證集” 等問題,只使用全部的資料進行訓練。

我們還是先給出一個虛擬碼,然後根據虛擬碼寫出一個 SQL:

    weight = (       w00, w01, w02, w03, w04,    w10, w11, w12, w13, w14,    w20, w21, w22, w23, w24)for iter in iterations:    sum00 = 0    sum01 = 0    ...    sum23 = 0    sum24 = 0    for data(x0, x1, x2, x3, x4, y0, y1, y2) in all Data:        exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)        exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)        exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)        sum_exp = exp0 + exp1 + exp2        // softmax        p0 = y0 - exp0  sum_exp        p1 = y1 - exp1  sum_exp        p2 = y2 - exp2  sum_exp        sum00 += p0 * x0        sum01 += p0 * x1        sum02 += p0 * x2        ...        sum23 += p2 * x3        sum24 += p2 * x4    w00 = w00 + learning_rate * sum00  Data.size    w01 = w01 + learning_rate * sum01  Data.size    ...    w23 = w23 + learning_rate * sum23  Data.size    w24 = w24 + learning_rate * sum24  Data.size

看上去比較繁瑣,因為我們這裡選擇把 sum, w 等向量給手動展開。

接著我們開始寫 SQL 訓練,我們先寫只有一次迭代的 SQL:

設定學習率和樣本數量

    mysql> set @lr = 0.1;Query OK, 0 rows affected (0.00 sec)mysql> set @dsize = 150;Query OK, 0 rows affected (0.00 sec)

迭代一次:

    select     w00 + @lr * sum(d00)  @dsize as w00, w01 + @lr * sum(d01)  @dsize as w01, w02 + @lr * sum(d02)  @dsize as w02, w03 + @lr * sum(d03)  @dsize as w03, w04 + @lr * sum(d04)  @dsize as w04 ,    w10 + @lr * sum(d10)  @dsize as w10, w11 + @lr * sum(d11)  @dsize as w11, w12 + @lr * sum(d12)  @dsize as w12, w13 + @lr * sum(d13)  @dsize as w13, w14 + @lr * sum(d14)  @dsize as w14,    w20 + @lr * sum(d20)  @dsize as w20, w21 + @lr * sum(d21)  @dsize as w21, w22 + @lr * sum(d22)  @dsize as w22, w23 + @lr * sum(d23)  @dsize as w23, w24 + @lr * sum(d24)  @dsize as w24from    (select        w00, w01, w02, w03, w04,        w10, w11, w12, w13, w14,        w20, w21, w22, w23, w24,        p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,        p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,        p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24    from        (select          w00, w01, w02, w03, w04,         w10, w11, w12, w13, w14,         w20, w21, w22, w23, w24,         x0, x1, x2, x3, x4,         y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2         from            (select                 w00, w01, w02, w03, w04,                w10, w11, w12, w13, w14,                w20, w21, w22, w23, w24,                x0, x1, x2, x3, x4, y0, y1, y2,                exp(                    w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4                ) as e0,                exp(                    w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4                ) as e1,                exp(                    w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4                 ) as e2             from data, weight) t1        )t2    )t3;

得到的結果是一次迭代後的模型引數:

    +----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+| w00                              | w01                              | w02                              | w03                              | w04                              | w10                              | w11                              | w12                              | w13                              | w14                              | w20                              | w21                              | w22                              | w23                              | w24                              |+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+| 0.242000022455130986666666666667 | 0.199736070114635900000000000000 | 0.135689102774125773333333333333 | 0.104372938417325687333333333333 | 0.128775320011717430666666666667 | 0.296128284590438133333333333333 | 0.237124925707748246666666666667 | 0.281477497498236260000000000000 | 0.225631554555397960000000000000 | 0.215390025342499213333333333333 | 0.061871692954430866666666666667 | 0.163139004177615846666666666667 | 0.182833399727637980000000000000 | 0.269995507027276353333333333333 | 0.255834654645783353333333333333 |+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+----------------------------------+1 row in set (0.03 sec)

下面就是核心部分 ,我們使用 Recursive CTE 來進行迭代訓練:

    mysql> set @num_iterations = 1000;Query OK, 0 rows affected (0.00 sec)

核心的思路是,每次迭代的輸入都是上一次迭代的結果,然後我們再加一個遞增的迭代變數來控制迭代次數,大體的架構:

with recursive cte(iter, weight) as(select 1, init_weightunion allselect iter+1, new_weightfrom cte where ites < @num_iterations)

接著,我們把一次迭代的 SQL 和這個迭代的框架結合起來(為了提高計算精度,在中間結果里加入了一些型別轉換):

    with recursive weight( iter,         w00, w01, w02, w03, w04,        w10, w11, w12, w13, w14,        w20, w21, w22, w23, w24) as(select 1,     cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast (0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),    cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)),    cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30)), cast(0.1 as DECIMAL(35, 30))union allselect     iter + 1,    w00 + @lr * cast(sum(d00) as DECIMAL(35, 30))  @dsize as w00, w01 + @lr * cast(sum(d01) as DECIMAL(35, 30))  @dsize as w01, w02 + @lr * cast(sum(d02) as DECIMAL(35, 30))  @dsize as w02, w03 + @lr * cast(sum(d03) as DECIMAL(35, 30))  @dsize as w03, w04 + @lr * cast(sum(d04) as DECIMAL(35, 30))  @dsize as w04 ,    w10 + @lr * cast(sum(d10) as DECIMAL(35, 30))  @dsize as w10, w11 + @lr * cast(sum(d11) as DECIMAL(35, 30))  @dsize as w11, w12 + @lr * cast(sum(d12) as DECIMAL(35, 30))  @dsize as w12, w13 + @lr * cast(sum(d13) as DECIMAL(35, 30))  @dsize as w13, w14 + @lr * cast(sum(d14) as DECIMAL(35, 30))  @dsize as w14,    w20 + @lr * cast(sum(d20) as DECIMAL(35, 30))  @dsize as w20, w21 + @lr * cast(sum(d21) as DECIMAL(35, 30))  @dsize as w21, w22 + @lr * cast(sum(d22) as DECIMAL(35, 30))  @dsize as w22, w23 + @lr * cast(sum(d23) as DECIMAL(35, 30))  @dsize as w23, w24 + @lr * cast(sum(d24) as DECIMAL(35, 30))  @dsize as w24    from    (select        iter, w00, w01, w02, w03, w04,        w10, w11, w12, w13, w14,        w20, w21, w22, w23, w24,        p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,        p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,        p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24    from        (select          iter, w00, w01, w02, w03, w04,         w10, w11, w12, w13, w14,         w20, w21, w22, w23, w24,         x0, x1, x2, x3, x4,         y0 - e0/(e0+e1+e2) as p0, y1 - e1/(e0+e1+e2) as p1, y2 - e2/(e0+e1+e2) as p2         from            (select                 iter, w00, w01, w02, w03, w04,                w10, w11, w12, w13, w14,                w20, w21, w22, w23, w24,                x0, x1, x2, x3, x4, y0, y1, y2,                exp(                    w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4  
                ) as e0,  
                exp(  
                    w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4  
                ) as e1,  
                exp(  
                    w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4   
                ) as e2  
             from data, weight where iter < @num_iterations) t1  
        )t2  
    )t3  
having count(*) > 0  
)  
select * from weight where iter = @num_iterations;

這個版本和上面迭代一次的版本的區別在於兩點:

在 data join weight 後,我們增加一個 where iter < @num_iterations 用於控制迭代次數,並且在最後的輸出中增加了一列 iter + 1 as ite
最後我們還增加了 having count(*) > 0 ,避免當最後沒有輸入資料時,aggregation 還是會輸出資料,導致迭代不能結束。

然後我們得到結果:

 ERROR 3577 (HY000): In recursive query block of Recursive Common Table Expression 'weight', the recursive table must be referenced only once, and not in any subquery

啊這…… recursive cte 竟然不允許在 recursive part 裡有子查詢!不過把上面的子查詢全部都合併到一起也不是不可以,那我手動合併一下,然後再試一下:

 ERROR 3575 (HY000): Recursive Common Table Expression 'cte' can contain neither aggregation nor window functions in recursive query block

不允許子查詢我可以手動改 SQL,但是不允許用 aggregate function 我是真的沒辦法了!

在這裡我們只能宣佈挑戰失敗…誒,為啥我不能去改一下 TiDB 的實現呢?

根據 proposal 中的介紹,recursive CTE 的實現並沒有脫離 TiDB 基本的執行框架,諮詢了 @wjhuang2016 之後,得知之所以不允許使用子查詢和 aggregate function 的原因應該有兩點:

  • MySQL 也不允許
  • 如果允許的話,有很多的 corner case 需要處理,非常的複雜
    但是這裡我們只是需要試驗一下功能,暫時把這個 check 給刪除掉也未嘗不可, diff 裡刪除了對子查詢和 aggregation function 的檢查。

下面我們再次執行一遍:

    +------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+  
| iter | w00                              | w01                              | w02                               | w03                               | w04                              | w10                              | w11                               | w12                               | w13                               | w14                              | w20                               | w21                               | w22                              | w23                              | w24                               |  
+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+  
| 1000 | 0.988746701341992382020000000002 | 2.154387045383744124308666666676 | -2.717791657467537500866666666671 | -1.219905459264249309799999999999 | 0.523764101056271250025665250523 | 0.822804724410132626693333333336 | -0.100577045244777709968533333327 | -0.033359805866941626546666666669 | -1.046591158370568595420000000005 | 0.757865074561280001352887284083 | -1.511551425752124944953333333333 | -1.753810000138966371560000000008 | 3.051151463334479351666666666650 | 2.566496617634817948266666666655 | -0.981629175617551201349829226980 |  
+------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+-----------------------------------+-----------------------------------+----------------------------------+----------------------------------+-----------------------------------+

成功了!我們得到了迭代 1000 次後的引數!

下面我們用新的引數來重新計算正確率:

| sum(y0 = r0 and y1 = r1 and y2 = r2) / count(*) |  
+-------------------------------------------------+  
|                                          0.9867 |  
+-------------------------------------------------+  
1 row in set (0.02 sec)

這次正確率到達了 98%。

Conclusion

我們這次成功使用純 SQL 在 TiDB 中訓練了一個 Softmax logistic regression model,主要利用了 TiDB v5.1 版本的 Recursive CTE 功能。在測試的過程中,我們發現了目前 TiDB 的 Recursive CTE 不允許存在 subquery 和 aggregate function,我們簡單修改了 TiDB 的程式碼,繞過了這個限制,最終成功訓練出了一個模型,並在 iris dataset 上得到了 98% 的準確率。

Discussion

  • 經過一些測試後,發現 PostgreSQL 和 MySQL 均不支援在 Recursive CTE 使用聚合函式,可能實現起來確實存在一些難以處理的 corner case,具體大家可以討論一下。
  • 本次的嘗試,是手動把所有的維度全部展開,實際上我還寫了一個不需要展開所有維度的實現(例如 data 表的 schema 是 (idx, dim, value)),但是這種實現方式需要 join 兩次 weight 表,也就是在 CTE 裡需要遞迴訪問兩次,這還需要修改 TiDB Executor 的實現,所以就沒有寫在這裡。但實際上,這種實現方式更加的通用,一個 SQL 可以處理所有維度數量的模型(我最初想嘗試用 TiDB 訓練 MINIST)。

相關文章