Skip to content

Commit f1a7130

Browse files
authored
对接retry接口
对接retry接口
2 parents 28c98e1 + a309372 commit f1a7130

12 files changed

Lines changed: 632 additions & 158 deletions

File tree

prompto-lab-app/pom.xml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,11 @@
1919
<maven.compiler.target>17</maven.compiler.target>
2020
</properties>
2121
<dependencies>
22+
<dependency>
23+
<groupId>commons-lang</groupId>
24+
<artifactId>commons-lang</artifactId>
25+
<version>2.6</version>
26+
</dependency>
2227
<dependency>
2328
<groupId>org.springframework.boot</groupId>
2429
<artifactId>spring-boot-starter-web</artifactId>

prompto-lab-app/src/main/java/io/github/timemachinelab/controller/UserInteractionController.java

Lines changed: 49 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
package io.github.timemachinelab.controller;
22

3-
import io.github.timemachinelab.core.qatree.QaTree;
4-
import io.github.timemachinelab.core.qatree.QaTreeDomain;
5-
import io.github.timemachinelab.core.session.application.ConversationService;
63
import io.github.timemachinelab.core.session.application.MessageProcessingService;
74
import io.github.timemachinelab.core.session.application.SessionManagementService;
5+
import io.github.timemachinelab.core.session.application.SseNotificationService;
86
import io.github.timemachinelab.core.session.domain.entity.ConversationSession;
9-
import io.github.timemachinelab.core.session.infrastructure.ai.QuestionGenerationOperation;
107
import io.github.timemachinelab.core.session.infrastructure.web.dto.UnifiedAnswerRequest;
11-
import io.github.timemachinelab.core.session.infrastructure.web.dto.MessageResponse;
128
import io.github.timemachinelab.entity.req.RetryRequest;
139
import io.github.timemachinelab.entity.resp.ApiResult;
1410
import io.github.timemachinelab.entity.resp.RetryResponse;
1511
import lombok.extern.slf4j.Slf4j;
16-
import org.springframework.beans.factory.annotation.Autowired;
1712
import org.springframework.http.MediaType;
1813
import org.springframework.http.ResponseEntity;
1914
import org.springframework.validation.annotation.Validated;
@@ -23,7 +18,6 @@
2318
import javax.annotation.Resource;
2419
import javax.validation.Valid;
2520
import java.io.IOException;
26-
import java.util.HashMap;
2721
import java.util.Map;
2822
import java.util.UUID;
2923
import java.util.concurrent.ConcurrentHashMap;
@@ -40,16 +34,12 @@
4034
@RequestMapping("/api/user-interaction")
4135
@Validated
4236
public class UserInteractionController {
43-
44-
@Resource
45-
private ConversationService conversationService;
4637
@Resource
4738
private MessageProcessingService messageProcessingService;
4839
@Resource
4940
private SessionManagementService sessionManagementService;
50-
private final Map<String, SseEmitter> sseEmitters = new ConcurrentHashMap<>();
51-
@Autowired
52-
private QaTreeDomain qaTreeDomain;
41+
@Resource
42+
private SseNotificationService sseNotificationService;
5343

5444
/**
5545
* 建立SSE连接
@@ -60,7 +50,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
6050
log.info("建立SSE连接 - 会话ID: {}, 用户ID: {}", sessionId, userId);
6151

6252
boolean isNewSession = false;
63-
ConversationSession session = null;
53+
ConversationSession session;
6454

6555
try {
6656
if (sessionId == null || sessionId.isEmpty()) {
@@ -82,7 +72,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
8272
}
8373

8474
SseEmitter emitter = new SseEmitter(Long.MAX_VALUE);
85-
sseEmitters.put(sessionId, emitter);
75+
sseNotificationService.registerSseConnection(sessionId, emitter);
8676

8777
// 连接建立时发送会话信息
8878
Map<String, Object> connectionData = new ConcurrentHashMap<>();
@@ -115,9 +105,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
115105
log.info("兜底返回根节点ID: 1 - 会话: {}", sessionId);
116106
}
117107

118-
emitter.send(SseEmitter.event()
119-
.name("connected")
120-
.data(connectionData));
108+
sseNotificationService.sendWelcomeMessage(sessionId, connectionData);
121109

122110
// 设置连接事件处理
123111
String finalSessionId = sessionId;
@@ -127,12 +115,12 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
127115

128116
emitter.onTimeout(() -> {
129117
log.info("SSE连接超时: {}", finalSessionId);
130-
sseEmitters.remove(finalSessionId);
118+
sseNotificationService.removeSseConnection(finalSessionId);
131119
});
132120

133121
emitter.onError((ex) -> {
134122
log.error("SSE连接错误: {} - {}", finalSessionId, ex.getMessage());
135-
sseEmitters.remove(finalSessionId);
123+
sseNotificationService.removeSseConnection(finalSessionId);
136124
});
137125

138126
return emitter;
@@ -163,8 +151,45 @@ public ResponseEntity<ApiResult<RetryResponse>> retry(@Valid @RequestBody RetryR
163151
log.info("收到重试请求 - nodeId: {}, sessionId: {}, whyretry: {}",
164152
request.getNodeId(), request.getSessionId(), request.getWhyretry());
165153

154+
// 使用应用服务验证节点存在性
155+
//todo: 有可能水平越权 不传userId的话
156+
if (!sessionManagementService.validateNodeExists(request.getSessionId(), request.getNodeId())) {
157+
log.warn("节点不存在 - nodeId: {}, sessionId: {}", request.getNodeId(), request.getSessionId());
158+
return ResponseEntity.badRequest().body(ApiResult.error("指定的节点不存在"));
159+
}
160+
161+
// 使用应用服务获取问题内容
162+
String question = sessionManagementService.getNodeQuestion(request.getSessionId(), request.getNodeId());
163+
if (question == null) {
164+
log.warn("节点问题内容为空 - nodeId: {}, sessionId: {}", request.getNodeId(), request.getSessionId());
165+
return ResponseEntity.badRequest().body(ApiResult.error("节点问题内容为空"));
166+
}
166167

167-
168+
// 获取会话对象
169+
ConversationSession session = sessionManagementService.getSessionById(request.getSessionId());
170+
if (session == null) {
171+
log.warn("会话不存在 - sessionId: {}", request.getSessionId());
172+
return ResponseEntity.badRequest().body(ApiResult.error("会话不存在"));
173+
}
174+
175+
// 移除要重试的节点(AI会基于parentId重新创建节点)
176+
boolean nodeRemoved = sessionManagementService.removeNode(request.getSessionId(), request.getNodeId());
177+
if (!nodeRemoved) {
178+
log.warn("移除节点失败,但继续处理重试 - sessionId: {}, nodeId: {}",
179+
request.getSessionId(), request.getNodeId());
180+
}
181+
182+
// 使用MessageProcessingService处理重试消息
183+
String processedMessage = messageProcessingService.processRetryMessage(
184+
request.getSessionId(),
185+
request.getNodeId(),
186+
request.getWhyretry(),
187+
session
188+
);
189+
190+
// 发送处理后的消息给AI服务
191+
messageProcessingService.processAndSendMessage(session, processedMessage);
192+
168193
// 构建响应数据
169194
RetryResponse response = RetryResponse.builder()
170195
.nodeId(request.getNodeId())
@@ -196,12 +221,6 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
196221
request.getNodeId(),
197222
request.getQuestionType());
198223

199-
// 1. 强制要求sessionId
200-
if (request.getSessionId() == null || request.getSessionId().trim().isEmpty()) {
201-
log.warn("缺少必需的sessionId参数");
202-
return ResponseEntity.badRequest().body("sessionId参数是必需的");
203-
}
204-
205224
// 2. 会话管理和验证
206225
String userId = request.getUserId();
207226
if (userId == null || userId.trim().isEmpty()) {
@@ -248,28 +267,7 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
248267
return ResponseEntity.badRequest().body("答案格式不正确");
249268
}
250269

251-
QaTree qaTree = session.getQaTree();
252-
253-
// 根据问题类型获取正确的答案数据
254-
Object answerData;
255-
switch (request.getQuestionType().toLowerCase()) {
256-
case "input":
257-
answerData = request.getInputAnswer();
258-
break;
259-
case "single":
260-
case "multi":
261-
answerData = request.getChoiceAnswer();
262-
break;
263-
case "form":
264-
answerData = request.getFormAnswer();
265-
break;
266-
default:
267-
log.warn("未知的问题类型: {}", request.getQuestionType());
268-
answerData = request.getAnswerString();
269-
break;
270-
}
271-
272-
qaTreeDomain.updateNodeAnswer(qaTree, request.getNodeId(), answerData);
270+
// 答案更新逻辑已在MessageProcessingService中处理
273271

274272
// 4. 处理答案并转换为消息
275273
String processedMessage = messageProcessingService.preprocessMessage(
@@ -279,11 +277,7 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
279277
);
280278

281279
// 5. 发送处理后的消息给AI服务
282-
conversationService.processUserMessage(
283-
session.getUserId(),
284-
processedMessage,
285-
response -> sendSseMessage(session.getSessionId(), response)
286-
);
280+
messageProcessingService.processAndSendMessage(session, processedMessage);
287281

288282

289283
return ResponseEntity.ok("答案处理成功");
@@ -294,89 +288,11 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
294288
}
295289
}
296290

297-
/**
298-
* 通过SSE发送消息给客户端
299-
* 在AI回复时创建QA节点,填入question,answer留空等用户提交后再更新
300-
*
301-
* @param sessionId 会话ID
302-
* @param response 消息响应对象
303-
*/
304-
private void sendSseMessage(String sessionId, QuestionGenerationOperation.QuestionGenerationResponse response) {
305-
SseEmitter emitter = sseEmitters.get(sessionId);
306-
if (emitter != null) {
307-
try {
308-
String currentNodeId = null;
309-
310-
// 1. 先将AI生成的新问题添加到QaTree(只填入question,answer留空)
311-
ConversationSession session = sessionManagementService.getSessionById(sessionId);
312-
if (session != null && session.getQaTree() != null && response.getQuestion() != null) {
313-
// 使用QaTreeDomain添加新节点,answer字段会自动为空
314-
// appendNode方法内部会调用session.getNextNodeId()获取新节点ID
315-
QaTree qaTree = qaTreeDomain.appendNode(
316-
session.getQaTree(),
317-
response.getParentId(),
318-
response.getQuestion(),
319-
session
320-
);
321-
322-
// 获取刚刚创建的节点ID(当前计数器的值)
323-
currentNodeId = String.valueOf(session.getNodeIdCounter().get());
324-
325-
log.info("AI问题已添加到QaTree - 会话: {}, 父节点: {}, 新节点ID: {}, 问题类型: {}",
326-
sessionId, response.getParentId(), currentNodeId, response.getQuestion().getType());
327-
} else {
328-
log.warn("无法添加问题到QaTree - 会话: {}, session存在: {}, qaTree存在: {}, question存在: {}",
329-
sessionId, session != null,
330-
session != null && session.getQaTree() != null,
331-
response.getQuestion() != null);
332-
}
333-
334-
// 2. 创建修改后的响应对象,包含currentNodeId和parentNodeId
335-
Map<String, Object> modifiedResponse = new HashMap<>();
336-
modifiedResponse.put("question", response.getQuestion());
337-
modifiedResponse.put("currentNodeId", currentNodeId != null ? currentNodeId : response.getParentId());
338-
modifiedResponse.put("parentNodeId", response.getParentId());
339-
340-
// 3. 发送SSE消息给前端
341-
emitter.send(SseEmitter.event()
342-
.name("message")
343-
.data(modifiedResponse));
344-
log.info("SSE消息发送成功 - 会话: {}, 当前节点ID: {}", sessionId, currentNodeId);
345-
} catch (IOException e) {
346-
log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, e.getMessage());
347-
sseEmitters.remove(sessionId);
348-
} catch (Exception e) {
349-
log.error("添加问题到QaTree失败 - 会话: {}, 错误: {}", sessionId, e.getMessage());
350-
// 即使QaTree更新失败,仍然发送SSE消息给前端
351-
try {
352-
Map<String, Object> fallbackResponse = new HashMap<>();
353-
fallbackResponse.put("question", response.getQuestion());
354-
fallbackResponse.put("currentNodeId", response.getParentId()); // 使用parentId作为fallback
355-
fallbackResponse.put("parentNodeId", response.getParentId());
356-
357-
emitter.send(SseEmitter.event()
358-
.name("message")
359-
.data(fallbackResponse));
360-
log.info("SSE消息发送成功(QaTree更新失败但消息已发送) - 会话: {}", sessionId);
361-
} catch (IOException ioException) {
362-
log.error("SSE消息发送失败 - 会话: {}, 错误: {}", sessionId, ioException.getMessage());
363-
sseEmitters.remove(sessionId);
364-
}
365-
}
366-
} else {
367-
log.warn("SSE连接不存在 - 会话: {}", sessionId);
368-
}
369-
}
370-
371291
/**
372292
* 获取SSE连接状态
373293
*/
374294
@GetMapping("/sse-status")
375295
public ResponseEntity<Map<String, Object>> getSseStatus() {
376-
Map<String, Object> status = new ConcurrentHashMap<>();
377-
status.put("connectedSessions", sseEmitters.keySet());
378-
status.put("totalConnections", sseEmitters.size());
379-
status.put("timestamp", System.currentTimeMillis());
380-
return ResponseEntity.ok(status);
296+
return ResponseEntity.ok(sseNotificationService.getSseStatus());
381297
}
382298
}

