NL2SQL之DB-GPT-Hub<詳解篇>:text2sql任務的微調框架和基準對比

汀、人工智能發表於2024-10-08

NL2SQL之DB-GPT-Hub<詳解篇>:text2sql任務的微調框架和基準對比

隨著生成式人工智慧(Artificial Intelligence Generated Content,簡寫為 AIGC)時代的到來,使用大規模預訓練語言模型(LLM)來進行 text2sql 任務的 sql 生成也越來越常見。基於 LLM 的 text2SQL 方法通常分為兩種:

  • 基於 prompt 的 In context Learning(ICL)方法;

  • 基於 text2sql 任務構建資料集並且微調開源的 LLM 以適配 text2sql 任務

基於 prompt 的方法相對來說成本較低,方法和效果都有相對成熟的結果;微調 LLM 的方法受限於消耗資源比較大,計算成本過高,沒有得到很好地探索。B-GPT-Hub是一款很好的專案,這是一個基於 LLM 微調的 text2SQL 的訓練推理框架和 benchmark,主要側重於大規模微調 LLM 的方式。

  • 主要貢獻:

    1. 透過微調中型到大型 open source 的 LLM 並對 textSQL 任務進行標準化和全面的評估;

    2. 模組化且易於擴充套件的程式碼庫,支援主流 LLM 和實驗場景,優先考慮微調方法,並擴充套件到基於 prompt 的方式。

工作研究了與基於 promp 方法相比,微調方法的潛在收益和效能邊界,並探索了針對特定場景的最佳解決方案。希望 DB-GPT-Hub 以及這些發現能夠推動進一步的研究和廣泛的應用,否則由於缺乏專門的開放基準,這些研究和應用將很難實現。

  • 具體程式碼:https://github.com/eosphoros-ai/DB-GPT-Hub

  • 文章:https://arxiv.org/abs/2406.11434

  • text2sql榜單:https://github.com/eosphoros-ai/Awesome-Text2SQL

1.DB-GPT-Hub簡介

Text-to-SQL(簡寫為 Text2SQL,或者 NL2SQL)是一項將自然語言描述轉化為對應的結構化查詢語句(Structured Query Language, 簡寫為 SQL)的技術,它能利用簡潔清晰的自然語言描述,有效地輔助人們對海量的資料庫進行查詢,簡化資料查詢和分析的工作。隨著生成式人工智慧(Artificial Intelligence Generated Content,簡寫為 AIGC)時代的到來,使用大規模預訓練語言模型來進行 sql 生成的方式也越來越常見。

然而在實際開發中,當前的 Text-to-SQL 技術並未與 LLM 一些優秀的特性有效結合,例如 self-instruct、思維鏈、分散式計算、量化微調、attention 最佳化等方法,此外,Text2SQL 技術如何結合強大的自然語言理解能力,實現從資料 - 模型 - 微調 - 部署 - 展示的全鏈路工作流程,也是亟待解決的問題。因此,在本專案中構建了一套基於微調的 text2sql 任務全鏈路框架,同時對現有主流的開源模型進行 text2sql 任務的評測,構建了一套 open source LLM 的 benchmark;同時也對其中的一些 insight 進行了分析。

2.程式碼架構設計

為了充分利用大語言模型(Large Language Model,簡寫為 LLM)的語言理解能力,提高 Text2SQL 的模型微調效率和模型精度,在 DB-GPT 框架下提出了一個端到端大模型 Text2SQL 微調子框架 DB-GPT-Hub。在 DB-GPT 框架下,構架了 Text2SQL 領域下的資料預處理 - 模型微調 - 模型預測 - 模型驗證 - 模型評估的全鏈路工作流程,如下圖所示:

圖 1.DB-GPT-Hub 的架構流程圖

如圖一所示:DB-GPT-Hub 專案重點關注在資料預處理 - 資料集構建 - 模型微調 - 模型預測 - 模型驗證部分,微調得到的模型可以無縫銜接部署到 DB-GPT 框架中,然後結合知識問答和資料分析等能力展示模型在 Text2SQL 領域的優越效能。

