找回密码
 立即注册
首页 业界区 业界 张高兴的大模型开发实战:(四)使用 LangGraph 实现多 ...

张高兴的大模型开发实战:(四)使用 LangGraph 实现多智能体应用

连热 5 天前
目录

  • 环境搭建与配置
  • 定义智能体

    • 加载模型
    • 提取关键词
    • 生成回答

  • 连接智能体

    • 定义图的状态
    • 定义节点方法

      • 根据指令路由
      • 生成回答
      • 文件处理
      • 提取关键词
      • 网络搜索

    • 定义图的结构
    • 运行图

  • 运行指南

    • 在控制台中测试程序
    • 使用 Streamlit 构建前端页面


随着大语言模型(LLM)技术的快速发展,人们期望利用 LLM 解决各种复杂问题,在此背景下,构建智能体(Agent)应用受到了广泛关注。用户与 LLM 的交互可以被视为一种 单智能体(Single-Agent) 行为:用户通过提示词(prompt)与通用 LLM 进行对话,LLM 理解问题并提供反馈。然而,单一智能体在处理复杂任务时存在明显局限性,例如需要用户多次引导、缺乏对外部环境的感知能力、对话历史记忆有限等。
试想以下场景:在不同处理阶段调用不同的模型;当 LLM 无法完成任务时,自动查询外部知识库;或者由 LLM 自主纠正生成内容中的幻觉和错误。这些需求如何实现? 多智能体(Multi-Agent) 系统正是解决这类问题的有效工具。通过提示词模板为每个智能体分配角色并规范其行为,多个智能体相互协作,从而完成复杂的任务。
然而,构建多智能体应用并非易事,开发者需要面对智能体设计、通信协议、协调策略等多方面的问题。LangGraph 提供了一种以图(graph)为核心的解决方案,清晰定义了智能体之间的关系与交互规则,并通过内置的通信接口和协调策略,帮助开发者快速构建高效且可扩展的分布式智能系统。
1.png

接下来,我们将通过一个实例展示如何使用 LangGraph 构建一个多智能体应用,并结合 Streamlit 实现用户友好的前端界面。 该应用具备以下功能:

  • 根据对话类型将请求路由到适当的处理节点。
  • 支持联网搜索,获取实时信息。
  • 根据问题和对话历史生成优化的搜索提示词。
  • 支持文件上传与处理。
  • 利用编程专用的 LLM 解决代码相关问题。
  • 基于提供的文档内容,总结生成答案。
2.png

环境搭建与配置

项目结构如下:
  1. .
  2. ├── .streamlit  # Streamlit 配置
  3. │   └── config.toml
  4. ├── chains  # 智能体
  5. │   ├── generate.py
  6. │   ├── models.py
  7. │   └── summary.py
  8. ├── graph   # 图结构
  9. │   ├── graph.py
  10. │   └── graph_state.py
  11. ├── upload_files    # 上传的文件
  12. │   └── .keep
  13. ├── .env   # 环境变量配置
  14. ├── app.py  # Streamlit 应用
  15. ├── main.py # 命令行程序
  16. └── requirements.txt    # 依赖
复制代码
requirements.txt 中列出了程序必要的包,使用命令 pip install -r requirements.txt 安装依赖。
  1. # LangChain 相关包
  2. langchain
  3. langchain-ollama
  4. langchain-chroma
  5. langchain-community
  6. langgraph
  7. chromadb
  8. tavily-python
  9. python-dotenv
  10. # 文档处理相关包
  11. marker-pdf
  12. weasyprint
  13. mammoth
  14. openpyxl
  15. unstructured[all-docs]
  16. libmagic
  17. # Streamlit 相关包
  18. streamlit
  19. streamlit-chat
  20. streamlit-extras
  21. # 文档使用GPU处理时,安装GPU版PyTorch
  22. # use 'pip install -r requirements.txt --proxy=127.0.0.1:23474' to accelerate download speed
  23. # --extra-index-url https://download.pytorch.org/whl/cu124
  24. # torch==2.6.0+cu124
  25. # torchvision==0.21.0+cu124