prompto-lab-app/src/main/java/io/github/timemachinelab/core/qatree/QaTree.java

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,59 @@ public void addNode(String parentId, QaTreeNode node) {
2929
public QaTreeNode getNodeById(String id) {
3030
return nodeMap.get(id);
3131
}
32+
33+
/**
34+
* 移除指定节点及其所有子节点
35+
* @param nodeId 要移除的节点ID
36+
* @return 是否移除成功
37+
*/
38+
public boolean removeNode(String nodeId) {
39+
QaTreeNode nodeToRemove = nodeMap.get(nodeId);
40+
if (nodeToRemove == null) {
41+
return false;
42+
}
43+
44+
// 递归移除所有子节点
45+
removeNodeAndChildren(nodeToRemove);
46+
47+
// 从父节点的children中移除该节点
48+
removeFromParent(nodeToRemove);
49+
50+
return true;
51+
}
52+
53+
/**
54+
* 递归移除节点及其所有子节点
55+
* @param node 要移除的节点
56+
*/
57+
private void removeNodeAndChildren(QaTreeNode node) {
58+
if (node == null) {
59+
return;
60+
}
61+
62+
// 递归移除所有子节点
63+
if (node.getChildren() != null) {
64+
for (QaTreeNode child : node.getChildren().values()) {
65+
removeNodeAndChildren(child);
66+
}
67+
}
68+
69+
// 从nodeMap中移除当前节点
70+
nodeMap.remove(node.getId());
71+
}
72+
73+
/**
74+
* 从父节点的children中移除指定节点
75+
* @param nodeToRemove 要移除的节点
76+
*/
77+
private void removeFromParent(QaTreeNode nodeToRemove) {
78+
// 遍历所有节点找到父节点
79+
for (QaTreeNode node : nodeMap.values()) {
80+
if (node.getChildren() != null && node.getChildren().containsKey(nodeToRemove.getId())) {
81+
node.removeChild(nodeToRemove.getId());
82+
break;
83+
}
84+
}
85+
}
3286

3387
}

0 commit comments

Comments
 (0)