使用Microsoft.SemanticKernel基於本地執行的Ollama大語言模型實現Agent呼叫函式

a1010發表於2024-06-21

大語言模型的發展日新月異,記得在去年這個時候,函式呼叫還是gpt-4的專屬。到今年本地執行的大模型無論是推理能力還是文字的輸出質量都已經非常接近gpt-4了。而在去年gpt-4尚未釋出函式呼叫時,智慧體框架的開發者們依賴構建精巧的提示詞實現了gpt-3.5的函式呼叫。目前在本機執行的大模型,基於這一套邏輯也可以實現函式式呼叫,今天我們就是用本地執行的大模型來實現這個需求。從測試的效果來看,本地大模型對於簡單的函式呼叫成功率已經非常高了,但是受限於本地機器的效能,呼叫的時間還是比較長。如果有NVIDIA顯示卡的CUDA環境,質量應該會好很多,今天就以大家都比較熟悉的LLAMA生態作為起點,基於阿里雲開源的千問7B模型的量化版作為基座透過C#和SemanticKernel來實現函式呼叫的功能。

基本呼叫邏輯參考這張圖:

首先我們需要在本機(windows系統)安裝Ollama作為LLM的API後端。訪問https://ollama.com/,選擇Download。選擇你需要的版本即可,windows使用者請選擇Download for Windows。下載完成後,無腦點選下一步下一步即可安裝完畢。

安裝完畢後,開啟我們的PowerShell即可執行大模型,第一次載入會下載模型檔案到本地磁碟,會比較慢。執行起來後就可以透過控制檯和模型進行簡單的對話,這裡我們以阿里釋出的千問2:7b為例。執行以下命令即可執行起來:

ollama run qwen2:7b

接著我們使用ctrl+D退出對話方塊,並執行ollama serve,看看伺服器是否執行起來了,正常情況下會看到11434這個埠已經執行起來了。接下來我們就可以進入到編碼階段

首先我們建立一個.net8.0的的控制檯,接著我們引入三個必要的包

dotnet add package Microsoft.SemanticKernel --version 1.15.0
dotnet add package Newtonsoft.Json --version 13.0.3
dotnet add package OllamaSharp --version 2.0.1

SemanticKernel是我們主要的代理執行框架,OllamaSharp是一個簡單的面向Ollama本地API服務的請求封裝。避免我們手寫httpclient來與本地伺服器互動。我這裡安裝了Newtonsoft.Json來替代system.text.json,主要是用於後期需要一些序列化模型回撥來使用,因為模型的回撥json可能不是特別標準,使用system.text.json容易導致轉義失敗。

接下來就是編碼階段,首先我們定義一個函式,這個函式是後面LLM會用到的函式,簡單的定義如下:

public class FunctionTest
{
    [KernelFunction, Description("獲取城市的天氣狀況")]
    public object GetWeather([Description("城市名稱")] string CityName, [Description("查詢時段,值可以是[白天,夜晚]")] string DayPart)
    {
        return new { CityName, DayPart, CurrentCondition = "多雲", LaterCondition = "陰", MinTemperature = 19, MaxTemperature = 23 };
    }
}

這裡的KernelFunction和Description特性都是必要的,用於SemanticKernel查詢到對應的函式並封裝處對應的後設資料。

接著我們需要自定義一個繼承自介面IChatCompletionService的實現,因為SemanticKernel是基於openai的gpt系列設計的框架,所以要和本地模型呼叫,我們需要設定獨立的ChatCompletionService來讓SemanticKernel和本機模型API互動。這裡我們主要需要實現的函式是GetChatMessageContentsAsync。因為函式呼叫我們需要接收到模型完整的回撥用於轉換json,所以流式傳輸這裡用不上。

public class CustomChatCompletionService : IChatCompletionService
{
    public IReadOnlyDictionary<string, object?> Attributes => throw new NotImplementedException();

    public Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
    {
        throw new NotImplementedException();
    }

    public IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
    {
        throw new NotImplementedException();
    }
}

接下來我們需要定義一個SemanticKernel的例項,這個例項會伴隨本次呼叫貫穿全程。SemanticKernel使用了簡單的鏈式構建。基本程式碼如下:

var builder = Kernel.CreateBuilder();
//這裡我們需要增加剛才我們定義的例項CustomChatCompletionService,有點類似IOC的設計
builder.Services.AddKeyedSingleton<IChatCompletionService>("ollamaChat", new CustomChatCompletionService());
//這裡我們需要插入之前定義的外掛
builder.Plugins.AddFromType<FunctionTest>();
var kernel = builder.Build();

可以看到基本的構建鏈式呼叫程式碼部分還是比較簡單的,接下來就是呼叫的部分,這裡主要的部分就是將LLM可用的函式插入到系統提示詞,來引導LLM去呼叫特定函式:

//定義一個對話歷史
ChatHistory history = [];
//獲取剛才定義的外掛函式的後設資料,用於後續建立prompt
var plugins = kernel.Plugins.GetFunctionsMetadata();
//生成函式呼叫提示詞,引導模型根據使用者請求去呼叫函式
var functionsPrompt = CreateFunctionsMetaObject(plugins);
//建立系統提示詞,插入剛才生成的提示詞
var prompt = $"""
                  You have access to the following functions. Use them if required:
                  {functionsPrompt}
                  If function calls are used, ensure the output is in JSON format; otherwise, output should be in text format.
                  """;
//新增系統提示詞
history.AddSystemMessage(prompt);
//建立一個對話服務例項
var chatCompletionService = kernel.GetRequiredService<IChatCompletionService>();
//新增使用者的提問
history.AddUserMessage(question);
//鏈式執行kernel
var result = await chatCompletionService.GetChatMessageContentAsync(
    history,
    executionSettings: null,
    kernel: kernel);
//列印回撥內容
Console.WriteLine($"Assistant> {result}");

在這裡我們可以debug看看生成的系統提示詞細節:

當程式碼執行到GetChatMessageContentAsync這裡時,就會跳轉到我們的CustomChatCompletionService的GetChatMessageContentsAsync函式,在這裡我們需要進行ollama的呼叫來達成目的。

這裡比較核心的部分就是將LLM回撥的內容使用JSON序列化來檢測是否涉及到函式呼叫,簡單來講由於類似qwen這樣沒有專門針對function calling專項微調過的(glm-4-9b原生支援function calling)模型,其function calling並不是每次都能準確的回撥,所以這裡我們需要對回撥的內容進行反序列化和資訊抽取,確保模型的呼叫符合回撥函式的格式標準。具體程式碼如下

public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
{
    GetDicSearchResult(kernel);
    var prompt = HistoryToText(chatHistory);
    var ollama = new OllamaApiClient("http://127.0.0.1:11434", "qwen2:7b");
    var chat = new Chat(ollama, _ => { });
    sw.Start();
    var history = (await chat.Send(prompt, CancellationToken.None)).ToArray();
    sw.Stop();
    Console.WriteLine($"呼叫耗時:{Math.Round(sw.Elapsed.TotalSeconds,2)}秒");
    var last = history.Last();
    var chatResponse = last.Content;
    try
    {
        JToken jToken = JToken.Parse(chatResponse);
        jToken = ConvertStringToJson(jToken);
        var searchs = DicSearchResult.Values.ToList();
        if (TryFindValues(jToken, ref searchs))
        {
            var firstFunc = searchs.First();
            var funcCallResult = await firstFunc.KernelFunction.InvokeAsync(kernel, firstFunc.FunctionParams);
            chatHistory.AddMessage(AuthorRole.Assistant, chatResponse);
            chatHistory.AddMessage(AuthorRole.Tool, funcCallResult.ToString());
            return await GetChatMessageContentsAsync(chatHistory, kernel: kernel);
        }
        else
        {

        }
    }
    catch(Exception e)
    {

    }
    return new List<ChatMessageContent> { new ChatMessageContent(AuthorRole.Assistant, chatResponse) };
}

這裡我們首先使用SemanticKernel的kernel的函式後設資料透過GetDicSearchResult構建了一個字典,這部分程式碼如下:

