如何在langchain中對大模型的輸出進行格式化

flydean發表於2023-11-27

簡介

我們知道在大語言模型中, 不管模型的能力有多強大,他的輸入和輸出基本上都是文字格式的,文字格式的輸入輸出雖然對人來說非常的友好,但是如果我們想要進行一些結構化處理的話還是會有一點點的不方便。

不用擔心,langchain已經為我們想到了這個問題,並且提出了完滿的解決方案。

langchain中的output parsers

langchain中所有的output parsers都是繼承自BaseOutputParser。這個基礎類提供了對LLM大模型輸出的格式化方法,是一個優秀的工具類。

我們先來看下他的實現:

class BaseOutputParser(BaseModel, ABC, Generic[T]):

    @abstractmethod
    def parse(self, text: str) -> T:
        """Parse the output of an LLM call.

        A method which takes in a string (assumed output of a language model )
        and parses it into some structure.

        Args:
            text: output of language model

        Returns:
            structured output
        """

    def parse_with_prompt(self, completion: str, prompt: PromptValue) -> Any:
        """Optional method to parse the output of an LLM call with a prompt.

        The prompt is largely provided in the event the OutputParser wants
        to retry or fix the output in some way, and needs information from
        the prompt to do so.

        Args:
            completion: output of language model
            prompt: prompt value

        Returns:
            structured output
        """
        return self.parse(completion)

    def get_format_instructions(self) -> str:
        """Instructions on how the LLM output should be formatted."""
        raise NotImplementedError

    @property
    def _type(self) -> str:
        """Return the type key."""
        raise NotImplementedError(
            f"_type property is not implemented in class {self.__class__.__name__}."
            " This is required for serialization."
        )

    def dict(self, **kwargs: Any) -> Dict:
        """Return dictionary representation of output parser."""
        output_parser_dict = super().dict()
        output_parser_dict["_type"] = self._type
        return output_parser_dict

BaseOutputParser 是一個基礎的類,可能被其他特定的輸出解析器繼承,以實現特定語言模型的輸出解析。

這個類使用了Python的ABC模組,表明它是一個抽象基類(Abstract Base Class),不能被直接例項化,而是需要子類繼承並實現抽象方法。

Generic[T] 表示這個類是一個泛型類,其中T 是一個型別變數,它表示解析後的輸出資料的型別。

@abstractmethod 裝飾器標記了 parse 方法,說明它是一個抽象方法,必須在子類中實現。parse 方法接受一個字串引數 text,通常是語言模型的輸出文字,然後將其解析成特定的資料結構,並返回。

parse_with_prompt 方法也是一個抽象方法,接受兩個引數,completion 是語言模型的輸出,prompt 是與輸出相關的提示資訊。這個方法是可選的,可以用於在需要時解析輸出,可能根據提示資訊來調整輸出。

get_format_instructions 方法返回關於如何格式化語言模型輸出的說明。這個方法可以用於提供解析後資料的格式化資訊。

_type 是一個屬性,可能用於標識這個解析器的型別,用於後續的序列化或其他操作。

dict 方法返回一個包含輸出解析器資訊的字典,這個字典可以用於序列化或其他操作。

其中子類必須要實現的方法就是parse。其他的都做為輔助作用。

langchain中有哪些Output Parser

那麼langchain中有哪些Output Parser的具體實現呢?具體對應我們應用中的什麼場景呢?

接下來我們將會一一道來。

List parser

ListOutputParser的作用就是把LLM的輸出轉成一個list。ListOutputParser也是一個基類,我們具體使用的是他的子類:CommaSeparatedListOutputParser。

看一下他的parse方法:

    def parse(self, text: str) -> List[str]:
        """Parse the output of an LLM call."""
        return text.strip().split(", ")

還有一個get_format_instructions:

    def get_format_instructions(self) -> str:
        return (
            "Your response should be a list of comma separated values, "
            "eg: `foo, bar, baz`"
        )

get_format_instructions是告訴LLM以什麼樣的格式進行資料的返回。

就是把LLM的輸出用逗號進行分割。