具體功能:

  • 資料集構建:將原生的 text2SQL 資料處理成合適的格式(Text Representation Format)以微調 LLM。這包括將問題和資料庫 schema 的描述整合到提示中作為指令(instruction),以及各種問題表示以提高訓練和評估期間的效能。此外,將選擇不同的 few-shot 策略(例如 example selection 和 organization)來構建評估資料集

  • 訓練:的程式碼庫支援使用 PEFT 策略對開源 LLM 進行微調。支援大多數公共架構,模型規模從小到大,例如 Qwen、Llama、Baichuan 和 ChatGLM

  • 預測:的程式碼庫支援開源 LLM 的微調版本和閉源 LLM 的 SQL 查詢推理。支援使用少樣本和零樣本方法來生成特定場景的 SQL

  • 評估:同時,支援不同的評測指標(EX、EM)來從不同角度評估生成的 SQL 的效能。

2.1 資料集構建

以開源資料集 Spider 為例做一個詳細的介紹,Spider 資料集是一個多資料庫、多表、單輪查詢的 Text2SQL 資料集,是 Text2SQL 任務中最具挑戰性的資料集之一,由耶魯大學的 LILY 實驗室於 2018 年釋出,具有如下特點:

  • 規模大:Spider 資料集包含了 10,181 個自然語言問題和 5,693 個唯一的複雜 SQL 查詢,涉及到 200 個具有多個表的資料庫,覆蓋了 138 個不同的領域。

  • 泛化強:Spider 資料集與之前的 Text2SQL 資料集不同的是,它在訓練集和測試集中使用了不同的 SQL 查詢和資料庫模式,這就要求模型不僅能很好地泛化到新的 SQL 查詢,而且也要泛化到新的資料庫模式。

  • 結構好: 與 WikiSQL 在每個資料庫中只有一個表相比,Spider 上的每個資料庫包含多個表,並且它們透過主外來鍵聯絡在一起。 有

  • 挑戰:Spider 資料集包含了 SQL 中幾乎所有常見的高階語法,比如 "ORDER BY", "GROUP BY", "HAVING", "JOIN”,"INSERTION" 和巢狀等,如下圖所示。

圖 2: 不同資料集的語法分佈

spider 資料集將 SQL 生成分成了四個等級:

  • 簡單:

    • Question: What is the number of cars with more than 4 cylinders?

    • SQL:SELECT COUNT (*)FROM cars_dataWHERE cylinders > 4

  • 中等:

    • Question: For each stadium, how many concerts are there?

    • SQL:SELECT T2.name, COUNT (*) FROM concert AS T1 JOIN stadium AS T2ON T1.stadium_id = T2.stadium_idGROUP BY T1.stadium_id

  • 較難

    • Question: Which countries in Europe have at least 3 car manufacturers?

    • SQL:SELECT T1.country name FROM countries AS T1 JOIN continents AS T2 ON T1.continent T2.cont_id JOIN car makers AS T3 ON T1.country_id = T3.country WHERE T2.continent = 'Europe' GROUPBY T1.country_name HAVINGCOUNT (*) >= 3

  • 極難

    • Question: What is the average life expectancy in the countries where English is not the official language?

    • SQL:SELECT AVG(life_expectancy) FROM country WHERE name NOT IN ( SELECT T1.name FROM country AS T1 JOIN country_language AS T2 ON T1.code = T2.country_code WHERE T2.language = "English" AND T2.is_official = "T")

為了充分利用資料庫中的表和欄位等相關資訊,對 Spider 中的原始資料進行處理,用自然語言表示資料庫包含的表結構以及表結構包含的欄位以及相應的主鍵和外來鍵等,經過資料預處理後,可以得到如下的資料格式:


{"instruction": "concert_singer(資料庫名) contains tables(表) such as stadium, singer, concert, singer_in_concert. Table stadium has columns(列) such as stadium_id, location, name, capacity, highest, lowest, average. stadium_id is the primary key(主鍵). Table singer has columns such as singer_id, name, country, song_name, song_release_year, age, is_male. singer_id is the primary key. Table concert has columns such as concert_id, concert_name, theme, stadium_id, year. concert_id is the primary key. Table singer_in_concert has columns such as concert_id, singer_id. concert_id is the primary key. The year of concert is the foreign key(外來鍵)of location of stadium. The stadium_id of singer_in_concert is the foreign key of name of singer. The singer_id of singer_in_concert is the foreign key of concert_name of concert.", 



"input": "How many singers do we have?", 



"response": "select count(*) from singer"}