复制代码
相关的环境变量配置在 .env 文件中,在程序中通过 dotenv 包读取。
  1. TAVILY_API_KEY=tvly-dev-xxxxxx  # Tavily API 密钥
  2. OMP_NUM_THREADS=8   # 设置线程数
复制代码
其中 TAVILY_API_KEY 是 Tavily 的 API 密钥,用于网络搜索服务,需要在 https://app.tavily.com/home 注册并获取,每月有 1000 次的免费额度。
3.png

定义智能体

在 LangChain 中,使用 链(chain) 来定义用户与 LLM 交互的行为,即智能体。链是一个可调用的对象,接收输入并返回输出。在 chains 目录下,定义了两个链:summary.py 和 generate.py,分别用于提取关键词和生成回答。
  1. .
  2. ├── chains  # 智能体
  3. │   ├── generate.py
  4. │   ├── models.py
  5. │   └── summary.py
复制代码
加载模型

在定义智能体之前,需要先定义好加载模型的方法。models.py 文件负责根据提供的模型名称加载相应的模型。
  1. from langchain_ollama import ChatOllama, OllamaEmbeddings
  2. from langchain_core.vectorstores import InMemoryVectorStore
  3. def load_model(model_name: str) -> ChatOllama:
  4.     """
  5.     加载语言模型
  6.     参数:
  7.         model_name (str): 模型名称
  8.     返回:
  9.         ChatOllama实例,用于生成文本和回答问题
  10.     """
  11.     return ChatOllama(model=model_name)
  12.    
  13. def load_embeddings(model_name: str) -> OllamaEmbeddings:
  14.     """
  15.     加载嵌入模型
  16.     参数:
  17.         model_name (str): 模型名称
  18.     返回:
  19.         OllamaEmbeddings实例,用于将文本转换为向量表示
  20.     """
  21.     return OllamaEmbeddings(model=model_name)
  22. def load_vector_store(model_name: str) -> InMemoryVectorStore:
  23.     """
  24.     创建内存向量存储
  25.     参数:
  26.         model_name (str): 用于生成嵌入的模型名称
  27.     返回:
  28.         InMemoryVectorStore实例,用于存储和检索向量化的文本
  29.     """
  30.     embeddings = load_embeddings(model_name)
  31.     return InMemoryVectorStore(embeddings)
复制代码
提取关键词

在 summary.py 文件中,定义了 SummaryChain 类,用于从用户问题和聊天记录中提取关键词,并生成高效的搜索查询。
  1. from langchain.prompts import ChatPromptTemplate
  2. from chains.models import load_model
  3. class SummaryChain:
  4.     """
  5.     一个用于生成搜索查询的类。
  6.     它从用户问题和聊天记录中提取关键词,并生成高效的搜索查询。
  7.     """
  8.     def __init__(self, model_name):
  9.         """
  10.         初始化 SummaryChain 类,并加载指定的语言模型。
  11.         参数:
  12.             model_name (str): 要加载的语言模型的名称。
  13.         """
  14.         self.llm = load_model(model_name)
  15.         self.prompt = ChatPromptTemplate.from_template(
  16.             "You are a professional assistant specializing in extracting keywords from user questions and chat histories. Extract keywords and connect them with spaces to output a efficient and precise search query. Be careful not answer the question directly, just output the search query.\n\nHistories: {history}\n\nQuestion: {question}"
  17.         )
  18.         self.chain = self.prompt | self.llm
  19.     def invoke(self, input_data):
  20.         """
  21.         使用提供的输入数据调用链以生成搜索查询。
  22.         参数:
  23.             input_data (dict): 包含 'history' 和 'question' 键的字典。
  24.         返回:
  25.             str: 链生成的搜索查询。
  26.         """
  27.         return self.chain.invoke(input_data)
复制代码
生成回答