下面是一個基本的使用例子:

output_parser = CommaSeparatedListOutputParser()

format_instructions = output_parser.get_format_instructions()
prompt = PromptTemplate(
    template="列出幾種{subject}.\n{format_instructions}",
    input_variables=["subject"],
    partial_variables={"format_instructions": format_instructions}
)

_input = prompt.format(subject="水果")
output = model(_input)
print(output)
print(output_parser.parse(output))

我們可以得到下面的輸出:

Apple, Orange, Banana, Grape, Watermelon, Strawberry, Pineapple, Peach, Mango, Cherry
['Apple', 'Orange', 'Banana', 'Grape', 'Watermelon', 'Strawberry', 'Pineapple', 'Peach', 'Mango', 'Cherry']

看到這裡,大家可能有疑問了, 為什麼我們問的是中文,返回的卻是因為呢?

這是因為output_parser.get_format_instructions就是用英文描述的,所以LLM會自然的用英文來回答。

別急,我們可以稍微修改下執行程式碼,如下:

output_parser = CommaSeparatedListOutputParser()

format_instructions = output_parser.get_format_instructions()
prompt = PromptTemplate(
    template="列出幾種{subject}.\n{format_instructions}",
    input_variables=["subject"],
    partial_variables={"format_instructions": format_instructions + "用中文回答"}
)

_input = prompt.format(subject="水果")
output = model(_input)
print(output)
print(output_parser.parse(output))

我們在format_instructions之後,提示LLM需要用中文來回答問題。這樣我們就可以得到下面的結果:

蘋果,橘子,香蕉,梨,葡萄,芒果,檸檬,桃
['蘋果,橘子,香蕉,梨,葡萄,芒果,檸檬,桃']

是不是很棒?

Datetime parser

DatetimeOutputParser用來將LLM的輸出進行時間的格式化。

class DatetimeOutputParser(BaseOutputParser[datetime]):
    format: str = "%Y-%m-%dT%H:%M:%S.%fZ"

    def get_format_instructions(self) -> str:
        examples = comma_list(_generate_random_datetime_strings(self.format))
        return f"""Write a datetime string that matches the 
            following pattern: "{self.format}". Examples: {examples}"""

    def parse(self, response: str) -> datetime:
        try:
            return datetime.strptime(response.strip(), self.format)
        except ValueError as e:
            raise OutputParserException(
                f"Could not parse datetime string: {response}"
            ) from e

    @property
    def _type(self) -> str:
        return "datetime"

在get_format_instructions中,他告訴LLM返回的結果是一個日期的字串。

然後在parse方法中對這個LLM的輸出進行格式化,最後返回datetime。

我們看下具體的應用:

output_parser = DatetimeOutputParser()
template = """回答下面問題:
{question}
{format_instructions}"""
prompt = PromptTemplate.from_template(
    template,
    partial_variables={"format_instructions": output_parser.get_format_instructions()},
)
chain = LLMChain(prompt=prompt, llm=model)
output = chain.run("中華人民共和國是什麼時候成立的?")
print(output)
print(output_parser.parse(output))
1949-10-01T00:00:00.000000Z
1949-10-01 00:00:00

回答的還不錯,給他點個贊。

Enum parser

如果你有列舉的型別,那麼可以嘗試使用EnumOutputParser.

EnumOutputParser的建構函式需要傳入一個Enum,我們主要看下他的兩個方法:

    @property
    def _valid_values(self) -> List[str]:
        return [e.value for e in self.enum]

    def parse(self, response: str) -> Any:
        try:
            return self.enum(response.strip())
        except ValueError:
            raise OutputParserException(
                f"Response '{response}' is not one of the "
                f"expected values: {self._valid_values}"
            )

    def get_format_instructions(self) -> str:
        return f"Select one of the following options: {', '.join(self._valid_values)}"

parse方法接收一個字串 response,嘗試將其解析為列舉型別的一個成員。如果解析成功,它會返回該列舉成員;如果解析失敗,它會丟擲一個 OutputParserException 異常,異常資訊中包含了所有有效值的列表。