{"instruction": "concert_singer(資料庫名)包含表(表),例如stadium, singer, concert, singer_in_concert。表體育場有列(列),如stadium_id、位置、名稱、容量、最高、最低、平均。Stadium_id是主鍵(主鍵)。表singer有這樣的列:singer_id、name、country、song_name、song_release_year、age、is_male。Singer_id為主鍵。表concert有如下列:concert_id、concert_name、theme、stadium_id、year。Concert_id是主鍵。表singer_in_concert有如下列:concert_id, singer_id。Concert_id是主鍵。演唱會年份是場館位置的外來鍵(外來鍵)。singer_in_concert的stadium_id是歌手名的外來鍵。singer_in_concert的singer_id是concert的concert_name的外來鍵。

"input": "我們有多少歌手?"

"response": "select count(*) from singer"}

同時,為了更好的利用大語言模型的理解能力,定製了 prompt dict 以最佳化輸入,如下所示:


SQL_PROMPT_DICT = {



    "prompt_input": (



        "I want you to act as a SQL terminal in front of an example database. "



        "Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\n"



        "###Instruction:\n{instruction}\n\n###Input:\n{input}\n\n###Response: "



    ),



    "prompt_no_input": (



        "I want you to act as a SQL terminal in front of an example database. "



        "Below is an instruction that describes a task, Write a response that appropriately completes the request.\n\n"



        "###Instruction:\n{instruction}\n\n### Response: "



    ),



}



2.2 模型訓練

將從基礎模型和微調方式來進行

2.2.1基礎模型

目前支援的模型結構如下所示,包含了當下主流的中外開源模型系列,比如 Llama 系列、Baichuan 系列、GLM 系列、Qwen 系列等,覆蓋面廣,同時 benchmark 橫跨 7b/13B/70B 的規模。

圖 5: 不同模型的微調模式

2.2.2 微調方式

Text2SQL微調主要包含以下流程:

  • 搭建環境

  • 資料處理

  • SFT訓練

  • 權重合並

  • 模型預測

  • 效果評估

在大語言模型對特定任務或領域進行微調任務時,重新訓練所有模型引數將會帶來昂貴的訓練成本,因此出現了各種最佳化的微調方案,綜合評估模型微調速度和精度,實現了當下流行的 LoRA(Low-Rank Adaptation 的簡寫) 方法和 QLoRA(量化 + lora)方法。 LoRA 的基本原理是在凍結原模型引數的情況下,透過向模型中加入額外的網路層,並只訓練這些新增的網路層引數。由於這些新增引數數量較少,這樣不僅 finetune 的成本顯著下降,還能獲得和全模型微調類似的效果,如下圖所示:

  • 圖中藍色部分為預訓練好的模型引數,LoRA 在預訓練好的模型結構旁邊加入了 A 和 B 兩個結構,這兩個結構的引數分別初始化為高斯分佈和 0

  • A 的輸入維度和 B 的輸出維度分別與原始模型的輸入輸出維度相同,而 A 的輸出維度和 B 的輸入維度是一個遠小於原始模型輸入輸出維度的值,這就是 low-rank 的體現,可以極大地減少待訓練的引數。

  • 在訓練時只更新 A、B 的引數,預訓練好的模型引數是固定不變的。在推斷時利用重引數思想,將 AB 與 W 合併,這樣在推斷時不會引入額外的計算。而且對於不同的下游任務,只需要在預訓練模型基礎上重新訓練 AB,這樣也能加快大模型的訓練節奏。

圖三. LoRA 微調示意圖

QLoRA 方法使用一種低精度的儲存資料型別(NF4)來壓縮預訓練的語言模型。透過凍結 LM 引數,將相對少量的可訓練引數以 Low-Rank Adapters 的形式新增到模型中,LoRA 層是在訓練期間更新的唯一引數,使得模型體量大幅壓縮同時推理效果幾乎沒有受到影響。從 QLoRA 的名字可以看出,QLoRA 實際上是 Quantize+LoRA 技術。

圖 4:QLora 示意圖

2.3 模型預測

模型微調完後,基於儲存的權重和基座大模型,對 spider 資料集的 dev 測試集進行測試,可以得到模型預測的 sql。 預測的 dev_sql.json 總共有 1034 條資料,同樣需要經過資料預處理後,再拿給模型預測結果。