在 generate.py 文件中,定义了 GenerateChain 类,根据用户问题、聊天记录和文档内容生成回答。
  1. from langchain.prompts import ChatPromptTemplate
  2. from chains.models import load_model
  3. class GenerateChain:
  4.     """
  5.     一个用于生成问答任务响应的类。
  6.     它使用语言模型和提示模板来处理输入数据。
  7.     """
  8.     def __init__(self, model_name):
  9.         """
  10.         初始化 GenerateChain 类,并加载指定的语言模型。
  11.         参数:
  12.             model_name (str): 要加载的语言模型的名称。
  13.         """
  14.         self.llm = load_model(model_name)
  15.         self.prompt = ChatPromptTemplate.from_template("You are an assistant for question-answering tasks. Use the following documents or chat histories to answer the question. If the documents or chat histories is empty, answer the question based on your own knowledge. If you don't know the answer, just say that you don't know.\n\nDocuments: {documents}\n\nHistories: {history}\n\nQuestion: {question}")
  16.         self.chain = self.prompt | self.llm
  17.     def invoke(self, input_data):
  18.         """
  19.         使用提供的输入数据调用链以生成响应。
  20.         参数:
  21.             input_data (dict): 包含 'documents'、'history' 和 'question' 键的字典。
  22.         返回:
  23.             str: 链生成的响应。
  24.         """
  25.         return self.chain.invoke(input_data)
复制代码
连接智能体

在 LangGraph 中,智能体之间的连接通过 状态图(graph) 来实现,使用 状态(state) 存储交互的信息。图由节点(node)和边(edge)组成,节点表示智能体,边表示智能体之间的关系。在 graph 目录下定义了两个文件:graph.py 和 graph_state.py。
  1. .
  2. ├── graph   # 图结构
  3. │   ├── graph.py
  4. │   └── graph_state.py
复制代码
定义图的状态

在 graph_state.py 文件中,定义了 GraphState 类,用于存储图的状态信息。
  1. from typing import Literal, Annotated, Optional
  2. from typing_extensions import TypedDict
  3. from langgraph.graph.message import add_messages
  4. class GraphState(TypedDict):
  5.     """
  6.     定义图状态的类型字典。
  7.     用于表示图中的状态信息。
  8.     """
  9.     model_name: str  # 使用的模型名称
  10.     type: Literal["websearch", "file", "chat"]  # 操作类型,包括联网搜索、上传文件和聊天
  11.     messages: Annotated[list, add_messages]  # 消息列表,使用add_messages注解处理消息追加
  12.     documents: Optional[list] = []  # 文档列表,默认为空列表
复制代码
定义节点方法

在 graph.py 文件中,定义了多个方法,表示图的结构和行为,用于处理不同类型的请求。先引入所需要的包。
  1. import os
  2. from langchain.schema import Document
  3. from langchain_core.runnables import RunnableConfig
  4. from langchain_community.document_loaders import TextLoader
  5. from langchain_community.tools.tavily_search import TavilySearchResults
  6. from langchain_text_splitters import MarkdownHeaderTextSplitter, RecursiveCharacterTextSplitter
  7. from langgraph.graph.state import StateGraph, CompiledStateGraph, END
  8. from langgraph.checkpoint.memory import MemorySaver
  9. from marker.converters.pdf import PdfConverter
  10. from marker.models import create_model_dict
  11. from marker.output import text_from_rendered
  12. from graph.graph_state import GraphState
  13. from chains.summary import SummaryChain
  14. from chains.generate import GenerateChain
复制代码
根据指令路由

route_question() 方法根据 GraphState 类中的操作类型将请求路由到相应的处理节点。
  1. def route_question(state: GraphState) -> str:
  2.     """
  3.     根据操作类型路由到相应的处理节点。
  4.     参数:
  5.         state (GraphState): 当前图的状态
  6.     返回:
  7.         str: 下一个要调用的节点名称
  8.     """
  9.     print("--- ROUTE QUESTION ---")
  10.     if state['type'] == 'websearch':
  11.         print("--- ROUTE QUESTION TO EXTRACT KEYWORDS ---")
  12.         return "extract_keywords"
  13.     if state['type'] == 'file':
  14.         print("--- ROUTE QUESTION TO FILE PROCESS ---")
  15.         return "file_process"
  16.     elif state['type'] == 'chat':
  17.         print("--- ROUTE QUESTION TO GENERATE ---")
  18.         return "generate"
