1.3 节点(Node)与边(Edge)编程 本节导读:通过本节学习,你将掌握LangGraph中节点和边的编程技巧,学会设计高效的执行单元和灵活的流转路径,能够构建出复杂的智能体工作流。 学习目标 理解节点和边的概念与作用 掌握节点的创建方法和编程模式 学会使用不同类型的边(普通边、条件边、跳跃边) 理解节点间的状态传递机制 能够设计高效的节点协作模式 核心概念 节点(Nodes) 节点是LangGraph中的基本执行单元,每个节点包含具体的业务逻辑。节点的设计质量直接影响整个智能体的性能和可维护性。 边(Edges) 边定义了节点间的执行路径,控制数据在节点间的流转。
本节导读:通过本节学习,你将掌握LangGraph中节点和边的编程技巧,学会设计高效的执行单元和灵活的流转路径,能够构建出复杂的智能体工作流。
节点是LangGraph中的基本执行单元,每个节点包含具体的业务逻辑。节点的设计质量直接影响整个智能体的性能和可维护性。
from typing import Dict, Any from langgraph.graph import StateGraph # 节点基本结构 def business_node(state: Dict[str, Any]) -> Dict[str, Any]: """ 业务节点函数 Args: state: 当前状态,包含所有必要的数据 Returns: 更新后的状态字典 """ # 业务逻辑处理 result = process_business_logic(state) # 返回状态更新 return {"updated_field": result} # 节点设计原则 """ 1. 单一职责:每个节点只做一件事 2. 无状态设计:节点不依赖外部状态 3. 幂等性:相同输入产生相同输出 4. 错误处理:节点内部处理异常 """
边定义了节点间的执行路径,控制数据在节点间的流转。LangGraph支持多种边类型:
pip install -U langgraph langchain langchain-openai
from langgraph.graph import StateGraph, START, END from typing import TypedDict, Annotated, Sequence, List, Dict from langgraph.graph import add_messages import time import uuid # 定义状态 class ProcessingState(TypedDict): messages: Annotated[Sequence[dict], add_messages] data: List[Dict] current_step: str results: Dict metadata: Dict # 基础节点模式 def data_input_node(state: ProcessingState) -> ProcessingState: """数据输入节点""" print("📥 开始接收数据...") # 模拟数据输入 input_data = [ {"id": 1, "name": "项目A", "status": "pending"}, {"id": 2, "name": "项目B", "status": "pending"}, {"id": 3, "name": "项目C", "status": "pending"} ] return { "messages": [{"role": "assistant", "content": f"收到 {len(input_data)} 个项目数据"}], "data": input_data, "current_step": "input_complete", "results": {}, "metadata": {"input_count": len(input_data)} } def validation_node(state: ProcessingState) -> ProcessingState: """数据验证节点""" print("🔍 验证数据格式...") data = state["data"] validated_data = [] errors = [] for item in data: # 验证必填字段 if not item.get("id"): errors.append(f"项目缺少ID: {item}") continue if not item.get("name"): errors.append(f"项目缺少名称: {item}") continue # 验证ID唯一性 if any(d["id"] == item["id"] for d in validated_data): errors.append(f"重复ID: {item['id']}") continue validated_data.append(item) if errors: return { "messages": [{"role": "assistant", "content": f"验证发现 {len(errors)} 个错误"}], "data": validated_data, "current_step": "validation_error", "results": {"errors": errors}, "metadata": {"validation_passed": False} } return { "messages": [{"role": "assistant", "content": f"验证通过,共 {len(validated_data)} 个有效项目"}], "data": validated_data, "current_step": "validation_complete", "results": {"validated_count": len(validated_data)}, "metadata": {"validation_passed": True} } def processing_node(state: ProcessingState) -> ProcessingState: """数据处理节点""" print("⚙️ 处理数据...") data = state["data"] results = {} for item in data: # 模拟数据处理 processed_item = { "original": item, "processed": True, "timestamp": time.time(), "processor_id": str(uuid.uuid4())[:8] } results[item["id"]] = processed_item time.sleep(0.1) # 模拟处理时间 return { "messages": [{"role": "assistant", "content": f"数据处理完成,共处理 {len(data)} 个项目"}], "current_step": "processing_complete", "results": results, "metadata": {"processing_time": time.time()} } def output_node(state: ProcessingState) -> ProcessingState: """数据输出节点""" print("📤 输出处理结果...") results = state["results"] output_data = [] for item_id, result in results.items(): output_data.append({ "project_id": item_id, "status": "completed", "processed_at": result["timestamp"], "processor": result["processor_id"] }) return { "messages": [{"role": "assistant", "content": "处理结果已生成完成"}], "current_step": "output_complete", "results": {"output": output_data}, "metadata": {"output_count": len(output_data)} } # 构建基础图 basic_graph = StateGraph(ProcessingState) basic_graph.add_node("data_input", data_input_node) basic_graph.add_node("validation", validation_node) basic_graph.add_node("processing", processing_node) basic_graph.add_node("output", output_node) # 设置普通边 basic_graph.add_edge(START, "data_input") basic_graph.add_edge("data_input", "validation") basic_graph.add_edge("validation", "processing") basic_graph.add_edge("processing", "output") basic_graph.add_edge("output", END) # 编译并测试 compiled_basic_graph = basic_graph.compile() # 执行处理流程 initial_state = { "messages": [], "data": [], "current_step": "start", "results": {}, "metadata": {} } result = compiled_basic_graph.invoke(initial_state) print("处理完成,最终状态:", result["current_step"])
def router_node(state: ProcessingState) -> ProcessingState: """路由判断节点""" data_count = len(state["data"]) if data_count == 0: return { "messages": [{"role": "assistant", "content": "数据为空,直接结束"}], "current_step": "empty_data" } elif data_count < 3: return { "messages": [{"role": "assistant", "content": "数据量较小,使用快速处理"}], "current_step": "small_data" } else: return { "messages": [{"role": "assistant", "content": "数据量较大,使用标准处理流程"}], "current_step": "large_data" } def fast_processing_node(state: ProcessingState) -> ProcessingState: """快速处理节点""" data = state["data"] processed = [f"快速处理: {item['name']}" for item in data] return { "messages": [{"role": "assistant", "content": "快速处理完成"}], "current_step": "fast_complete", "results": {"fast_processed": processed} } def standard_processing_node(state: ProcessingState) -> ProcessingState: """标准处理节点""" data = state["data"] processed = [] for item in data: # 模拟详细处理 detail = f"详细处理: {item['name']} (ID: {item['id']})" processed.append(detail) time.sleep(0.2) return { "messages": [{"role": "assistant", "content": "标准处理完成"}], "current_step": "standard_complete", "results": {"standard_processed": processed} } # 构建条件路由图 conditional_graph = StateGraph(ProcessingState) conditional_graph.add_node("router", router_node) conditional_graph.add_node("fast_processing", fast_processing_node) conditional_graph.add_node("standard_processing", standard_processing_node) conditional_graph.add_node("end", lambda state: state) # 设置条件边 conditional_graph.add_edge(START, "router") conditional_graph.add_conditional_edges( "router", lambda state: state["current_step"], { "empty_data": "end", "small_data": "fast_processing", "large_data": "standard_processing" } ) conditional_graph.add_edge("fast_processing", "end") conditional_graph.add_edge("standard_processing", "end") # 测试不同数据量 test_cases = [ {"name": "空数据", "data": []}, {"name": "小数据", "data": [{"id": 1, "name": "项目A"}]}, {"name": "大数据", "data": [ {"id": 1, "name": "项目A"}, {"id": 2, "name": "项目B"}, {"id": 3, "name": "项目C"}, {"id": 4, "name": "项目D"} ]} ] for test_case in test_cases: print(f"\n--- 测试{test_case['name']} ---") result = compiled_basic_graph.invoke({ "messages": [], "data": test_case["data"], "current_step": "start", "results": {}, "metadata": {} }) print(f"处理步骤: {result['current_step']}")
# 1. 异步节点 import asyncio from typing import Coroutine async def async_data_fetch(state: ProcessingState) -> ProcessingState: """异步数据获取节点""" print("🔄 异步获取数据...") # 模拟异步API调用 await asyncio.sleep(1) external_data = [ {"source": "api_1", "data": "外部数据1"}, {"source": "api_2", "data": "外部数据2"} ] return { "messages": [{"role": "assistant", "content": f"异步获取完成,共 {len(external_data)} 条数据"}], "current_step": "async_fetch_complete", "results": {"external_data": external_data} } # 2. 批处理节点 def batch_processing_node(state: ProcessingState) -> ProcessingState: """批处理节点""" data = state["data"] batch_size = 2 # 每批处理2个 batches = [data[i:i + batch_size] for i in range(0, len(data), batch_size)] batch_results = [] for i, batch in enumerate(batches): print(f"📦 处理第 {i+1} 批,包含 {len(batch)} 个项目") batch_result = [] for item in batch: processed = f"批处理: {item['name']} (批次{i+1})" batch_result.append(processed) batch_results.append({ "batch_number": i + 1, "items": batch_result, "size": len(batch) }) return { "messages": [{"role": "assistant", "content": f"批处理完成,共处理 {len(batches)} 批"}], "current_step": "batch_complete", "results": {"batch_results": batch_results} } # 3. 重试节点 def retry_node(state: ProcessingState) -> ProcessingState: """带重试的节点""" max_retries = 3 retry_count = state.get("retry_count", 0) print(f"🔄 重试第 {retry_count + 1} 次...") # 模拟可能失败的操作 import random if random.random() < 0.7: # 70%失败率 if retry_count < max_retries: return { "messages": [{"role": "assistant", "content": "操作失败,准备重试"}], "current_step": "retry", "retry_count": retry_count + 1 } else: return { "messages": [{"role": "assistant", "content": "重试次数已达上限"}], "current_step": "retry_failed", "retry_count": retry_count } return { "messages": [{"role": "assistant", "content": "操作成功完成"}], "current_step": "retry_success", "retry_count": retry_count } # 构建高级节点图 advanced_graph = StateGraph(ProcessingState) advanced_graph.add_node("data_input", data_input_node) advanced_graph.add_node("batch_processing", batch_processing_node) advanced_graph.add_node("retry_node", retry_node) advanced_graph.add_node("end", lambda state: state) # 设置边 advanced_graph.add_edge(START, "data_input") advanced_graph.add_edge("data_input", "batch_processing") advanced_graph.add_edge("batch_processing", "retry_node") # 条件重试逻辑 advanced_graph.add_conditional_edges( "retry_node", lambda state: "retry_node" if state["current_step"] == "retry" else "end", {"retry_node": "retry_node", "end": "end"} ) compiled_advanced_graph = advanced_graph.compile() # 测试批处理和重试 test_data = [ {"id": 1, "name": "项目A"}, {"id": 2, "name": "项目B"}, {"id": 3, "name": "项目C"}, {"id": 4, "name": "项目D"}, {"id": 5, "name": "项目E"} ] result = compiled_advanced_graph.invoke({ "messages": [], "data": test_data, "current_step": "start", "results": {}, "metadata": {}, "retry_count": 0 }) print(f"最终状态: {result['current_step']}")
class ComplexState(TypedDict): messages: Annotated[Sequence[dict], add_messages] primary_data: List[Dict] secondary_data: List[Dict] processing_results: Dict coordination_metadata: Dict def parallel_processing_node(state: ComplexState) -> ComplexState: """并行处理节点""" print("🚀 开始并行处理...") primary_data = state["primary_data"] secondary_data = state["secondary_data"] # 模拟并行处理 primary_results = [f"主处理: {item['name']}" for item in primary_data] secondary_results = [f"副处理: {item['name']}" for item in secondary_data] return { "messages": [{"role": "assistant", "content": "并行处理完成"}], "processing_results": { "primary": primary_results, "secondary": secondary_results }, "coordination_metadata": { "primary_count": len(primary_data), "secondary_count": len(secondary_data), "total_processed": len(primary_data) + len(secondary_data) } } def merge_results_node(state: ComplexState) -> ComplexState: """结果合并节点""" print("🔄 合并处理结果...") primary_results = state["processing_results"].get("primary", []) secondary_results = state["processing_results"].get("secondary", []) merged_results = primary_results + secondary_results # 排序合并结果 sorted_results = sorted(merged_results, key=lambda x: x.split(":")[1]) return { "messages": [{"role": "assistant", "content": f"结果合并完成,共 {len(sorted_results)} 条"}], "processing_results": {"merged": sorted_results}, "coordination_metadata": {"merged_count": len(sorted_results)} } def quality_check_node(state: ComplexState) -> ComplexState: """质量检查节点""" print("🔍 执行质量检查...") merged_results = state["processing_results"].get("merged", []) quality_issues = [] for i, result in enumerate(merged_results): # 模拟质量检查 if "主处理" in result and "副处理" in result: quality_issues.append(f"结果{i+1}: 混合处理类型") elif len(result) < 3: quality_issues.append(f"结果{i+1}: 内容过短") if quality_issues: return { "messages": [{"role": "assistant", "content": f"质量检查发现问题: {len(quality_issues)} 个"}], "coordination_metadata": { "quality_passed": False, "issues": quality_issues } } return { "messages": [{"role": "assistant", "content": "质量检查通过"}], "coordination_metadata": { "quality_passed": True, "issues": [] } } # 构建复杂协作图 complex_graph = StateGraph(ComplexState) complex_graph.add_node("parallel_processing", parallel_processing_node) complex_graph.add_node("merge_results", merge_results_node) complex_graph.add_node("quality_check", quality_check_node) complex_graph.add_node("end", lambda state: state) # 设置复杂边关系 complex_graph.add_edge(START, "parallel_processing") complex_graph.add_edge("parallel_processing", "merge_results") complex_graph.add_edge("merge_results", "quality_check") # 质量检查后的条件路由 complex_graph.add_conditional_edges( "quality_check", lambda state: "end" if state["coordination_metadata"].get("quality_passed", False) else "parallel_processing", {"end": "end", "parallel_processing": "parallel_processing"} ) compiled_complex_graph = complex_graph.compile() # 测试复杂协作 test_complex_state = { "messages": [], "primary_data": [ {"id": 1, "name": "主要任务A"}, {"id": 2, "name": "主要任务B"} ], "secondary_data": [ {"id": 3, "name": "辅助任务X"}, {"id": 4, "name": "辅助任务Y"}, {"id": 5, "name": "辅助任务Z"} ], "processing_results": {}, "coordination_metadata": {} } result = compiled_complex_graph.invoke(test_complex_state) print(f"最终质量状态: {result['coordination_metadata'].get('quality_passed', 'unknown')}")
# 使用缓存避免重复计算 from functools import lru_cache @lru_cache(maxsize=128) def expensive_computation(data: str) -> str: """耗时的计算使用缓存""" return complex_processing(data) # 批量处理优化 def batch_optimized_node(state: ProcessingState) -> ProcessingState: """批量优化节点""" data = state["data"] # 批量处理,减少函数调用开销 batch_results = [] for batch in chunks(data, 10): # 每批10个 batch_result = [process_item(item) for item in batch] batch_results.extend(batch_result) return {"results": batch_results}
# 使用高效的路径判断 def optimized_router(state: ProcessingState) -> str: """优化的路由判断""" data_count = len(state["data"]) # 使用二分查找的思想进行路由 if data_count == 0: return "empty_path" elif data_count <= 5: return "fast_path" else: return "standard_path"
通过本节学习,我们深入掌握了LangGraph中节点和边的编程技巧。节点作为基本执行单元,需要遵循单一职责、无状态设计等原则;边作为连接路径,支持普通边、条件边、跳跃边等多种类型,能够构建出灵活的执行流程。
熟练掌握节点和边的编程,是构建复杂智能体工作流的基础。下一节我们将进入第2章,学习LangGraph的状态管理机制。
关键词:节点, 边, 条件路由, 并行处理, 批处理, 重试机制
难度:高级
预计阅读:35分钟