本地執行 Gemma 的 pytorch 整合

yyintech發表於2024-03-01


Gemma 是 Google 在2024年2月21日釋出的一款輕量的開源大模型,採用了和 Google Gemini 模型一樣的技術。有猜測 Google 在毫無預告的情況下急忙釋出 Gemma 是對 Meta 的 Llama3 的截胡,但不管怎麼說作為名廠名牌的大模型,自然要上手嘗試嘗試。

這次釋出的 Gemma 有 2B 引數和 7B 引數兩個版本,兩個版本又分別提供了預訓練 (Pretrained) 和指令除錯 (Instruction tuned) 兩個版本。預訓練版本做了基礎訓練,而指令除錯版本做了根據人類語言互動的特定訓練調整,所以如果直接拿來做會話使用可以下載 it 版本。2B 和 7B 在於引數量的多少,7B 需要更多的資源去執行。

好了,前面囉嗦了一堆背景,為了引出這裡介紹 2b-it 版本地部署的原因——耗資源少且可以本地使用會話。

準備環境

  • 安裝 python venv,命名 gemma-torch conda env create -n "gemma-torch"
  • 啟用虛擬環境 conda activate gemma-torch
  • 安裝依賴的庫 pip install torch immutabledict sentencepiece numpy packaging 後面兩個庫不是官方文件裡要求的,但是根據我執行報錯,需要安裝。另外上面命令也取消了-q -U 簡單粗暴也方便觀察。

為了後續用程式碼連線 kaggle 下載模型,還需要安裝 kagglehub 包:

pip install kagglehub

連線 kaggle

這一步的目的是從 kaggle 上面下載模型。

  • 首先獲取 kaggle 的訪問許可權 登入 kaggle,在設定頁面 的 API 一節點選按鈕 “Create New Token”,會觸發下載 kaggle.json。 ​​
  • 配置環境 將 kaggle.json 檔案複製到~/.kaggle/目錄下。並在~/.bash_profile 中設定環境變數 KAGGLE_CONFIG_DIR 為~/.kaggle。

這樣就可以透過下面程式碼訪問 (後面的程式碼寫到一塊,不需要此處執行)。

import kagglehub

kagglehub.login()

執行程式碼

經過前面的配置後,可以程式碼本地執行 2b-it 模型了。不過載入模型還需要 gemma_pytorch 包。
從 github 倉庫 clone 到本地:

# NOTE: The "installation" is just cloning the repo.
git clone https://github.com/google/gemma_pytorch.git

將下載好的 gemma_pytorch 資料夾放到下面指令碼檔案同一級目錄下 ,並在~/.bash_profile 中設定 PYTHONPATH 環境變數包含該資料夾路徑。

最後執行指令碼 (gemma_torch.py):

# Choose variant and machine type
import kagglehub
import os
import sys
from gemma_pytorch.gemma.config import get_config_for_7b, get_config_for_2b
from gemma_pytorch.gemma.model import GemmaForCausalLM
import torch


VARIANT = '2b-it'
#如果是cpu執行,將下面cuda改成cpu,不過巨慢
MACHINE_TYPE = 'cuda'

# Load model weights
# 模型下載到了~/.cache目錄下
weights_dir = kagglehub.model_download(f'google/gemma/pyTorch/{VARIANT}')

# Ensure that the tokenizer is present
tokenizer_path = os.path.join(weights_dir, 'tokenizer.model')
assert os.path.isfile(tokenizer_path), 'Tokenizer not found!'

# Ensure that the checkpoint is present
ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
assert os.path.isfile(ckpt_path), 'PyTorch checkpoint not found!'

# Set up model config.
model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = tokenizer_path
model_config.quant = 'quant' in VARIANT

# Instantiate the model and load the weights.
torch.set_default_dtype(model_config.get_dtype())
device = torch.device(MACHINE_TYPE)
model = GemmaForCausalLM(model_config)
model.load_weights(ckpt_path)
model = model.to(device).eval()

# Generate with one request in chat mode

# Chat templates
USER_CHAT_TEMPLATE = '<start_of_turn>user\n{prompt}<end_of_turn>\n'
MODEL_CHAT_TEMPLATE = '<start_of_turn>model\n{prompt}<end_of_turn>\n'

# Sample formatted prompt
prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt='What is a good place for travel in the US?'
    )
    + MODEL_CHAT_TEMPLATE.format(prompt='California.')
    + USER_CHAT_TEMPLATE.format(prompt='What can I do in California?')
    + '<start_of_turn>model\n'
)
print('Chat prompt:\n', prompt)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)

# Generate sample
model.generate(
    'Write a poem about an llm writing a poem.',
    device=device,
    output_len=60,
)

一點後話

能用 GPU 還是上 GPU 吧,我本地用的 CPU 筆記本跑的巨慢。

可以線上使用 colab,具體步驟參考這個帖子 (昨天 Google 釋出了最新的開源模型 Gemma,今天我來體驗一下_gemma_lm.generate-CSDN 部落格)。

不過我在使用過程中發現 T4 經常在預測執行時報 OOM,導致無法產出結果。

參考資料:
pytorch 中使用 Gemma: https://ai.google.dev/gemma/docs/pytorch_gemma

官方文件地址:https://ai.google.dev/gemma/docs

相關文章