复制代码
当然,也可以将路由交给 LLM 决定,只需要写好相应的提示词即可,例如下面的提示词将由 LLM 决定是进行知识库查询还是网络搜索。
  1. from langchain_core.output_parsers import JsonOutputParser
  2. prompt = ChatPromptTemplate.from_template("You are an expert at routing a user question to a vectorstore or web search. Use the vectorstore for questions on LangChain and LangGraph. You do not need to be stringent with the keywords in the question related to these topics. Otherwise, use web-search. Give a binary choice 'web_search' or 'vectorstore' based on the question. Return the a JSON with a single key 'datasource' and no premable or explaination. Question to route: {question}")
  3. router = prompt | llm | JsonOutputParser()
  4. source = router.invoke({"question": question})
  5. if source['datasource'] == 'web_search':
  6.     # TODO: route to web search
  7. elif source['datasource'] == 'vectorstore':
  8.     # TODO: route to vectorstore
复制代码
生成回答

generate() 方法根据用户问题、聊天记录和文档内容生成回答。
  1. def generate(state: GraphState) -> GraphState:
  2.     """
  3.     根据文档和对话历史生成答案。
  4.     参数:
  5.         state (GraphState): 当前图的状态
  6.     返回:
  7.         state (GraphState): 返回添加了LLM生成内容的新状态
  8.     """
  9.     print("--- GENERATE ---")
  10.     chain = GenerateChain(state["model_name"])
  11.     messages = state["messages"]
  12.     state["messages"] = chain.invoke({"question": messages[-1].content, "history": messages[:-1], "documents": state["documents"]})
  13.     return state
复制代码
文件处理

file_process() 方法处理上传的文件,提取文本内容并进行词嵌入(embedding),然后将向量存储至内存数据库中。config 是一个字典,存储 LLM 运行时的配置参数,会在调用 LLM 时传入。
  1. def file_process(state: GraphState, config: RunnableConfig) -> GraphState:
  2.     """
  3.     处理文件。
  4.     参数:
  5.         state (GraphState): 当前图的状态
  6.         config (RunnableConfig): 可运行配置
  7.     返回:
  8.         state (GraphState): 返回图状态,将文档添加 config 中的向量存储
  9.     """
  10.     print("--- FILE PROCESS ---")
  11.     vector_store = config["configurable"]["vectorstore"]
  12.     for doc in state["documents"]:
  13.         file_path: str = doc.page_content
  14.         if os.path.exists(file_path):
  15.             split_docs: list[Document] = None
  16.             if file_path.endswith(".txt") or file_path.endswith(".md"):
  17.                 # 处理文本或Markdown文件
  18.                 docs = TextLoader(file_path, autodetect_encoding=True).load()
  19.                 # 文本分割
  20.                 splitter = RecursiveCharacterTextSplitter(separators=["\n\n", "\n", " ", ".", ",", "\u200B", "\uff0c", "\u3001", "\uff0e", "\u3002", ""], chunk_size=1000, chunk_overlap=100, add_start_index=True)
  21.                 split_docs = splitter.split_documents(docs)
  22.             else:
  23.                 # 使用 marker-pdf 处理其他文件
  24.                 converter = PdfConverter(artifact_dict=create_model_dict())
  25.                 rendered = converter(file_path)
  26.                 docs, _, _ = text_from_rendered(rendered)
  27.                 splitter = MarkdownHeaderTextSplitter([("#", "Header 1"), ("##", "Header 2"), ("###", "Header 3")], strip_headers = False)
  28.                 split_docs = splitter.split_text(docs)
  29.             # 将处理后的文档添加到向量存储中
  30.             vector_store.add_documents(split_docs)
  31.     return state
复制代码
提取关键词

