Transformers Pipeline + Mistral-7B-Instruct-v0.x修改Chat Template

一蓑烟雨度平生發表於2024-07-17

在使用https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3提供的Generate with transformers程式碼進行測試時,產生以下報錯:

from transformers import pipeline

messages = [
    {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
    {"role": "user", "content": "Who are you?"},
]
chatbot = pipeline("text-generation", model="mistralai/Mistral-7B-Instruct-v0.3")
chatbot(messages)
TemplateError: Conversation roles must alternate user/assistant/user/assistant/...

這個錯誤是由於Mistral本身不支援system prompt導致的。
檢視tokenizer.apply_chat_template的原始碼,可以看到預設的chat template是這樣的:

def default_chat_template(self):
        """
        This template formats inputs in the standard ChatML format. See
        https://github.com/openai/openai-python/blob/main/chatml.md
        """
        return (
            "{% for message in messages %}"
            "{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}"
            "{% endfor %}"
            "{% if add_generation_prompt %}"
            "{{ '<|im_start|>assistant\n' }}"
            "{% endif %}"
        )

為了在使用Transformers Pipeline + Mistral模型時能夠支援system prompt,我們需要修改預設的chat template:

{% if messages[0]['role'] == 'system' %}
    {% set system_message = messages[0]['content'] | trim + '\n\n' %}
    {% set messages = messages[1:] %}
{% else %}
    {% set system_message = '' %}
{% endif %}

{{ bos_token + system_message}}
{% for message in messages %}
    {% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}
        {{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}
    {% endif %}

    {% if message['role'] == 'user' %}
        {{ '[INST] ' + message['content'] | trim + ' [/INST]' }}
    {% elif message['role'] == 'assistant' %}
        {{ ' ' + message['content'] | trim + eos_token }}
    {% endif %}
{% endfor %}

在程式碼中將預設的chat_template覆蓋:

tokenizer.apply_chat_template(
        messages, 
        chat_template=mistral_chat_template,
        tokenize=False, 
        add_generation_prompt=True
)

這樣就可以順利進行推理了。

相關文章