get_format_instructions告訴LLM需要從Enum的有效value中選擇一個輸出。這樣parse才能接受到正確的輸入值。

具體使用的例子可以參考前面兩個parser的用法。篇幅起見,這裡就不列了。

Pydantic (JSON) parser

JSON可能是我們在日常程式碼中最常用的資料結構了,這個資料結構很重要。

在langchain中,提供的JSON parser叫做:PydanticOutputParser。

既然要進行JSON轉換,必須得先定義一個JSON的型別物件,然後告訴LLM將文字輸出轉換成JSON格式,最後呼叫parse方法把json字串轉換成JSON物件。

我們來看一個例子:


class Student(BaseModel):
    name: str = Field(description="學生的姓名")
    age: str = Field(description="學生的年齡")

student_query = "告訴我一個學生的資訊"

parser = PydanticOutputParser(pydantic_object=Student)

prompt = PromptTemplate(
    template="回答下面問題.\n{format_instructions}\n{query}\n",
    input_variables=["query"],
    partial_variables={"format_instructions": parser.get_format_instructions()+"用中文回答"},
)

_input = prompt.format_prompt(query=student_query)

output = model(_input.to_string())
print(output)
print(parser.parse(output))

這裡我們定義了一個Student的結構體,然後讓LLM給我一個學生的資訊,並用json的格式進行返回。

之後我們使用parser.parse來解析這個json,生成最後的Student資訊。

我們可以得到下面的輸出:

示例輸出:{"name": "張三", "age": "18"}
name='張三' age='18'

Structured output parser

雖然PydanticOutputParser非常強大, 但是有時候我們只是需要一些簡單的結構輸出,那麼可以考慮StructuredOutputParser.

我們看一個具體的例子:

response_schemas = [
    ResponseSchema(name="name", description="學生的姓名"),
    ResponseSchema(name="age", description="學生的年齡")
]
output_parser = StructuredOutputParser.from_response_schemas(response_schemas)

format_instructions = output_parser.get_format_instructions()
prompt = PromptTemplate(
    template="回答下面問題.\n{format_instructions}\n{question}",
    input_variables=["question"],
    partial_variables={"format_instructions": format_instructions}
)

_input = prompt.format_prompt(question="給我一個女孩的名字?")
output = model(_input.to_string())
print(output)
print(output_parser.parse(output))

這個例子是上面的PydanticOutputParser的改寫,但是更加簡單。

我們可以得到下面的結果:

 ` ` `json
{
	"name": "Jane",
	"age": "18"
}
 ` ` `
{'name': 'Jane', 'age': '18'}

output返回的是一個markdown格式的json字串,然後透過output_parser.parse得到最後的json。

其他的一些parser

除了json,xml格式也是比較常用的格式,langchain中提供的XML parser叫做XMLOutputParser。

另外,如果我們在使用parser的過程中出現了格式問題,langchain還貼心的提供了一個OutputFixingParser。也就是說當第一個parser報錯的時候,或者說不能解析LLM輸出的時候,就會換成OutputFixingParser來嘗試修正格式問題:

from langchain.output_parsers import OutputFixingParser

new_parser = OutputFixingParser.from_llm(parser=parser, llm=ChatOpenAI())

new_parser.parse(misformatted)

如果錯誤不是因為格式引起的,那麼langchain還提供了一個RetryOutputParser,來嘗試重試:

from langchain.output_parsers import RetryWithErrorOutputParser

retry_parser = RetryWithErrorOutputParser.from_llm(
    parser=parser, llm=OpenAI(temperature=0)
)

retry_parser.parse_with_prompt(bad_response, prompt_value)

這幾個parser都非常有用,大家可以自行嘗試。

總結

雖然langchain中的有些parser我們可以自行藉助python語言的各種工具來實現。但是有一些parser實際上是要結合LLM一起來使用的,比如OutputFixingParser和RetryOutputParser。

所以大家還是儘可能的使用langchain提供的parser為好。畢竟輪子都給你造好了,還要啥腳踏車。

相關文章