該模組可以讓我們向LLM提問時從資料庫中查詢資料並做出回答。
架構
SQL chain和agent的高層抽象架構:
- 問題轉查詢SQL:使用LLM將使用者輸入轉成SQL查詢
- 執行SQL查詢:執行SQL語句
- 回答問題:LLM根據資料庫查詢結果返回回答內容
環境安裝
安裝必要環境和包
pip install --upgrade --quiet langchain langchain-community langchain-openai
在本文章中預設使用openai的模型
import getpass
import os
os.environ["OPENAI_API_KEY"] = getpass.getpass()
# Uncomment the below to use LangSmith. Not required.
# os.environ["LANGCHAIN_API_KEY"] = getpass.getpass()
# os.environ["LANGCHAIN_TRACING_V2"] = "true"
下面的例子均使用SQLite連線Chinook資料庫,按照這些安裝步驟建立一個Chinook.db檔案在當前目錄下
- 將該檔案 儲存成
Chinook_Sqlite.sql
- 執行
sqlite3 Chinook.db
- 執行
.read Chinook_Sqlite.sql
- 測試
SELECT * FROM Artist LIMIT 10;
現在,Chinhook.db
已經在我們的目錄中,我們可以使用基於SQLAlchemy的 SQLDatabase 類與它進行互動:
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM Artist LIMIT 10;")
sqlite
['Album', 'Artist', 'Customer', 'Employee', 'Genre', 'Invoice', 'InvoiceLine', 'MediaType', 'Playlist', 'PlaylistTrack', 'Track']
"[(1, 'AC/DC'), (2, 'Accept'), (3, 'Aerosmith'), (4, 'Alanis Morissette'), (5, 'Alice In Chains'), (6, 'Antônio Carlos Jobim'), (7, 'Apocalyptica'), (8, 'Audioslave'), (9, 'BackBeat'), (10, 'Billy Cobham')]"
Chain
讓我們建立一個簡單的Chain,它接收一個問題,將其轉換成SQL查詢,執行該查詢,並將結果用於回答原始問題。
將問題轉換為SQL查詢
SQL鏈條或代理的第一步是接收使用者輸入並將其轉換為SQL查詢。LangChain內建了用於此目的的鏈條:create_sql_query_chain
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0)
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employees are there"})
response
'SELECT COUNT(*) FROM Employee'
db.run(response)
'[(8,)]'
我們可以檢視 LangSmith 的跟蹤以更好地理解這個鏈在做什麼。我們也可以直接檢查鏈的提示。從下面的提示中,我們可以看到它是:
方言特定的。在這種情況下,它明確引用了 SQLite。
具有所有可用表的定義。
每個表都有三個示例行。
這種技術受到了像這樣的論文的啟發,這些論文建議展示示例行並明確表格可以提高效能。我們也可以這樣檢查完整的提示:
回答問題
現在我們已經有了自動生成和執行查詢的方法,我們只需要將原始問題和SQL查詢結果結合起來生成最終答案。我們可以再次將問題和結果傳遞給LLM來完成這個過程。
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from langchain_core.runnables import RunnablePassthrough
answer_prompt = PromptTemplate.from_template(
"""Given the following user question, corresponding SQL query, and SQL result, answer the user question.
Question: {question}
SQL Query: {query}
SQL Result: {result}
Answer: """
)
answer = answer_prompt | llm | StrOutputParser()
chain = (
RunnablePassthrough.assign(query=write_query).assign(
result=itemgetter("query") | execute_query
)
| answer
)
chain.invoke({"question": "How many employees are there"})
'There are 8 employees.'
下一步
對於更復雜的查詢生成,我們可能需要建立少樣本提示或新增查詢檢查步驟。要了解更多關於此類高階技術,請檢視以下內容:
- 提示策略:高階提示工程技術。
- 查詢檢查:新增查詢驗證和錯誤處理。
- 大型資料庫:處理大型資料庫的技術。
Agent
LangChain的SQL Agent是一個設計用於與SQL資料庫進行更靈活互動的工具。它的主要優點包括:
-
基於資料庫模式和內容回答問題:SQL Agent不僅能理解並回答有關資料庫模式的問題(例如描述一個特定表的結構),還能基於資料庫的實際內容回答問題。
-
錯誤恢復能力:透過執行生成的查詢,如果遇到錯誤,它可以捕獲錯誤資訊(traceback),然後正確地重新生成查詢,從而實現從錯誤中恢復的能力。
-
處理依賴多個查詢的問題:對於需要多個步驟和依賴不同查詢結果的問題,SQL Agent能夠有效地管理和執行這些依賴的查詢。
-
節省令牌(tokens):透過僅考慮相關表的模式,SQL Agent能夠更有效地使用令牌,減少不必要的令牌消耗。
使用create_sql_agent
函式可以初始化這個代理。這個代理包含了SQLDatabaseToolkit
,它提供了一系列工具來:
- 建立和執行查詢
- 檢查查詢語法
- 檢索表描述
- ...等等
這意味著,透過使用SQL Agent和內建的SQLDatabaseToolkit,開發者能夠更加靈活和高效地與SQL資料庫進行互動,無論是執行復雜的查詢,還是處理和維護資料庫,都能夠得到簡化和加速。這對於需要頻繁與資料庫互動的應用程式來說是一個很大的優勢,尤其是在需要動態生成查詢語句或處理複雜資料庫互動邏輯的時候。
初始化Agent
from langchain_community.agent_toolkits import create_sql_agent
agent_executor = create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=True)
agent_executor.invoke(
{
"input": "List the total sales per country. Which country's customers spent the most?"
}
)
> Entering new AgentExecutor chain...
Invoking: `sql_db_list_tables` with `{}`
Album, Artist, Customer, Employee, Genre, Invoice, InvoiceLine, MediaType, Playlist, PlaylistTrack, Track
Invoking: `sql_db_schema` with `Invoice,Customer`
CREATE TABLE "Customer" (
"CustomerId" INTEGER NOT NULL,
"FirstName" NVARCHAR(40) NOT NULL,
"LastName" NVARCHAR(20) NOT NULL,
"Company" NVARCHAR(80),
"Address" NVARCHAR(70),
"City" NVARCHAR(40),
"State" NVARCHAR(40),
"Country" NVARCHAR(40),
"PostalCode" NVARCHAR(10),
"Phone" NVARCHAR(24),
"Fax" NVARCHAR(24),
"Email" NVARCHAR(60) NOT NULL,
"SupportRepId" INTEGER,
PRIMARY KEY ("CustomerId"),
FOREIGN KEY("SupportRepId") REFERENCES "Employee" ("EmployeeId")
)
/*
3 rows from Customer table:
CustomerId FirstName LastName Company Address City State Country PostalCode Phone Fax Email SupportRepId
1 Luís Gonçalves Embraer - Empresa Brasileira de Aeronáutica S.A. Av. Brigadeiro Faria Lima, 2170 São José dos Campos SP Brazil 12227-000 +55 (12) 3923-5555 +55 (12) 3923-5566 luisg@embraer.com.br 3
2 Leonie Köhler None Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 +49 0711 2842222 None leonekohler@surfeu.de 5
3 François Tremblay None 1498 rue Bélanger Montréal QC Canada H2G 1A7 +1 (514) 721-4711 None ftremblay@gmail.com 3
*/
CREATE TABLE "Invoice" (
"InvoiceId" INTEGER NOT NULL,
"CustomerId" INTEGER NOT NULL,
"InvoiceDate" DATETIME NOT NULL,
"BillingAddress" NVARCHAR(70),
"BillingCity" NVARCHAR(40),
"BillingState" NVARCHAR(40),
"BillingCountry" NVARCHAR(40),
"BillingPostalCode" NVARCHAR(10),
"Total" NUMERIC(10, 2) NOT NULL,
PRIMARY KEY ("InvoiceId"),
FOREIGN KEY("CustomerId") REFERENCES "Customer" ("CustomerId")
)
/*
3 rows from Invoice table:
InvoiceId CustomerId InvoiceDate BillingAddress BillingCity BillingState BillingCountry BillingPostalCode Total
1 2 2009-01-01 00:00:00 Theodor-Heuss-Straße 34 Stuttgart None Germany 70174 1.98
2 4 2009-01-02 00:00:00 Ullevålsveien 14 Oslo None Norway 0171 3.96
3 8 2009-01-03 00:00:00 Grétrystraat 63 Brussels None Belgium 1000 5.94
*/
Invoking: `sql_db_query` with `SELECT c.Country, SUM(i.Total) AS TotalSales FROM Invoice i JOIN Customer c ON i.CustomerId = c.CustomerId GROUP BY c.Country ORDER BY TotalSales DESC LIMIT 10;`
responded: To list the total sales per country, I can query the "Invoice" and "Customer" tables. I will join these tables on the "CustomerId" column and group the results by the "BillingCountry" column. Then, I will calculate the sum of the "Total" column to get the total sales per country. Finally, I will order the results in descending order of the total sales.
Here is the SQL query:
```sql
SELECT c.Country, SUM(i.Total) AS TotalSales
FROM Invoice i
JOIN Customer c ON i.CustomerId = c.CustomerId
GROUP BY c.Country
ORDER BY TotalSales DESC
LIMIT 10;
```
Now, I will execute this query to get the total sales per country.
[('USA', 523.0600000000003), ('Canada', 303.9599999999999), ('France', 195.09999999999994), ('Brazil', 190.09999999999997), ('Germany', 156.48), ('United Kingdom', 112.85999999999999), ('Czech Republic', 90.24000000000001), ('Portugal', 77.23999999999998), ('India', 75.25999999999999), ('Chile', 46.62)]The total sales per country are as follows:
1. USA: $523.06
2. Canada: $303.96
3. France: $195.10
4. Brazil: $190.10
5. Germany: $156.48
6. United Kingdom: $112.86
7. Czech Republic: $90.24
8. Portugal: $77.24
9. India: $75.26
10. Chile: $46.62
To answer the second question, the country whose customers spent the most is the USA, with a total sales of $523.06.
> Finished chain.
Qwen1.5連線Hive並查表示例
需要提前安裝pyhive環境以及xinference
from langchain_community.utilities import SQLDatabase
db = SQLDatabase.from_uri("hive://user:password@hive-server.example.com:10000/mydatabase")
print(db.dialect)
print(db.get_usable_table_names())
db.run("SELECT * FROM dim_static.dim_loan_part_number_df LIMIT 10")
xinference啟動Qwen1.5
from langchain.chains import create_sql_query_chain
from langchain_openai import ChatOpenAI
from langchain.llms import Xinference
from langchain.prompts import PromptTemplate
llm = Xinference(server_url="http://localhost:9997", model_uid="Qwen1.5-14B-Chat",temperature=0.3)
chain = create_sql_query_chain(llm, db)
response = chain.invoke({"question": "How many employer are there"})
response
db.run(response)
參考文獻
https://python.langchain.com/docs/use_cases/sql/quickstart