Agent系列——SPring AI Alibaba Graph初探
文章目录
一、概述
为什么需要Graph

核心概念

二、快速入门
实现如下工作流:
开始节点→node1→node2→结束节点
用node2的值替换node1的值
依赖版本
spring-boot:3.4.0
spring-ai-alibaba:1.0.0.4
pom.xml添加核心依赖
<dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-web</artifactId></dependency><dependency><groupId>org.springframework.ai</groupId><artifactId>spring-ai-starter-model-zhipuai</artifactId></dependency><dependency><groupId>com.alibaba.cloud.ai</groupId><artifactId>spring-ai-alibaba-graph-core</artifactId></dependency>修改配置文件application.yaml
server:port:8889spring:application:name: agent-graph ai:zhipuai:api-key: ${ZHIPU_KEY}# 配置智谱大模型的API Keychat:options:model: glm-4-flash 创建状态图的配置类

GraphConfig.java
@Configuration@Slf4jpublicclassGraphConfig{@Bean("quickStartGraph")publicCompiledGraphquickStartGraph()throwsGraphStateException{KeyStrategyFactory keyStrategyFactory =newKeyStrategyFactory(){@OverridepublicMap<String,KeyStrategy>apply(){// ReplaceStrategy为替换策略returnMap.of("input1",newReplaceStrategy(),"input2",newReplaceStrategy());}};// 定义状态图StateGraphStateGraph stateGraph =newStateGraph("quickStartGraph", keyStrategyFactory);// 添加节点// AsyncNodeAction.node_async为异步执行 stateGraph.addNode("node1",AsyncNodeAction.node_async(newNodeAction(){@OverridepublicMap<String,Object>apply(OverAllState state)throwsException{ log.info("node1 state: {}", state);returnMap.of("input1",1,"input2",1);}})); stateGraph.addNode("node2",AsyncNodeAction.node_async(newNodeAction(){@OverridepublicMap<String,Object>apply(OverAllState state)throwsException{ log.info("node2 state: {}", state);returnMap.of("input1",2,"input2",2);}}));// 定义边 stateGraph.addEdge(StateGraph.START,"node1"); stateGraph.addEdge("node1","node2"); stateGraph.addEdge("node2",StateGraph.END);// 编译状态图return stateGraph.compile();}}创建一个Controller
GraphController.java
@RestController @RequestMapping("/graph") @Slf4j publicclassGraphController{private final CompiledGraph compiledGraph;publicGraphController(CompiledGraph compiledGraph){this.compiledGraph = compiledGraph;} @GetMapping("/quickStartGraph")publicStringquickStartGraph(){Optional<OverAllState> overAllStateOptional = compiledGraph.call(Map.of()); log.info("overAllStateOptional: {}", overAllStateOptional);return"OK";}}启动程序,查看效果
GET方式,http://localhost:8889/graph/quickStartGraph

发现input1和input2的值被成功替换为2
三、API详解
KeyStrategyFactory(键策略工厂)

NodeAction&AsyncNodeAction

stateGraph(状态图)
状态图的抽象,需要配置状态(通过KeyStrategyFactory ),节点,边。
配置好后通过compile方法编译成CompiledGraph后才可以供调用。
CompiledGraph(编译图)
CompiledGraph是StateGraph编译后的结果,CompiledGraph才能用了执行。
一般我们是把StateGraph定义好后调用其compile方法得到一个CompiledGraph放入Spring容器中然后在需要的时候从容器中注入然后再调用。

四、案例:开发一个英语学习小助手
需求
使用Graph开发一个英语学习小助手。
功能如下:输入一个单词,能基于这个单词造句,然后再对句子进行翻译,把造句的译文也返回。
思路分析
我们可以定义一个工作流,工作流中主要有两个节点:
SentenceConstructionNode 造句节点,拿输入的单词让LLM进行造句。
TranslationNode 翻译节点,能够把一个英文句子翻译成中文。最终把造句的结果和翻译的结果返回即可。
流程图
开始节点(输入一个单词)–>造句节点(根据给定的单词进行造句)–>翻译节点(对句子进行翻译)–>结束节点(输出造句和翻译的结果)
代码编写
定义SentenceConstructionNode造句节点

publicclassSentenceConstructionNodeimplementsNodeAction{privatefinalChatClient chatClient;publicSentenceConstructionNode(ChatClient.Builder builder){this.chatClient = builder.build();}@OverridepublicMap<String,Object>apply(OverAllState state)throwsException{// 从stage中获取要造句的单词String word = state.value("word","");// 定义提示词PromptTemplate promptTemplate =newPromptTemplate("你是一个英语造句专家,能够基于给定的单词进行造句。"+"要求只返回最终造好的句子,不要返回其他信息。给定的单词:{word}"); promptTemplate.add("word", word);// 替换占位符String prompt = promptTemplate.render();// 渲染提示词// 模型调用String content = chatClient.prompt().user(prompt).call().content();// 把句子存入stagereturnMap.of("sentence", content);}}定义TranslationNode翻译节点
publicclassTranslationNodeimplementsNodeAction{privatefinalChatClient chatClient;publicTranslationNode(ChatClient.Builder builder){this.chatClient = builder.build();}@OverridepublicMap<String,Object>apply(OverAllState state)throwsException{// 从stage中获取要翻译的句子String sentence = state.value("sentence","");// 定义提示词PromptTemplate promptTemplate =newPromptTemplate("你是一个英语翻译专家,能够把英文翻译成中文。"+"要求只返回翻译的中文结果,不要返回英文原句。要翻译的英文句子:{sentence}"); promptTemplate.add("sentence", sentence);// 替换占位符String prompt = promptTemplate.render();// 渲染提示词// 模型调用String content = chatClient.prompt().user(prompt).call().content();// 把翻译结果存入stagereturnMap.of("translation", content);}}定义状态图
config/GraphConfig.java,在quickStartGraph下面增加如下内容

@Bean("simpleGraph")publicCompiledGraphsimpleGraph(ChatClient.Builder clientBuilder)throwsGraphStateException{KeyStrategyFactory keyStrategyFactory =()->{HashMap<String,KeyStrategy> keyStrategyHashMap =newHashMap<>(); keyStrategyHashMap.put("word",newReplaceStrategy()); keyStrategyHashMap.put("sentence",newReplaceStrategy()); keyStrategyHashMap.put("translation",newReplaceStrategy());return keyStrategyHashMap;};// 创建状态图StateGraph stateGraph =newStateGraph("simpleGraph", keyStrategyFactory);// 添加节点 stateGraph.addNode("SentenceConstructionNode",AsyncNodeAction.node_async(newSentenceConstructionNode(clientBuilder))); stateGraph.addNode("TranslationNode",AsyncNodeAction.node_async(newTranslationNode(clientBuilder)));// 定义边 stateGraph.addEdge(StateGraph.START,"SentenceConstructionNode"); stateGraph.addEdge("SentenceConstructionNode","TranslationNode"); stateGraph.addEdge("TranslationNode",StateGraph.END);// 编译状态图,放入容器return stateGraph.compile();}新增API接口

@RestController@RequestMapping("/graph")@Slf4jpublicclassGraphController{privatefinalCompiledGraph compiledGraph;privatefinalCompiledGraph simpleGraph;publicGraphController(@Qualifier("quickStartGraph")CompiledGraph compiledGraph,@Qualifier("simpleGraph")CompiledGraph simpleGraph){this.compiledGraph = compiledGraph;this.simpleGraph = simpleGraph;}@GetMapping("/quickStartGraph")publicStringquickStartGraph(){Optional<OverAllState> overAllStateOptional = compiledGraph.call(Map.of()); log.info("overAllStateOptional: {}", overAllStateOptional);return"OK";}@GetMapping("/simpleGraph")publicMap<String,Object>simpleGraph(@RequestParam("word")String word){Optional<OverAllState> overAllStateOptional = simpleGraph.call(Map.of("word", word));Map<String,Object> data = overAllStateOptional.map(OverAllState::data).orElse(Map.of());return data;}}启动服务,访问接口
GET方法:http://localhost:8889/graph/simpleGraph?word=sky

五、条件边


代码结构

定义GenerateJokeNode生成笑话节点
publicclassGenerateJokeNodeimplementsNodeAction{privatefinalChatClient chatClient;publicGenerateJokeNode(ChatClient.Builder builder){this.chatClient = builder.build();}@OverridepublicMap<String,Object>apply(OverAllState state)throwsException{// 从stage中获取笑话主题String topic = state.value("topic","");// 定义提示词PromptTemplate promptTemplate =newPromptTemplate("你需要写一个关于指定主题的短笑话。要求返回的结果中只能包含笑话的内容"+"主题:{topic}"); promptTemplate.add("topic", topic);// 替换占位符String prompt = promptTemplate.render();// 渲染提示词// 模型调用String content = chatClient.prompt().user(prompt).call().content();// 把结果存入stagereturnMap.of("joke", content);}}定义EvaluateJokesNode评估笑话节点
publicclassEvaluateJokesNodeimplementsNodeAction{privatefinalChatClient chatClient;publicEvaluateJokesNode(ChatClient.Builder builder){this.chatClient = builder.build();}@OverridepublicMap<String,Object>apply(OverAllState state)throwsException{// 从stage中获取待评估笑话String joke = state.value("joke","");// 定义提示词PromptTemplate promptTemplate =newPromptTemplate("你是一个笑话评分专家,能够对笑话进行评分,基于效果的搞笑程度给出0到10分的打分。"+"0到3分是不够优秀,4到10分是优秀。要求结果只返回优秀或者不够优秀,不能输出其他内容。"+"要评分的笑话:{joke}"); promptTemplate.add("joke", joke);// 替换占位符String prompt = promptTemplate.render();// 渲染提示词// 模型调用String content = chatClient.prompt().user(prompt).call().content();// 把结果存入stagereturnMap.of("result", content.trim());}}定义EnhanceJokeQualityNode优化笑话节点
publicclassEnhanceJokeQualityNodeimplementsNodeAction{privatefinalChatClient chatClient;publicEnhanceJokeQualityNode(ChatClient.Builder builder){this.chatClient = builder.build();}@OverridepublicMap<String,Object>apply(OverAllState state)throwsException{// 从stage中获取待评估笑话String joke = state.value("joke","");// 定义提示词PromptTemplate promptTemplate =newPromptTemplate("你是一个笑话优化专家,你能够优化笑话,让它更加搞笑"+"要优化的话:{joke}"); promptTemplate.add("joke", joke);// 替换占位符String prompt = promptTemplate.render();// 渲染提示词// 模型调用String content = chatClient.prompt().user(prompt).call().content();// 把结果存入stagereturnMap.of("newJoke", content);}}在GraphConfig下面定义图
@Bean("conditionalGraph")publicCompiledGraphconditionalGraph(ChatClient.Builder clientBuilder)throwsGraphStateException{KeyStrategyFactory keyStrategyFactory =()->Map.of("topic",newReplaceStrategy());// 定义状态图StateGraphStateGraph stateGraph =newStateGraph("conditionalGraph", keyStrategyFactory);// 定义节点 stateGraph.addNode("生成笑话",AsyncNodeAction.node_async(newGenerateJokeNode(clientBuilder))); stateGraph.addNode("评估笑话",AsyncNodeAction.node_async(newEvaluateJokesNode(clientBuilder))); stateGraph.addNode("优化笑话",AsyncNodeAction.node_async(newEnhanceJokeQualityNode(clientBuilder)));// 定义边 stateGraph.addEdge(StateGraph.START,"生成笑话"); stateGraph.addEdge("生成笑话","评估笑话"); stateGraph.addConditionalEdges("评估笑话",AsyncEdgeAction.edge_async( state -> state.value("result","优秀")),Map.of("优秀",StateGraph.END,"不够优秀","优化笑话")); stateGraph.addEdge("优化笑话",StateGraph.END);return stateGraph.compile();}在GraphController下创建接口
privatefinalCompiledGraph compiledGraph;privatefinalCompiledGraph simpleGraph;privatefinalCompiledGraph conditionalGraph;publicGraphController(@Qualifier("quickStartGraph")CompiledGraph compiledGraph,@Qualifier("simpleGraph")CompiledGraph simpleGraph,@Qualifier("conditionalGraph")CompiledGraph conditionalGraph){this.compiledGraph = compiledGraph;this.simpleGraph = simpleGraph;this.conditionalGraph = conditionalGraph;}@GetMapping("/conditionalGraph")publicMap<String,Object>conditionalGraph(@RequestParam("topic")String topic){Optional<OverAllState> overAllStateOptional = conditionalGraph.call(Map.of("topic", topic));Map<String,Object> data = overAllStateOptional.map(OverAllState::data).orElse(Map.of());return data;}验证效果
GET方式:http://localhost:8889/graph/conditionalGraph?topic=爱情
评估结果是优秀

直接输出结果

在断点处右击,选择“Evaluate Expression”

篡改评估结果为"不够优秀",回车后关闭

修改成功

就会走优化节点,生成新的笑话

六、循环边

新增LoopEvaluateJokesNode循环评分节点
@Slf4jpublicclassLoopEvaluateJokesNodeimplementsNodeAction{privatefinalChatClient chatClient;privatefinalInteger targetScore;privatefinalInteger maxLoopCount;publicLoopEvaluateJokesNode(ChatClient.Builder builder,Integer targetScore,Integer maxLoopCount){this.chatClient = builder.build();this.targetScore = targetScore;this.maxLoopCount = maxLoopCount;}@OverridepublicMap<String,Object>apply(OverAllState state)throwsException{// 从stage中获取待评估笑话String joke = state.value("joke","");// 循环次数Integer loopCount = state.value("loopCount",0);// 定义提示词PromptTemplate promptTemplate =newPromptTemplate("你是一个笑话评分专家,能够对笑话进行评分,基于效果的搞笑程度给出0到10分的打分。"+"要求结果只返回最后的打分,打分必须是整数,不能输出其他内容。"+"要评分的笑话:{joke}"); promptTemplate.add("joke", joke);// 替换占位符String prompt = promptTemplate.render();// 渲染提示词// 模型调用String content = chatClient.prompt().user(prompt).call().content();// content转为整数Integer score =Integer.parseInt(content.trim()); log.info("joke: {},score: {},循环次数: {}", joke, score, loopCount);// 根据分数判断是否继续循环,循环最多执行5次String result ="loop";if(score >= targetScore || loopCount >= maxLoopCount){ result ="break";} loopCount++;// 把结果存入stagereturnMap.of("result", result,"loopCount", loopCount);}}在GraphConfig下面定义图
@Bean("loopGraph")publicCompiledGraphloopGraph(ChatClient.Builder clientBuilder)throws GraphStateException {KeyStrategyFactory keyStrategyFactory =()-> Map.of("topic",newReplaceStrategy());// 定义状态图StateGraphStateGraph stateGraph =newStateGraph("loopGraph", keyStrategyFactory);// 定义节点 stateGraph.addNode("生成笑话", AsyncNodeAction.node_async(newGenerateJokeNode(clientBuilder))); stateGraph.addNode("评估笑话", AsyncNodeAction.node_async(newLoopEvaluateJokesNode(clientBuilder,8,5)));// 定义边 stateGraph.addEdge(StateGraph.START,"生成笑话"); stateGraph.addEdge("生成笑话","评估笑话"); stateGraph.addConditionalEdges("评估笑话", AsyncEdgeAction.edge_async( state -> state.value("result","loop")), Map.of("loop","生成笑话","break", StateGraph.END));return stateGraph.compile();}在GraphController下创建接口
@GetMapping("/loopGraph")publicMap<String, Object>loopGraph(@RequestParam("topic")String topic){Optional<OverAllState> overAllStateOptional = loopGraph.call(Map.of("topic", topic));Map<String, Object> data = overAllStateOptional.map(OverAllState::data).orElse(Map.of());return data;}测试
GET方式:http://localhost:8889/graph/loopGraph?topic=爱情
当score为8时,退出循环,输出结果


七、状态存储
我们可以把图中的状态数据进行存储。默契情况下Graph会把状态存储到内存中。
在ConfigGraph中创建状态图
@Bean("saveGraph")publicCompiledGraphsaveGraph(ChatClient.Builder clientBuilder)throwsGraphStateException{KeyStrategyFactory keyStrategyFactory =()->Map.of();// 定义状态图 stateGraphStateGraph stateGraph =newStateGraph("saveGraph", keyStrategyFactory); stateGraph.addNode("对话存储",AsyncNodeAction.node_async(newNodeAction(){@OverridepublicMap<String,Object>apply(OverAllState state)throwsException{String msg = state.value("msg","");ArrayList<Object> historyMsg = state.value("historyMsg",newArrayList<>()); historyMsg.add(msg);returnMap.of("historyMsg", historyMsg);}}));// 定义边 stateGraph.addEdge(StateGraph.START,"对话存储"); stateGraph.addEdge("对话存储",StateGraph.END);return stateGraph.compile();}在GraphController中创建接口
@GetMapping("/saveGraph")// 通过conversationId来隔离不同请求者的数据publicMap<String,Object>saveGraph(@RequestParam("msg")String msg,@RequestParam("conversationId")String conversationId){RunnableConfig runnableConfig =RunnableConfig.builder().threadId(conversationId).build();Optional<OverAllState> overAllStateOptional = saveGraph.call(Map.of("msg", msg), runnableConfig);Map<String,Object> data = overAllStateOptional.map(OverAllState::data).orElse(Map.of());return data;}测试
GET方式:http://localhost:8889/graph/saveGraph?msg=你好张三&conversationId=zs
第一次调用

第二次调用,发现前面的值存储了下来

修改会话ID,历史数据只有最新的一条数据

八、打印图
我们可以把定义好的状态图进行打印,更直观的看到当前图的情况
在图的下面添加如下代码:
// 添加PlantUML打印GraphRepresentation representation = stateGraph.getGraph(GraphRepresentation.Type.PLANTUML,"stateGraph"); log.info("\n===打印UML Flow==="); log.info(representation.content()); log.info("====================\n");
启动服务,复制如下内容

打开网址:http://www.plantuml.com/plantuml/
粘贴内容,就可以看到图的效果了

九、资料
视频:https://www.bilibili.com/video/BV1eyWbzEEnw?spm_id_from=333.788.player.switch&vd_source=0467ab39cc5ec5940fee22a0e7797575&p=45