extract_keywords() 方法从用户问题和聊天记录中提取关键词,并生成高效的搜索查询。
  1. def extract_keywords(state: GraphState, config: RunnableConfig) -> GraphState:
  2.     """
  3.     从问题中提取关键词。
  4.     参数:
  5.         state (GraphState): 当前图的状态
  6.         config (RunnableConfig): 可运行配置
  7.     返回:
  8.         state (GraphState): 返回添加了提取关键词的新状态
  9.     """
  10.     print("--- EXTRACT KEYWORDS ---")
  11.     chain = SummaryChain(state["model_name"])
  12.     messages = state["messages"]
  13.     query = chain.invoke({"question": messages[-1].content, "history": messages[:-1]})
  14.     print(query.content)
  15.     if state["type"] == "websearch":
  16.         # 将生成的搜索查询添加到消息列表中,下一个节点将会使用
  17.         state["messages"] = query
  18.     elif state["type"] == "file":
  19.         # 使用生成的搜索查询在向量数据库中搜索
  20.         docs = config["configurable"]["vectorstore"].max_marginal_relevance_search(query.content)
  21.         state["documents"] = docs
  22.     return state
复制代码
对于“上传文件”,“提取关键词”时已经进行了查询处理,可以直接进行“生成回答”;对于“网络搜索”,“提取关键词”进行搜索后,才能进行“生成回答”。执行路径不同,还需要进行判断。
  1. def decide_to_generate(state: GraphState) -> str:
  2.     """
  3.     决定是进行网络搜索还是直接生成回答。
  4.     参数:
  5.         state (GraphState): 当前图的状态
  6.     返回:
  7.         str: 下一个要调用的节点名称
  8.     """
  9.     if state["type"] == "websearch":
  10.         print("--- DECIDE TO WEB SEARCH ---")
  11.         return "websearch"
  12.     elif state["type"] == "file":
  13.         print("--- DECIDE TO GENERATE ---")
  14.         return "generate"
复制代码
网络搜索

web_search() 方法使用 Tavily API 进行网络搜索,获取实时信息。
  1. def web_search(state: GraphState) -> GraphState:
  2.     """
  3.     基于问题进行网络搜索。
  4.     参数:
  5.         state (GraphState): 当前图的状态
  6.     返回:
  7.         state (GraphState): 返回添加了网络搜索结果的新状态
  8.     """
  9.     print("--- WEB SEARCH ---")
  10.     web_search_tool = TavilySearchResults(k=3)
  11.     documents = state["documents"]
  12.     try:
  13.         docs = web_search_tool.invoke({"query": state["messages"][-1].content})
  14.         web_results = "\n".join([d["content"] for d in docs])
  15.         web_results = Document(page_content=web_results)
  16.         documents.append(web_results)
  17.         state["documents"] = documents
  18.     except:
  19.         pass
  20.     return state
复制代码
定义图的结构

在 LangGraph 中,图的边分为普通边和条件边。普通边表示两个节点之间的直接连接,而条件边则根据特定条件决定是否连接两个节点。在 create_graph() 方法中定义了图的结构:add_node() 方法将定义的节点方法添加到图中;add_edge() 方法定义了节点之间的连接关系,也就是普通边;add_conditional_edges() 方法定义了条件边的连接关系;set_conditional_entry_point() 方法定义了图的条件入口节点。
4.png
  1. def create_graph() -> CompiledStateGraph:
  2.     """
  3.     创建并配置状态图工作流。
  4.     返回:
  5.         CompiledStateGraph: 编译好的状态图
  6.     """
  7.     workflow = StateGraph(GraphState)
  8.     # 添加节点
  9.     workflow.add_node("websearch", web_search)
  10.     workflow.add_node("extract_keywords", extract_keywords)
  11.     workflow.add_node("file_process", file_process)
  12.     workflow.add_node("generate", generate)
  13.     # 添加边
  14.     workflow.set_conditional_entry_point(
  15.         route_question,
  16.         {
  17.             "extract_keywords": "extract_keywords",
  18.             "generate": "generate",
  19.             "file_process": "file_process",
  20.         },
  21.     )
  22.     workflow.add_edge("file_process", "extract_keywords")
  23.     workflow.add_conditional_edges(
  24.         "extract_keywords",
  25.         decide_to_generate,
  26.         {
  27.             "websearch": "websearch",
  28.             "generate": "generate",
  29.         },
  30.     )
  31.     workflow.add_edge("websearch", "generate")
  32.     workflow.add_edge("generate", END)
  33.     # 创建图,并使用 `MemorySaver()` 在内存中保存状态
  34.     return workflow.compile(checkpointer=MemorySaver())
