注:本文使用 LangChain v1.0+

Custom 中间件的两种风格

Custom 中间件提供两种方式来拦截和修改 Agent 行为:

1️⃣ Node-Style Hooks(节点风格)

顺序执行在特定执行点,用于日志记录、验证、状态更新。

import os
from langchain_openai import ChatOpenAI
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_model, after_model
from langchain.tools import tool
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import TypedDict

model = ChatOpenAI(
    model=os.getenv("MODEL_NAME", "Qwen/Qwen2-7B-Instruct"),
    temperature=0.7,
    base_url=os.getenv("SILICONFLOW_BASE_URL"),
    api_key=os.getenv("SILICONFLOW_API_KEY")
)

class Context(TypedDict):
    user_id: str
    request_id: str

# ===== 钩子 1: 模型调用前 =====
@before_model
def log_request_info(state: AgentState, runtime: Runtime[Context]) -> dict | None:
    """记录请求信息(节点风格)"""
    user_id = runtime.context.user_id
    msg_count = len(state.get("messages", []))
    
    print(f" [用户 {user_id}] 即将调用模型")
    print(f"   消息数量: {msg_count}")
    
    # 不修改状态,返回 None
    return None

# ===== 钩子 2: 模型调用后 =====
@after_model
def log_response_info(state: AgentState, runtime: Runtime[Context]) -> dict | None:
    """记录响应信息(节点风格)"""
    last_msg = state["messages"][-1]
    
    if isinstance(last_msg, AIMessage):
        print(f" 模型已响应,长度: {len(last_msg.content)}")
    
    return None

# ===== 创建 Agent =====
agent = create_agent(
    model=model,
    tools=[],
    middleware=[log_request_info, log_response_info],
    context_schema=Context
)

result = agent.invoke(
    {"messages": [{"role": "user", "content": "你好"}]},
    context={"user_id": "user123", "request_id": "req_001"}
)

2️⃣ Wrap-Style Hooks(环绕风格)

环绕执行,你控制 handler 何时被调用。用于重试、缓存、动态修改。

from langchain.agents.middleware import before_model
from langchain.agents import create_agent

# ===== Wrap 风格中间件:重试逻辑 =====
def create_retry_middleware(max_retries: int = 3):
    """创建重试中间件"""
    
    # 这里用 before_model 演示,但真实场景中应该用 wrap 风格
    @before_model
    def retry_on_error(state: AgentState, runtime: Runtime[Context]) -> dict | None:
        # 在这个例子中,我们可以使用状态追踪重试次数
        # 实际的 wrap 风格需要使用特殊的装饰器
        return None
    
    return retry_on_error

agent = create_agent(
    model=model,
    tools=[],
    middleware=[create_retry_middleware(max_retries=3)],
    context_schema=Context
)

更完整的 Wrap 风格示例:

from langchain.agents.middleware import before_model

def create_model_retry_middleware(max_retries: int = 3):
    """
    使用状态跟踪来实现重试逻辑
    """
    @before_model
    def handle_retry(state: AgentState, runtime: Runtime[Context]) -> dict | None:
        # 检查重试计数
        retry_count = state.get("retry_count", 0)
        
        if retry_count >= max_retries:
            return {
                "messages": [AIMessage(content="已达到最大重试次数")],
                "jumpTo": "end"  # 跳到结束
            }
        
        return None
    
    return handle_retry

修改状态的中间件

中间件可以返回一个字典来修改状态:

from typing import TypedDict
from langchain.agents import AgentState, create_agent
from langchain.agents.middleware import before_model, after_model
from langchain.messages import AIMessage
from langgraph.runtime import Runtime

class Context(TypedDict):
    user_id: str

