2024 CCF BDCI 小樣本條件下的自然語言至圖查詢語言翻譯大模型微調|Google T5預訓練語言模型訓練與PyTorch框架的使用

KaiInssy發表於2024-11-24

程式碼詳見 https://gitee.com/wang-qiangsy/bdci

目錄
  • 一.賽題介紹
    • 1.賽題背景
    • 2.賽題任務
  • 二.關於Google T5預訓練語言模型
    • 1.T5模型主要特點
    • 2.T5模型與賽題任務的適配性分析
    • 3.模型的最佳化
  • 三.解題思路
    • 1.資料準備
    • 2.資料處理
    • 3.模型訓練
    • 4.模型評估
  • 四.程式碼實現
    • 1.配置類(Config)
    • 2.資料集類 (CypherDataset)
    • 3.訓練函式 (train)
    • 4.預測函式(generate_predictions)
    • 5.主要依賴:
  • 五.不足與分析
    • 1.錯誤的處理機制
    • 2.資料預處理和處理不平衡資料問題的缺乏
  • 六.總結與收穫
    • 1.競賽最終得分
    • 2.感受與收穫

一.賽題介紹

1.賽題背景

現代關係型資料庫使用SQL(Structured Query Language)作為查詢語言,由於SQL語言本身複雜的特性,只有少數研發工程師和資料分析師能夠熟練使用資料庫。但是隨著大語言模型技術的發展,及Text2Sql資料集的不斷完善,經過大量Text2Sql資料集訓練後的大模型已經初步具備了將自然語言翻譯成可執行的SQL語句的能力,極大的降低了關係型資料庫的使用門檻。
同樣的,在圖資料庫領域也存在相似的問題,甚至更為嚴峻。由於圖資料庫本身並沒有統一的查詢語言,目前是多種查詢語法並存的狀態,使用門檻比關係型資料庫更高。即便想要使用大模型技術將自然語言翻譯成可執行的圖查詢語言,依然面臨著缺乏Text2Sql領域海量語料的困難。如何透過每一種圖查詢語言現有的少量語料,微調出一個可以高質量的將自然語言翻譯成對應圖查詢語言的大模型,並以此降低圖資料庫的使用門檻,成為了現階段的一個重要研究方向。

2.賽題任務

參賽者需要使用提供的在TuGraph-DB上可執行的Cypher語料,對一個指定的本地模型進行微調,使得微調後的模型能夠準確的將測試集中的自然語言描述翻譯成對應的Cypher語句,翻譯結果將基於文字相似度和語法正確性兩個方面綜合評分。

二.關於Google T5預訓練語言模型

1.T5模型主要特點

  • 統一框架
    T5將輸入和輸出格式化為純文字字串。
  • 基於Transformer架構
    T5採用標準的Transformer模型架構,包含一個編碼器和一個解碼器。與GPT相比,其雙向編碼器和自迴歸解碼器相結合,更適合生成式任務。
  • 多工學習
    T5在一個包含各種任務的超大資料集上進行預訓練,使模型能夠適應不同任務的切換。
  • 開放的預訓練與微調方式
    預訓練:使用了C4(Colossal Clean Crawled Corpus)資料集,重點清洗了Web文字。
    微調:透過特定任務的資料集進一步最佳化。

2.T5模型與賽題任務的適配性分析

  • 文字到文字統一框架
    由於T5本質是一個將所有任務轉化為文字輸入和文字輸出的模型,具有將輸入和輸出格式化為純文字字串的特點,所以正好與“自然語言描述到Cypher語句翻譯”這一任務匹配。
  • 生成式任務能力
    T5在多工訓練中積累了強大的生成能力,Cypher語句是一種結構化查詢語言,其語法較為固定,T5的自迴歸生成解碼器在確保生成語句語法正確性方面具有優勢。
  • 遷移學習的可擴充套件性
    透過在提供的Cypher語料上微調,T5能夠快速適配新任務,達到較高的準確率和生成質量。

3.模型的最佳化

  • 指令調優
  • 資料增強
  • 知識注入
  • 模型蒸餾

三.解題思路

1.資料準備

  • 載入Schema檔案:從指定路徑載入movie.json,yago.json,the_three_body.json和finbench.json的Schema檔案,並將其儲存在一個字典中。每個Schema檔案描述了一個資料庫的結構,包括節點(VERTEX)和邊(EDGE)的定義及其屬性。
  • 載入訓練資料:從指定路徑載入訓練資料train_cypher,訓練資料包含自然語言描述和對應的Cypher語句。

2.資料處理

  • 定義資料集類:我們先是使用CypherDataset類將訓練資料和Schema結合起來,然後使用Tokenizer將自然語言描述和目標Cypher語句編碼為模型可接受的格式。(詳細程式碼中的__getitem__方法中,將自然語言描述和對應的Schema結合,構建輸入文字。使用Tokenizer對輸入文字和目標文字進行編碼,返回模型所需的張量格式資料。)

3.模型訓練

  • 初始化模型和Tokenizer:使用預訓練的T5模型和對應的Tokenizer。
  • 建立資料集例項:使用CypherDataset類建立訓練資料集,使用Tokenizer將自然語言描述和目標Cypher語句編碼為模型可接受的格式。
  • 設定訓練引數:使用TrainingArguments類設定訓練引數,如訓練輪數、批次大小、學習率等。
  • 建立Trainer例項:使用Trainer類進行模型訓練,Trainer類封裝了訓練過程中的許多細節,如梯度計算、引數更新、模型儲存等。