复制代码
运行图

最后通过 stream_graph_updates() 方法运行图,并流式返回结果内容。
  1. def stream_graph_updates(graph: CompiledStateGraph, user_input: GraphState, config: dict):
  2.     """
  3.     流式处理图更新并返回最终结果。
  4.     参数:
  5.         graph (CompiledStateGraph): 编译好的状态图
  6.         user_input (GraphState): 用户输入的状态
  7.         config (dict): 配置字典
  8.     返回:
  9.         generator: 生成器对象,逐步返回图更新的内容
  10.     """
  11.     for chunk, _ in graph.stream(user_input, config, stream_mode="messages"):
  12.         yield chunk.content
复制代码
运行指南

在控制台中测试程序

在 main.py 文件中,定义了一个命令行程序,用户可以通过输入问题与智能体进行交互。
  1. import uuid
  2. from dotenv import load_dotenv
  3. from langchain.schema import Document
  4. from langchain_core.messages import AIMessage, HumanMessage
  5. from chains.models import load_vector_store
  6. from graph.graph import create_graph, stream_graph_updates, GraphState
  7. def main():
  8.     # langchain.debug = True  # 启用langchain调试模式,可以获得如完整提示词等信息
  9.     load_dotenv(verbose=True)  # 加载环境变量配置
  10.     # 创建状态图以及对话相关的设置
  11.     config = {"configurable": {"thread_id": uuid.uuid4().hex, "vectorstore": load_vector_store("nomic-embed-text")}}  
  12.     state = GraphState(
  13.         model_name="qwen2.5:7b",
  14.         type="chat",
  15.         documents=[Document(page_content="upload_files/test.pdf")],
  16.     )
  17.     graph = create_graph()
  18.     # 对话
  19.     while True:
  20.         user_input = input("User: ")
  21.         if user_input.lower() in ["exit", "quit"]:
  22.             break
  23.         state["messages"] = HumanMessage(user_input)
  24.         # 流式获取AI的回复
  25.         for answer in stream_graph_updates(graph, state, config):
  26.             print(answer, end="")
  27.         print()
  28.     # 打印对话历史
  29.     print("\nHistory: ")
  30.     for message in graph.get_state(config).values["messages"]:
  31.         if isinstance(message, AIMessage):
  32.             prefix = "AI"
  33.         else:
  34.             prefix = "User"
  35.         print(f"{prefix}: {message.content}")
  36. if __name__ == "__main__":
  37.     main()
复制代码
使用 Streamlit 构建前端页面

在 app.py 文件中,使用 Streamlit 构建了一个简单的前端界面,用户可以通过输入框与智能体进行交互。完整的程序代码如下:
5.png

[code]import uuidimport datetimefrom dotenv import load_dotenvfrom langchain.schema import Documentimport streamlit as stfrom streamlit_extras.bottom_container import bottomfrom chains.models import load_vector_storefrom graph.graph import create_graph, stream_graph_updates, GraphState# 设置上传文件的存储路径file_path = "upload_files/"# 加载环境变量load_dotenv(verbose=True)def upload_pdf(file):    """保存上传的文件并返回文件路径"""    with open(file_path + file.name, "wb") as f:        f.write(file.getbuffer())        return file_path + file.name# 设置页面配置信息st.set_page_config(    page_title="AI-Powerwd Assistant",    page_icon="
来源:程序园用户自行投稿发布,如果侵权,请联系站长删除
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!
您需要登录后才可以回帖 登录 | 立即注册