# ===== 例子 1: 在模型调用前修改消息 =====
@before_model
def trim_long_messages(state: AgentState, runtime: Runtime[Context]) -> dict | None:
    """如果消息过多,删除旧消息"""
    messages = state.get("messages", [])
    
    # 超过 100 条消息则只保留最近 50 条
    if len(messages) > 100:
        print(f"️  消息过多 ({len(messages)}),正在修剪...")
        trimmed_messages = messages[-50:]  # 只保留最后 50 条
        
        return {"messages": trimmed_messages}
    
    return None

# ===== 例子 2: 在模型响应后替换内容 =====
@after_model
def filter_sensitive_content(state: AgentState, runtime: Runtime[Context]) -> dict | None:
    """过滤敏感内容"""
    last_msg = state["messages"][-1]
    
    if isinstance(last_msg, AIMessage):
        content = last_msg.content.lower()
        
        # 检测敏感词
        if "password" in content or "api_key" in content:
            print(" 检测到敏感内容,正在替换...")
            return {
                "messages": [AIMessage(content="无法显示该内容。")]
            }
    
    return None

agent = create_agent(
    model=model,
    tools=[],
    middleware=[trim_long_messages, filter_sensitive_content],
    context_schema=Context
)

使用 jumpTo 控制流程

中间件可以用 jumpTo 提前结束或跳过执行:

from langchain.agents.middleware import before_model

@before_model
def rate_limit_check(state: AgentState, runtime: Runtime[Context]) -> dict | None:
    """检查速率限制"""
    msg_count = len(state.get("messages", []))
    
    # 如果消息太多,直接返回错误并结束
    if msg_count > 1000:
        return {
            "messages": [AIMessage(content="已达到请求限制")],
            "jumpTo": "end"  # 直接跳到 agent 结束
        }
    
    return None

agent = create_agent(
    model=model,
    tools=[],
    middleware=[rate_limit_check],
    context_schema=Context
)

扩展状态 Schema

中间件可以为 Agent 状态添加自定义字段:

from langchain.agents import AgentState, create_agent
from typing import TypedDict
from langchain.agents.middleware import before_model, after_model

# 创建扩展的状态类
class ExtendedAgentState(AgentState):
    """扩展状态,添加自定义字段"""
    call_count: int  # 模型调用次数
    total_tokens: int  # 总 token 数
    user_metadata: dict  # 用户元数据

# ===== 中间件追踪调用次数 =====
@before_model
def track_calls(state: ExtendedAgentState, runtime: Runtime[Context]) -> dict | None:
    """记录模型调用次数"""
    current_count = state.get("call_count", 0)
    print(f" 这是第 {current_count + 1} 次模型调用")
    
    return {"call_count": current_count + 1}

@after_model
def track_tokens(state: ExtendedAgentState, runtime: Runtime[Context]) -> dict | None:
    """记录 token 使用量"""
    last_msg = state["messages"][-1]
    
    # 估算 token 数(简单方法)
    token_count = len(str(last_msg.content).split())
    current_total = state.get("total_tokens", 0)
    
    return {"total_tokens": current_total + token_count}

agent = create_agent(
    model=model,
    tools=[],
    middleware=[track_calls, track_tokens],
    context_schema=Context,
    state_schema=ExtendedAgentState  # 使用扩展的状态
)

result = agent.invoke(
    {
        "messages": [{"role": "user", "content": "你好"}],
        "call_count": 0,
        "total_tokens": 0,
        "user_metadata": {"plan": "premium"}
    },
    context={"user_id": "user123"}
)

print(f"最终调用次数: {result['call_count']}")
print(f"总 token 数: {result['total_tokens']}")

实战:完整的中间件系统

import os
from langchain_openai import ChatOpenAI
from langchain.agents import create_agent, AgentState
from langchain.agents.middleware import before_model, after_model
from langchain.tools import tool
from langchain.messages import AIMessage
from langgraph.runtime import Runtime
from typing import TypedDict
from datetime import datetime
import time

model = ChatOpenAI(
    model=os.getenv("MODEL_NAME", "Qwen/Qwen2-7B-Instruct"),
    temperature=0.7,
    base_url=os.getenv("SILICONFLOW_BASE_URL"),
    api_key=os.getenv("SILICONFLOW_API_KEY")
)

