11package io .github .timemachinelab .controller ;
22
3- import io .github .timemachinelab .core .qatree .QaTree ;
4- import io .github .timemachinelab .core .qatree .QaTreeDomain ;
5- import io .github .timemachinelab .core .session .application .ConversationService ;
63import io .github .timemachinelab .core .session .application .MessageProcessingService ;
74import io .github .timemachinelab .core .session .application .SessionManagementService ;
5+ import io .github .timemachinelab .core .session .application .SseNotificationService ;
86import io .github .timemachinelab .core .session .domain .entity .ConversationSession ;
9- import io .github .timemachinelab .core .session .infrastructure .ai .QuestionGenerationOperation ;
107import io .github .timemachinelab .core .session .infrastructure .web .dto .UnifiedAnswerRequest ;
11- import io .github .timemachinelab .core .session .infrastructure .web .dto .MessageResponse ;
128import io .github .timemachinelab .entity .req .RetryRequest ;
139import io .github .timemachinelab .entity .resp .ApiResult ;
1410import io .github .timemachinelab .entity .resp .RetryResponse ;
1511import lombok .extern .slf4j .Slf4j ;
16- import org .springframework .beans .factory .annotation .Autowired ;
1712import org .springframework .http .MediaType ;
1813import org .springframework .http .ResponseEntity ;
1914import org .springframework .validation .annotation .Validated ;
2318import javax .annotation .Resource ;
2419import javax .validation .Valid ;
2520import java .io .IOException ;
26- import java .util .HashMap ;
2721import java .util .Map ;
2822import java .util .UUID ;
2923import java .util .concurrent .ConcurrentHashMap ;
4034@ RequestMapping ("/api/user-interaction" )
4135@ Validated
4236public class UserInteractionController {
43-
44- @ Resource
45- private ConversationService conversationService ;
4637 @ Resource
4738 private MessageProcessingService messageProcessingService ;
4839 @ Resource
4940 private SessionManagementService sessionManagementService ;
50- private final Map <String , SseEmitter > sseEmitters = new ConcurrentHashMap <>();
51- @ Autowired
52- private QaTreeDomain qaTreeDomain ;
41+ @ Resource
42+ private SseNotificationService sseNotificationService ;
5343
5444 /**
5545 * 建立SSE连接
@@ -60,7 +50,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
6050 log .info ("建立SSE连接 - 会话ID: {}, 用户ID: {}" , sessionId , userId );
6151
6252 boolean isNewSession = false ;
63- ConversationSession session = null ;
53+ ConversationSession session ;
6454
6555 try {
6656 if (sessionId == null || sessionId .isEmpty ()) {
@@ -82,7 +72,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
8272 }
8373
8474 SseEmitter emitter = new SseEmitter (Long .MAX_VALUE );
85- sseEmitters . put (sessionId , emitter );
75+ sseNotificationService . registerSseConnection (sessionId , emitter );
8676
8777 // 连接建立时发送会话信息
8878 Map <String , Object > connectionData = new ConcurrentHashMap <>();
@@ -115,9 +105,7 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
115105 log .info ("兜底返回根节点ID: 1 - 会话: {}" , sessionId );
116106 }
117107
118- emitter .send (SseEmitter .event ()
119- .name ("connected" )
120- .data (connectionData ));
108+ sseNotificationService .sendWelcomeMessage (sessionId , connectionData );
121109
122110 // 设置连接事件处理
123111 String finalSessionId = sessionId ;
@@ -127,12 +115,12 @@ public SseEmitter streamConversation(@RequestParam(required = false) String sess
127115
128116 emitter .onTimeout (() -> {
129117 log .info ("SSE连接超时: {}" , finalSessionId );
130- sseEmitters . remove (finalSessionId );
118+ sseNotificationService . removeSseConnection (finalSessionId );
131119 });
132120
133121 emitter .onError ((ex ) -> {
134122 log .error ("SSE连接错误: {} - {}" , finalSessionId , ex .getMessage ());
135- sseEmitters . remove (finalSessionId );
123+ sseNotificationService . removeSseConnection (finalSessionId );
136124 });
137125
138126 return emitter ;
@@ -163,8 +151,45 @@ public ResponseEntity<ApiResult<RetryResponse>> retry(@Valid @RequestBody RetryR
163151 log .info ("收到重试请求 - nodeId: {}, sessionId: {}, whyretry: {}" ,
164152 request .getNodeId (), request .getSessionId (), request .getWhyretry ());
165153
154+ // 使用应用服务验证节点存在性
155+ //todo: 有可能水平越权 不传userId的话
156+ if (!sessionManagementService .validateNodeExists (request .getSessionId (), request .getNodeId ())) {
157+ log .warn ("节点不存在 - nodeId: {}, sessionId: {}" , request .getNodeId (), request .getSessionId ());
158+ return ResponseEntity .badRequest ().body (ApiResult .error ("指定的节点不存在" ));
159+ }
160+
161+ // 使用应用服务获取问题内容
162+ String question = sessionManagementService .getNodeQuestion (request .getSessionId (), request .getNodeId ());
163+ if (question == null ) {
164+ log .warn ("节点问题内容为空 - nodeId: {}, sessionId: {}" , request .getNodeId (), request .getSessionId ());
165+ return ResponseEntity .badRequest ().body (ApiResult .error ("节点问题内容为空" ));
166+ }
166167
167-
168+ // 获取会话对象
169+ ConversationSession session = sessionManagementService .getSessionById (request .getSessionId ());
170+ if (session == null ) {
171+ log .warn ("会话不存在 - sessionId: {}" , request .getSessionId ());
172+ return ResponseEntity .badRequest ().body (ApiResult .error ("会话不存在" ));
173+ }
174+
175+ // 移除要重试的节点(AI会基于parentId重新创建节点)
176+ boolean nodeRemoved = sessionManagementService .removeNode (request .getSessionId (), request .getNodeId ());
177+ if (!nodeRemoved ) {
178+ log .warn ("移除节点失败,但继续处理重试 - sessionId: {}, nodeId: {}" ,
179+ request .getSessionId (), request .getNodeId ());
180+ }
181+
182+ // 使用MessageProcessingService处理重试消息
183+ String processedMessage = messageProcessingService .processRetryMessage (
184+ request .getSessionId (),
185+ request .getNodeId (),
186+ request .getWhyretry (),
187+ session
188+ );
189+
190+ // 发送处理后的消息给AI服务
191+ messageProcessingService .processAndSendMessage (session , processedMessage );
192+
168193 // 构建响应数据
169194 RetryResponse response = RetryResponse .builder ()
170195 .nodeId (request .getNodeId ())
@@ -196,12 +221,6 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
196221 request .getNodeId (),
197222 request .getQuestionType ());
198223
199- // 1. 强制要求sessionId
200- if (request .getSessionId () == null || request .getSessionId ().trim ().isEmpty ()) {
201- log .warn ("缺少必需的sessionId参数" );
202- return ResponseEntity .badRequest ().body ("sessionId参数是必需的" );
203- }
204-
205224 // 2. 会话管理和验证
206225 String userId = request .getUserId ();
207226 if (userId == null || userId .trim ().isEmpty ()) {
@@ -248,28 +267,7 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
248267 return ResponseEntity .badRequest ().body ("答案格式不正确" );
249268 }
250269
251- QaTree qaTree = session .getQaTree ();
252-
253- // 根据问题类型获取正确的答案数据
254- Object answerData ;
255- switch (request .getQuestionType ().toLowerCase ()) {
256- case "input" :
257- answerData = request .getInputAnswer ();
258- break ;
259- case "single" :
260- case "multi" :
261- answerData = request .getChoiceAnswer ();
262- break ;
263- case "form" :
264- answerData = request .getFormAnswer ();
265- break ;
266- default :
267- log .warn ("未知的问题类型: {}" , request .getQuestionType ());
268- answerData = request .getAnswerString ();
269- break ;
270- }
271-
272- qaTreeDomain .updateNodeAnswer (qaTree , request .getNodeId (), answerData );
270+ // 答案更新逻辑已在MessageProcessingService中处理
273271
274272 // 4. 处理答案并转换为消息
275273 String processedMessage = messageProcessingService .preprocessMessage (
@@ -279,11 +277,7 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
279277 );
280278
281279 // 5. 发送处理后的消息给AI服务
282- conversationService .processUserMessage (
283- session .getUserId (),
284- processedMessage ,
285- response -> sendSseMessage (session .getSessionId (), response )
286- );
280+ messageProcessingService .processAndSendMessage (session , processedMessage );
287281
288282
289283 return ResponseEntity .ok ("答案处理成功" );
@@ -294,89 +288,11 @@ public ResponseEntity<String> processAnswer(@Validated @RequestBody UnifiedAnswe
294288 }
295289 }
296290
297- /**
298- * 通过SSE发送消息给客户端
299- * 在AI回复时创建QA节点,填入question,answer留空等用户提交后再更新
300- *
301- * @param sessionId 会话ID
302- * @param response 消息响应对象
303- */
304- private void sendSseMessage (String sessionId , QuestionGenerationOperation .QuestionGenerationResponse response ) {
305- SseEmitter emitter = sseEmitters .get (sessionId );
306- if (emitter != null ) {
307- try {
308- String currentNodeId = null ;
309-
310- // 1. 先将AI生成的新问题添加到QaTree(只填入question,answer留空)
311- ConversationSession session = sessionManagementService .getSessionById (sessionId );
312- if (session != null && session .getQaTree () != null && response .getQuestion () != null ) {
313- // 使用QaTreeDomain添加新节点,answer字段会自动为空
314- // appendNode方法内部会调用session.getNextNodeId()获取新节点ID
315- QaTree qaTree = qaTreeDomain .appendNode (
316- session .getQaTree (),
317- response .getParentId (),
318- response .getQuestion (),
319- session
320- );
321-
322- // 获取刚刚创建的节点ID(当前计数器的值)
323- currentNodeId = String .valueOf (session .getNodeIdCounter ().get ());
324-
325- log .info ("AI问题已添加到QaTree - 会话: {}, 父节点: {}, 新节点ID: {}, 问题类型: {}" ,
326- sessionId , response .getParentId (), currentNodeId , response .getQuestion ().getType ());
327- } else {
328- log .warn ("无法添加问题到QaTree - 会话: {}, session存在: {}, qaTree存在: {}, question存在: {}" ,
329- sessionId , session != null ,
330- session != null && session .getQaTree () != null ,
331- response .getQuestion () != null );
332- }
333-
334- // 2. 创建修改后的响应对象,包含currentNodeId和parentNodeId
335- Map <String , Object > modifiedResponse = new HashMap <>();
336- modifiedResponse .put ("question" , response .getQuestion ());
337- modifiedResponse .put ("currentNodeId" , currentNodeId != null ? currentNodeId : response .getParentId ());
338- modifiedResponse .put ("parentNodeId" , response .getParentId ());
339-
340- // 3. 发送SSE消息给前端
341- emitter .send (SseEmitter .event ()
342- .name ("message" )
343- .data (modifiedResponse ));
344- log .info ("SSE消息发送成功 - 会话: {}, 当前节点ID: {}" , sessionId , currentNodeId );
345- } catch (IOException e ) {
346- log .error ("SSE消息发送失败 - 会话: {}, 错误: {}" , sessionId , e .getMessage ());
347- sseEmitters .remove (sessionId );
348- } catch (Exception e ) {
349- log .error ("添加问题到QaTree失败 - 会话: {}, 错误: {}" , sessionId , e .getMessage ());
350- // 即使QaTree更新失败,仍然发送SSE消息给前端
351- try {
352- Map <String , Object > fallbackResponse = new HashMap <>();
353- fallbackResponse .put ("question" , response .getQuestion ());
354- fallbackResponse .put ("currentNodeId" , response .getParentId ()); // 使用parentId作为fallback
355- fallbackResponse .put ("parentNodeId" , response .getParentId ());
356-
357- emitter .send (SseEmitter .event ()
358- .name ("message" )
359- .data (fallbackResponse ));
360- log .info ("SSE消息发送成功(QaTree更新失败但消息已发送) - 会话: {}" , sessionId );
361- } catch (IOException ioException ) {
362- log .error ("SSE消息发送失败 - 会话: {}, 错误: {}" , sessionId , ioException .getMessage ());
363- sseEmitters .remove (sessionId );
364- }
365- }
366- } else {
367- log .warn ("SSE连接不存在 - 会话: {}" , sessionId );
368- }
369- }
370-
371291 /**
372292 * 获取SSE连接状态
373293 */
374294 @ GetMapping ("/sse-status" )
375295 public ResponseEntity <Map <String , Object >> getSseStatus () {
376- Map <String , Object > status = new ConcurrentHashMap <>();
377- status .put ("connectedSessions" , sseEmitters .keySet ());
378- status .put ("totalConnections" , sseEmitters .size ());
379- status .put ("timestamp" , System .currentTimeMillis ());
380- return ResponseEntity .ok (status );
296+ return ResponseEntity .ok (sseNotificationService .getSseStatus ());
381297 }
382298}
0 commit comments