public static Dictionary<string, SearchResult> DicSearchResult = new Dictionary<string, SearchResult>();
public static void GetDicSearchResult(Kernel kernel)
{
    DicSearchResult = new Dictionary<string, SearchResult>();
    foreach (var functionMetaData in kernel.Plugins.GetFunctionsMetadata())
    {
        string functionName = functionMetaData.Name;
        if (DicSearchResult.ContainsKey(functionName))
            continue;
        var searchResult = new SearchResult
        {
            FunctionName = functionName,
            KernelFunction = kernel.Plugins.GetFunction(null, functionName)
        };
        functionMetaData.Parameters.ToList().ForEach(x => searchResult.FunctionParams.Add(x.Name, null));
        DicSearchResult.Add(functionName, searchResult);
    }
}

接著使用HistoryToText將歷史對話資訊組裝成一個單一的prompt傳送給模型,大概會組裝成如下內容,其實就是系統提示詞+使用者提示片語合成一個單一文字:

接著我們使用OllamaSharp的SDK提供的OllamaApiClient傳送資訊給模型,等待模型回撥後,從模型回撥的內容中抽取chatResponse,接著我們需要透過一個try catch來處理,當chatResponse可以被正確的解析成標準JToken後,說明模型的回撥是一段json,否則會丟擲異常,代表模型輸出的是一段文字。如果是文字,我們就直接返回模型輸出的內容,如果是json則繼續向下處理,透過一個TryFindValues函式從模型中抽取我們所需要的回撥函式名、引數,並賦值到一個臨時變數中。最後透過SemanticKernel的KernelFunction的InvokeAsync進行真正的函式呼叫,獲取到函式的回撥內容,接著我們需要將模型的原始輸出和回撥內容一同新增到chatHistory後,再度遞迴發起GetChatMessageContentsAsync呼叫,這一次模型就會拿到前一次回撥的城市天氣內容來進行回答了。

第二次回撥前的prompt如下,可以看到模型的輸出雖然是json,但是並沒有規範的格式,不過使用我們的抽取函式還是獲取到了需要的資訊,從而正確的構建了底部的回撥:

透過這一輪迴調再次餵給llm,llm就可以正確的輸出結果了:

以上就是整個文章的內容了,可以看到在這個過程中我們主要做的工作就是透過系統提示詞誘導模型輸出回撥函式json,解析json獲取引數,呼叫本地的函式後再次回撥給模型,這個過程其實有點類似的RAG,只不過RAG是透過使用者的提示詞直接進行近似度搜尋獲取到近似度相關的文字組合到系統提示詞,而函式呼叫給了模型更大的自由度,可以讓模型自行決策是否呼叫函式,從而使本地Agent代理可以實現諸如幫你操控電腦,列印檔案,編寫郵件等等助手性質的功能。

下面是核心部分的程式碼,請大家自取

program.cs:

using ConsoleApp4;
using Microsoft.SemanticKernel.ChatCompletion;
using Microsoft.SemanticKernel;
using Microsoft.Extensions.DependencyInjection;
using System.ComponentModel;
using Newtonsoft.Json.Linq;




await Ollama("我想知道西安今天晚上的天氣情況");