# ===== 上下文和状态 =====
class Context(TypedDict):
    user_id: str
    user_role: str  # "admin" 或 "user"

class ExtendedState(AgentState):
    call_count: int
    start_time: float
    execution_logs: list

# ===== 工具 =====
@tool
def search_tool(query: str) -> str:
    """搜索工具"""
    return f"找到关于 '{query}' 的 3 个结果"

# ===== 中间件 1: 请求验证 =====
@before_model
def validate_request(state: ExtendedState, runtime: Runtime[Context]) -> dict | None:
    """验证请求权限"""
    user_role = runtime.context.get("user_role", "user")
    
    # 仅管理员可以连续发送超过 10 条消息
    msg_count = len(state.get("messages", []))
    if msg_count > 10 and user_role != "admin":
        return {
            "messages": [AIMessage(content="您已达到消息限制。")],
            "jumpTo": "end"
        }
    
    return None

# ===== 中间件 2: 请求日志 =====
@before_model
def log_request(state: ExtendedState, runtime: Runtime[Context]) -> dict | None:
    """记录请求详情"""
    user_id = runtime.context.user_id
    call_count = state.get("call_count", 0)
    
    log_msg = f"[{datetime.now().strftime('%H:%M:%S')}] 用户 {user_id} - 调用 #{call_count + 1}"
    
    logs = state.get("execution_logs", [])
    logs.append(log_msg)
    
    return {
        "call_count": call_count + 1,
        "execution_logs": logs
    }

# ===== 中间件 3: 响应验证 =====
@after_model
def validate_response(state: ExtendedState, runtime: Runtime[Context]) -> dict | None:
    """检查响应质量"""
    last_msg = state["messages"][-1]
    
    if isinstance(last_msg, AIMessage):
        content = last_msg.content
        
        # 内容过短可能是错误
        if len(content) < 10:
            return {
                "messages": [AIMessage(content="模型响应异常,请重试。")]
            }
    
    return None

# ===== 中间件 4: 性能监控 =====
@after_model
def monitor_performance(state: ExtendedState, runtime: Runtime[Context]) -> dict | None:
    """监控执行时间"""
    start_time = state.get("start_time", time.time())
    elapsed = time.time() - start_time
    
    logs = state.get("execution_logs", [])
    logs.append(f"⏱️  耗时: {elapsed:.2f}秒")
    
    return {"execution_logs": logs}

# ===== 创建 Agent =====
agent = create_agent(
    model=model,
    tools=[search_tool],
    middleware=[
        validate_request,    # ① 验证权限
        log_request,        # ② 记录请求
        validate_response,  # ③ 验证响应
        monitor_performance # ④ 监控性能
    ],
    context_schema=Context,
    state_schema=ExtendedState
)

# ===== 测试 =====
if __name__ == "__main__":
    result = agent.invoke(
        {
            "messages": [{"role": "user", "content": "搜索 Python"}],
            "call_count": 0,
            "start_time": time.time(),
            "execution_logs": []
        },
        context={"user_id": "user_001", "user_role": "admin"}
    )
    
    print("执行日志:")
    for log in result.get("execution_logs", []):
        print(f"  {log}")

关键概念总结

特性Node-StyleWrap-Style
执行时机顺序执行在特定点环绕执行
用途日志、验证、更新重试、缓存、动态选择
控制权框架控制你控制 handler
钩子before_*, after_*wrap_*
返回值dictNone直接返回结果

作者:世界那么哒哒
链接:juejin.cn/post/757240…
来源:稀土掘金
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。

相关文档:

  • Custom Middleware 详解
  • Runtime 对象
  • [Agents 中间件
本站提供的所有下载资源均来自互联网,仅提供学习交流使用,版权归原作者所有。如需商业使用,请联系原作者获得授权。 如您发现有涉嫌侵权的内容,请联系我们 邮箱:[email protected]