4.模型評估

  • 文字相似度:對生成的Cypher語句與參考答案進行文字相似度計算,評估模型的翻譯準確性。
  • 語法正確性:檢查生成的Cypher語句的語法正確性,確保其能夠在TuGraph-DB上正確執行。

四.程式碼實現

1.配置類(Config)

class Config:
    def __init__(self):
        self.model_name = "t5-base"  # 使用T5基礎模型
        self.cache_dir = "./model_cache"  # 模型快取目錄
        self.output_dir = "./results"  # 輸出目錄
        self.num_train_epochs = 3  # 訓練輪數
        self.batch_size = 4  # 批次大小
        self.learning_rate = 5e-5  # 學習率
        self.max_length = 512  # 最大序列長度
        self.warmup_steps = 100  # 預熱步數
        self.save_steps = 1000  # 儲存檢查點的步數間隔
        self.eval_steps = 1000  # 評估的步數間隔

2.資料集類 (CypherDataset)

class CypherDataset(Dataset):
    # 資料處理的核心類,繼承自PyTorch的Dataset
    def __init__(self, data, schemas, tokenizer, max_length):
        # 初始化資料集,接收原始資料、schema定義、分詞器和最大長度
        
    def __getitem__(self, idx):
        # 構建輸入格式:Schema + Question
        # 返回經過編碼的輸入資料、注意力掩碼和標籤

3.訓練函式 (train)

關鍵程式碼段

def train():
    # 載入schema檔案
    schemas = {}
    # ...
    
    # 初始化模型和tokenizer
    tokenizer = T5Tokenizer.from_pretrained(...)
    model = T5ForConditionalGeneration.from_pretrained(...)
    
    # 建立資料集和訓練器
    train_dataset = CypherDataset(...)
    trainer = Trainer(...)
    
    # 訓練和儲存
    trainer.train()
    trainer.save_model("./cypher_model")

4.預測函式(generate_predictions)

關鍵程式碼段

def generate_predictions():
    # 載入模型
    model = T5ForConditionalGeneration.from_pretrained(...)
    tokenizer = T5Tokenizer.from_pretrained(...)
    
    # 生成預測
    predictions = []
    for item in test_data:
        input_text = f"Schema: {schema}\nQuestion: {item['question']}"
        outputs = model.generate(...)
        predicted_text = tokenizer.decode(...)
        predictions.append(...)

5.主要依賴:

  • torch: PyTorch深度學習框架
  • transformers: Hugging Face的轉換器庫
  • numpy: 數值計算庫
  • json: JSON資料處理

五.不足與分析

1.錯誤的處理機制

  • 缺乏日誌管理,無法更好地對程式碼各種報錯資訊進行除錯處理,在訓練cypher語料時,無法及時獲取相關資訊反饋。
  • 進行錯誤處理機制的完善,引入日誌系統。
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

try:
    with open(file_path, 'r', encoding='utf-8') as f:
        schema = json.load(f)
except FileNotFoundError:
    logger.error(f"檔案不存在: {file_path}")
    continue
except json.JSONDecodeError as e:
    logger.error(f"JSON解析錯誤 {file_path}: {str(e)}")

    continue
except Exception as e:
    logger.error(f"載入schema時發生未知錯誤: {str(e)}")
    continue

2.資料預處理和處理不平衡資料問題的缺乏

  • 對資料的預處理不夠充分,可能導致資料質量和資料格式達不到預期。訓練語料資訊的缺乏,在訓練任務中,不同類別的資料樣本數量差異較大。
  • 進行資料清洗和資料格式化進行資料預處理,透過重取樣,重新定義損失函式解決不平衡資料的處理。
class CypherDataset(Dataset):
    def __init__(self, data, schemas, tokenizer, max_length):
        self.data = self._preprocess_data(data)  # 新增預處理
        
    def _preprocess_data(self, data):
        processed_data = []
        for item in data:
            # 資料清洗
            if self._validate_item(item):
                # 資料增強
                augmented_items = self._augment_data(item)
                processed_data.extend(augmented_items)
        return processed_data

六.總結與收穫

1.競賽最終得分

2.感受與收穫

  • 資料預處理:小組學習瞭如何載入和處理JSON格式的訓練和測試資料。並透過編寫自定義的Dataset類,掌握瞭如何將資料轉換為模型可以接受的格式。
  • 模型微調:小組瞭解如何使用Hugging Face的Transformers庫進行模型微調。並且對T5模型進行微調後用於特定任務。
  • 圖資料庫與Cypher語句:在透過處理不同的schema檔案中,理解了圖資料庫的結構和Cypher查詢語言。
  • 透過這個專案,我們小組不僅提升了自然語言處理和深度學習的技能,還對圖資料庫和Cypher查詢語言有了更深入的理解。這些收穫將對我們未來的學習框架的使用和大模型微調帶來積極的影響。總的來說,這次專案實踐讓我們在理論和實踐上都有了顯著的提升。

相關文章