async Task Ollama(string question)
{
    Console.WriteLine($"User> {question}");
    var builder = Kernel.CreateBuilder();
    //這裡我們需要增加剛才我們定義的例項CustomChatCompletionService,有點類似IOC的設計
    builder.Services.AddKeyedSingleton<IChatCompletionService>("ollamaChat", new CustomChatCompletionService());
    //這裡我們需要插入之前定義的外掛
    builder.Plugins.AddFromType<FunctionTest>();
    var kernel = builder.Build();
    //定義一個對話歷史
    ChatHistory history = [];
    //獲取剛才定義的外掛函式的後設資料,用於後續建立prompt
    var plugins = kernel.Plugins.GetFunctionsMetadata();
    //生成函式呼叫提示詞,引導模型根據使用者請求去呼叫函式
    var functionsPrompt = CreateFunctionsMetaObject(plugins);
    //建立系統提示詞,插入剛才生成的提示詞
    var prompt = $"""
                      You have access to the following functions. Use them if required:
                      {functionsPrompt}
                      If function calls are used, ensure the output is in JSON format; otherwise, output should be in text format.
                      """;
    //新增系統提示詞
    history.AddSystemMessage(prompt);
    //建立一個對話服務例項
    var chatCompletionService = kernel.GetRequiredService<IChatCompletionService>();
    //新增使用者的提問
    history.AddUserMessage(question);
    //鏈式執行kernel
    var result = await chatCompletionService.GetChatMessageContentAsync(
        history,
        executionSettings: null,
        kernel: kernel);
    //列印回撥內容
    Console.WriteLine($"Assistant> {result}");
}
static JToken? CreateFunctionsMetaObject(IList<KernelFunctionMetadata> plugins)
{
    if (plugins.Count < 1) return null;
    if (plugins.Count == 1) return CreateFunctionMetaObject(plugins[0]);

    JArray promptFunctions = [];
    foreach (var plugin in plugins)
    {
        var pluginFunctionWrapper = CreateFunctionMetaObject(plugin);
        promptFunctions.Add(pluginFunctionWrapper);
    }

    return promptFunctions;
}
static JObject CreateFunctionMetaObject(KernelFunctionMetadata plugin)
{
    var pluginFunctionWrapper = new JObject()
        {
            { "type", "function" },
        };

    var pluginFunction = new JObject()
        {
            { "name", plugin.Name },
            { "description", plugin.Description },
        };

    var pluginFunctionParameters = new JObject()
        {
            { "type", "object" },
        };
    var pluginProperties = new JObject();
    foreach (var parameter in plugin.Parameters)
    {
        var property = new JObject()
            {
                { "type", parameter.ParameterType?.ToString() },
                { "description", parameter.Description },
            };

        pluginProperties.Add(parameter.Name, property);
    }

    pluginFunctionParameters.Add("properties", pluginProperties);
    pluginFunction.Add("parameters", pluginFunctionParameters);
    pluginFunctionWrapper.Add("function", pluginFunction);

    return pluginFunctionWrapper;
}
public class FunctionTest
{
    [KernelFunction, Description("獲取城市的天氣狀況")]
    public object GetWeather([Description("城市名稱")] string CityName, [Description("查詢時段,值可以是[白天,夜晚]")] string DayPart)
    {
        return new { CityName, DayPart, CurrentCondition = "多雲", LaterCondition = "陰", MinTemperature = 19, MaxTemperature = 23 };
    }
}

CustomChatCompletionService.cs:

using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
using Newtonsoft.Json.Linq;
using OllamaSharp;
using System.Diagnostics;
using System.Text;