{"instruction": "concert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as stadium_id, location, name, capacity, highest, lowest, average. stadium_id is the primary key. Table singer has columns such as singer_id, name, country, song_name, song_release_year, age, is_male. singer_id is the primary key. Table concert has columns such as concert_id, concert_name, theme, stadium_id, year. concert_id is the primary key. Table singer_in_concert has columns such as concert_id, singer_id. concert_id is the primary key. The stadium_id of concert is the foreign key of stadium_id of stadium. The singer_id of singer_in_concert is the foreign key of singer_id of singer. The concert_id of singer_in_concert is the foreign key of concert_id of concert.", "input": "How many singers do we have?", "output": "select count(*) from singer"}



模型預測的核心程式碼如下:


def inference(model: ChatModel, predict_data: List[Dict], **input_kwargs):



    res = []



    # test



    # for item in predict_data[:20]:



    for item in tqdm(predict_data, desc="Inference Progress", unit="item"):



        response, _ = model.chat(query=item["input"], history=[], **input_kwargs)



        res.append(response)



    return res



2.4 模型評估

模型預測得到 sql 後,需要和 spider 資料集的標準答案對比,使用 EX(execution accuracy)和 EM(Exact Match)指標進行評估 EX 指標是計算 SQL 執行結果正確的數量在資料集中的比例,公示如下所示:

$$

\mathrm{EX}=\frac{\Sigma_{n=1}^N \operatorname{score}\left(\hat{Y}_n, Y_n\right)}{N}

$$

EM 指標是計算模型生成的 SQL 和標註 SQL 的匹配程度。

$$

\mathrm{EM}=\frac{\sum_{n=1}^N s \operatorname{core}\left(\hat{Y}_n, Y_n\right)}{N}

$$

3.benchmark 設計

3.1 資料集

的 benchmark 在 bird 和 spirder 兩個資料上構建:

  • Spider:是一個大規模跨域資料集,包含 10,181 個自然語言查詢、5,693 個跨 200 個資料庫的獨特複雜 SQL 查詢,涵蓋 138 個域。此資料集的標準協議將其分為 8,659 個訓練示例和 34 個資料庫中的 2,147 個測試示例。SQL 查詢分為四個難度級別,即簡單、中等、困難和極難。

  • BIRD:它包含一個廣泛的資料集,其中包含 12,751 個獨特的問題 - SQL 對,涵蓋 95 個大型資料庫。SQL 查詢分為三個難度級別,即簡單、中等和挑戰。值得注意的是,BIRD 資料集中的 SQL 查詢往往比 Spider 資料集中的 SQL 查詢更復雜。

整體程式碼適配 WikiSQL,CoSQL 等資料集。

更多內容參考:NL2SQL基礎系列(1):業界頂尖排行榜、權威測評資料集及LLM大模型(Spider vs BIRD)全面對比優劣分析[Text2SQL、Text2DSL]

3.1.1 spider

表 1.Spider 的 EX 準確率表,L 代表 LoRA,QL 代表 QLoRA

表 2.Spider 的 EM 準確率表,L 代表 LoRA,QL 代表 QLoRA

3.1.2 BIRD

表 3.BIRD 的 EX 準確率表,L 代表 LoRA,QL 代表 QLoRA

表 4.BIRD 的 EM 準確率表,L 代表 LoRA,QL 代表 QLoRA

4. 實驗 Insight

4.1 不同難易程度任務的效果差異

如下圖所示,以三個 7B 模型為例,展示了調整後的 LLM 針對一系列 SQL 生成難度級別的有效性。對於所有三個微調後的模型,結果都表明效能提升的大小與 SQL 複雜性呈負相關,並且微調對簡單 SQL 的改進更為顯著。

4.2 LoRA 和 QLoRA 對比

如下表所示,總結 Lora 和 QLora 在 EX、EM、時間成本和 GPU 記憶體指標之間的差異。首先,發現使用 LoRA 和 QLoRA 調整的模型在生成效能(以 EX 和 EM 衡量)方面差異有限。其次,與量化機制一致,QLoRA 需要更多時間才能收斂,而 GPU 記憶體較少。例如,與 Qwen-14B-LoRA 相比,其 QLoRA 對應模型僅需要 2 倍的時間和 50%GPU 記憶體

更多優質內容請關注公號:汀丶人工智慧;會提供一些相關的資源和優質文章,免費獲取閱讀。

相關文章