namespace ConsoleApp4
{
    public class CustomChatCompletionService : IChatCompletionService
    {
        public static Dictionary<string, SearchResult> DicSearchResult = new Dictionary<string, SearchResult>();
        public static void GetDicSearchResult(Kernel kernel)
        {
            DicSearchResult = new Dictionary<string, SearchResult>();
            foreach (var functionMetaData in kernel.Plugins.GetFunctionsMetadata())
            {
                string functionName = functionMetaData.Name;
                if (DicSearchResult.ContainsKey(functionName))
                    continue;
                var searchResult = new SearchResult
                {
                    FunctionName = functionName,
                    KernelFunction = kernel.Plugins.GetFunction(null, functionName)
                };
                functionMetaData.Parameters.ToList().ForEach(x => searchResult.FunctionParams.Add(x.Name, null));
                DicSearchResult.Add(functionName, searchResult);
            }
        }
        public IReadOnlyDictionary<string, object?> Attributes => throw new NotImplementedException();
        static Stopwatch sw = new Stopwatch();
        public async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
        {
            GetDicSearchResult(kernel);
            var prompt = HistoryToText(chatHistory);
            var ollama = new OllamaApiClient("http://127.0.0.1:11434", "qwen2:7b");
            var chat = new Chat(ollama, _ => { });
            sw.Start();
            var history = (await chat.Send(prompt, CancellationToken.None)).ToArray();
            sw.Stop();
            Console.WriteLine($"呼叫耗時:{Math.Round(sw.Elapsed.TotalSeconds,2)}秒");
            var last = history.Last();
            var chatResponse = last.Content;
            try
            {
                JToken jToken = JToken.Parse(chatResponse);
                jToken = ConvertStringToJson(jToken);
                var searchs = DicSearchResult.Values.ToList();
                if (TryFindValues(jToken, ref searchs))
                {
                    var firstFunc = searchs.First();
                    var funcCallResult = await firstFunc.KernelFunction.InvokeAsync(kernel, firstFunc.FunctionParams);
                    chatHistory.AddMessage(AuthorRole.Assistant, chatResponse);
                    chatHistory.AddMessage(AuthorRole.Tool, funcCallResult.ToString());
                    return await GetChatMessageContentsAsync(chatHistory, kernel: kernel);
                }
                else
                {

                }
            }
            catch(Exception e)
            {

            }
            return new List<ChatMessageContent> { new ChatMessageContent(AuthorRole.Assistant, chatResponse) };
        }
        JToken ConvertStringToJson(JToken token)
        {
            if (token.Type == JTokenType.Object)
            {
                // 遍歷物件的每個屬性
                JObject obj = new JObject();
                foreach (JProperty prop in token.Children<JProperty>())
                {
                    obj.Add(prop.Name, ConvertStringToJson(prop.Value));
                }
                return obj;
            }
            else if (token.Type == JTokenType.Array)
            {
                // 遍歷陣列的每個元素
                JArray array = new JArray();
                foreach (JToken item in token.Children())
                {
                    array.Add(ConvertStringToJson(item));
                }
                return array;
            }
            else if (token.Type == JTokenType.String)
            {
                // 嘗試將字串解析為 JSON
                string value = token.ToString();
                try
                {
                    return JToken.Parse(value);
                }
                catch (Exception)
                {
                    // 解析失敗時返回原始字串
                    return token;
                }
            }
            else
            {
                // 其他型別直接返回
                return token;
            }
        }
        bool TryFindValues(JToken token, ref List<SearchResult> searches)
        {
            if (token.Type == JTokenType.Object)
            {
                foreach (var child in token.Children<JProperty>())
                {
                    foreach (var search in searches)
                    {
                        if (child.Value.ToString().ToLower().Equals(search.FunctionName.ToLower()) && search.SearchFunctionNameSucc != true)
                            search.SearchFunctionNameSucc = true;
                        foreach (var par in search.FunctionParams)
                        {
                            if (child.Name.ToLower().Equals(par.Key.ToLower()) && par.Value == null)
                                search.FunctionParams[par.Key] = child.Value.ToString().ToLower();
                        }
                    }
                    if (searches.Any(x => x.SearchFunctionNameSucc == false || x.FunctionParams.Any(x => x.Value == null)))
                        TryFindValues(child.Value, ref searches);
                }
            }
            else if (token.Type == JTokenType.Array)
            {
                foreach (var item in token.Children())
                {
                    if (searches.Any(x => x.SearchFunctionNameSucc == false || x.FunctionParams.Any(x => x.Value == null)))
                        TryFindValues(item, ref searches);
                }
            }
            return searches.Any(x => x.SearchFunctionNameSucc && x.FunctionParams.All(x => x.Value != null));
        }
        public virtual string HistoryToText(ChatHistory history)
        {
            StringBuilder sb = new();
            foreach (var message in history)
            {
                if (message.Role == AuthorRole.User)
                {
                    sb.AppendLine($"User: {message.Content}");
                }
                else if (message.Role == AuthorRole.System)
                {
                    sb.AppendLine($"System: {message.Content}");
                }
                else if (message.Role == AuthorRole.Assistant)
                {
                    sb.AppendLine($"Assistant: {message.Content}");
                }
                else if (message.Role == AuthorRole.Tool)
                {
                    sb.AppendLine($"Tool: {message.Content}");
                }
            }
            return sb.ToString();
        }
        public IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMessageContentsAsync(ChatHistory chatHistory, PromptExecutionSettings? executionSettings = null, Kernel? kernel = null, CancellationToken cancellationToken = default)
        {
            throw new NotImplementedException();
        }
    }
    public class SearchResult
    {
        public string FunctionName { get; set; }
        public bool SearchFunctionNameSucc { get; set; }
        public KernelArguments FunctionParams { get; set; } = new KernelArguments();
        public KernelFunction KernelFunction { get; set; }
    }
}

相關文章