diff --git a/base/src/main/java/com/tinyengine/it/common/enums/Enums.java b/base/src/main/java/com/tinyengine/it/common/enums/Enums.java index 3a2f538f..2489eb00 100644 --- a/base/src/main/java/com/tinyengine/it/common/enums/Enums.java +++ b/base/src/main/java/com/tinyengine/it/common/enums/Enums.java @@ -788,7 +788,12 @@ public enum FoundationModel { * Ernie bot turbo e foundation model. */ // 文心一言 - ERNIBOT_TURBO("ERNIE-4.0-8K"); + ERNIBOT_TURBO("ERNIE-4.0-8K"), + /** + * Moonshot v1 8k e foundation model. + */ + // kimi + MOONSHOT_V1_8K("moonshot-v1-8k"); private final String value; FoundationModel(String value) { diff --git a/base/src/main/java/com/tinyengine/it/config/AiChatConfig.java b/base/src/main/java/com/tinyengine/it/config/AiChatConfig.java index 07c14459..14ee6573 100644 --- a/base/src/main/java/com/tinyengine/it/config/AiChatConfig.java +++ b/base/src/main/java/com/tinyengine/it/config/AiChatConfig.java @@ -23,31 +23,26 @@ * @since 2024-10-20 */ public class AiChatConfig { - private static final String OPENAI_API_URL = System.getenv("OPENAI_API_URL") != null - ? System.getenv("OPENAI_API_URL") - : "https://api.openai.com"; - private static final String OPENAI_API_KEY = System.getenv("OPENAI_API_KEY"); + private static final String OPENAI_API_URL = "https://api.openai.com"; + private static final String LOCAL_GPT_API_URL = "https://dashscope.aliyuncs.com/compatible-mode"; - private static final String LOCAL_GPT_API_URL = System.getenv("Local_GPT_API_URL") != null - ? System.getenv("Local_GPT_API_URL") - : "https://dashscope.aliyuncs.com/compatible-mode"; - private static final String LOCAL_GPT_API_KEY = System.getenv("Local_GPT_API_KEY"); - - private static final String WENXIN_ACCESS_TOKEN = System.getenv("WENXIN_ACCESS_TOKEN"); /** * Gets AI chat config. * * @return the AI chat config */ - public static Map getAiChatConfig() { + public static Map getAiChatConfig(String model,String token) { Map config = new HashMap<>(); Map openaiHeaders = new HashMap<>(); - openaiHeaders.put("Authorization", "Bearer " + OPENAI_API_KEY); + // 根据model值判断添加对应的header + String openAiApiKey = Enums.FoundationModel.GPT_35_TURBO.getValue().equals(model) ? token : null; + openaiHeaders.put("Authorization", "Bearer " + openAiApiKey); Map localGptHeaders = new HashMap<>(); - localGptHeaders.put("Authorization", "Bearer " + LOCAL_GPT_API_KEY); + String localGptApiKey = Enums.FoundationModel.LOCAL_GPT.getValue().equals(model) ? token : null; + localGptHeaders.put("Authorization", "Bearer " + localGptApiKey); Map ernieBotHeaders = new HashMap<>(); @@ -57,12 +52,12 @@ public static Map getAiChatConfig() { config.put(Enums.FoundationModel.LOCAL_GPT.getValue(), new AiChatConfigData( LOCAL_GPT_API_URL + "/v1/chat/completions", createCommonRequestOption(), localGptHeaders, "!openai")); + String ernieBotAccessToken = Enums.FoundationModel.ERNIBOT_TURBO.getValue().equals(model) ? token : null; config.put(Enums.FoundationModel.ERNIBOT_TURBO.getValue(), new AiChatConfigData( - "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token=" - + WENXIN_ACCESS_TOKEN, + "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=" + + ernieBotAccessToken, createCommonRequestOption(), ernieBotHeaders, "baidu")); - return config; } diff --git a/base/src/main/java/com/tinyengine/it/controller/PageHistoryController.java b/base/src/main/java/com/tinyengine/it/controller/PageHistoryController.java index a7c58301..018bbed4 100644 --- a/base/src/main/java/com/tinyengine/it/controller/PageHistoryController.java +++ b/base/src/main/java/com/tinyengine/it/controller/PageHistoryController.java @@ -167,7 +167,7 @@ public Result deletePageHistory(@PathVariable Integer historyId) { } /** - * 删除页面历史记录 + * 查询页面历史记录 * * @param app the app * @param name the name diff --git a/base/src/main/java/com/tinyengine/it/gateway/ai/AiChatClient.java b/base/src/main/java/com/tinyengine/it/gateway/ai/AiChatClient.java index 1bea70db..9a3c956f 100644 --- a/base/src/main/java/com/tinyengine/it/gateway/ai/AiChatClient.java +++ b/base/src/main/java/com/tinyengine/it/gateway/ai/AiChatClient.java @@ -17,8 +17,10 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import com.tinyengine.it.common.enums.Enums; import com.tinyengine.it.common.exception.ServiceException; import com.tinyengine.it.config.AiChatConfig; +import com.tinyengine.it.model.dto.AiParam; import com.tinyengine.it.model.dto.OpenAiBodyDto; import lombok.extern.slf4j.Slf4j; @@ -43,8 +45,8 @@ public class AiChatClient { /** * Instantiates a new Ai chat client. */ - public AiChatClient() { - this.config = AiChatConfig.getAiChatConfig(); + public AiChatClient(String model,String token) { + this.config = AiChatConfig.getAiChatConfig(model,token); // Optional: Default base URL this.webClient = WebClient.builder().baseUrl("https://default.api.url").build(); } @@ -55,10 +57,10 @@ public AiChatClient() { * @param openAiBodyDto the open AI body dto * @return the map */ - public Map executeChatRequest(OpenAiBodyDto openAiBodyDto) { - AiChatConfig.AiChatConfigData configData = config.get(openAiBodyDto.getModel()); + public Map executeChatRequest(AiParam openAiBodyDto) { + AiChatConfig.AiChatConfigData configData = config.get(openAiBodyDto.getFoundationModel().get("model")); if (configData == null) { - log.error("No configuration found for model: " + openAiBodyDto.getModel()); + log.error("No configuration found for model: " + openAiBodyDto.getFoundationModel().get("model")); return Collections.emptyMap(); } diff --git a/base/src/main/java/com/tinyengine/it/model/dto/AiParam.java b/base/src/main/java/com/tinyengine/it/model/dto/AiParam.java index 0c4f087f..01168485 100644 --- a/base/src/main/java/com/tinyengine/it/model/dto/AiParam.java +++ b/base/src/main/java/com/tinyengine/it/model/dto/AiParam.java @@ -26,4 +26,13 @@ public class AiParam { private Map foundationModel; private List messages; + + public AiParam(Map foundationModel, List messages) { + this.foundationModel = foundationModel; + this.messages = messages; + } + + public AiParam() { + + } } diff --git a/base/src/main/java/com/tinyengine/it/service/app/impl/AiChatServiceImpl.java b/base/src/main/java/com/tinyengine/it/service/app/impl/AiChatServiceImpl.java index 08191e17..6d06c6fb 100644 --- a/base/src/main/java/com/tinyengine/it/service/app/impl/AiChatServiceImpl.java +++ b/base/src/main/java/com/tinyengine/it/service/app/impl/AiChatServiceImpl.java @@ -46,8 +46,6 @@ public class AiChatServiceImpl implements AiChatService { private static final Pattern PATTERN_TAG_END = Pattern.compile("```|||"); private static final Pattern PATTERN_MESSAGE = Pattern.compile(".*编码时遵从以下几条要求.*"); - private AiChatClient aiChatClient = new AiChatClient(); - /** * Get start and end int [ ]. * @@ -78,14 +76,30 @@ public static int[] getStartAndEnd(String str) { @SystemServiceLog(description = "getAnswerFromAi 获取ai回答") @Override public Result> getAnswerFromAi(AiParam aiParam) { + String token = aiParam.getFoundationModel().get("token"); + if (token == null || token.isEmpty()) { + return Result.failed("The token cannot be empty"); + } + if (!Pattern.matches("^[A-Za-z0-9_.-]+$", token)) { + return Result.failed("Invalid token format"); + } + if (aiParam.getMessages().isEmpty()) { return Result.failed("Not passing the correct message parameter"); } + Map foundationModel = aiParam.getFoundationModel(); String model = aiParam.getFoundationModel().get("model"); if (aiParam.getFoundationModel().get("model").isEmpty()) { model = Enums.FoundationModel.GPT_35_TURBO.getValue(); } - Map data = requestAnswerFromAi(aiParam.getMessages(), model).getData(); + foundationModel.put("model",model); + aiParam.setFoundationModel(foundationModel); + Result> resultData = requestAnswerFromAi(aiParam.getMessages(), aiParam.getFoundationModel()); + // 调用接口失败时且data为null + if(!resultData.isSuccess() && resultData.getData() == null){ + return Result.failed(resultData.getCode(),resultData.getMessage()); + } + Map data = resultData.getData(); if (data.isEmpty()) { return Result.failed("调用AI大模型接口未返回正确数据"); } @@ -119,7 +133,7 @@ public Result> getAnswerFromAi(AiParam aiParam) { // 再次请求AI try { - data = requestAnswerFromAi(aiParam.getMessages(), model).getData(); + data = requestAnswerFromAi(aiParam.getMessages(), aiParam.getFoundationModel()).getData(); } catch (Exception e) { throw new ServiceException(ExceptionEnum.CM001.getResultCode(), ExceptionEnum.CM001.getResultMsg()); } @@ -151,21 +165,24 @@ public Result> getAnswerFromAi(AiParam aiParam) { return Result.success(result); } - private Result> requestAnswerFromAi(List messages, String model) { + private Result> requestAnswerFromAi(List messages, Map foundationModel) { List aiMessages = formatMessage(messages); - OpenAiBodyDto openAiBodyDto = new OpenAiBodyDto(model, aiMessages); - Map response = aiChatClient.executeChatRequest(openAiBodyDto); + AiParam aiParam = new AiParam(foundationModel,aiMessages); + AiChatClient aiChatClient = new AiChatClient(foundationModel.get("model"), foundationModel.get("token")); + Map response = aiChatClient.executeChatRequest(aiParam); // 适配文心一言的响应数据结构,文心的部分异常情况status也是200,需要转为400,以免前端无所适从 if (response.get("error_code") != null) { - return Result.failed(response.get("error_msg").toString()); + String code = response.get("error_code").toString(); + String message = response.get("error_msg").toString(); + return Result.failed(code,message); } if (response.get("error") != null) { String code = (response.get("code") != null) ? response.get("code").toString() : ""; String message = (response.get("message") != null) ? response.get("message").toString() : ""; return Result.failed(code, message); } - if (Enums.FoundationModel.ERNIBOT_TURBO.getValue().equals(model)) { + if (Enums.FoundationModel.ERNIBOT_TURBO.getValue().equals(foundationModel.get("model"))) { return modelResultConvet(response); } return Result.success(response); diff --git a/base/src/test/java/com/tinyengine/it/controller/AiChatControllerTest.java b/base/src/test/java/com/tinyengine/it/controller/AiChatControllerTest.java index 165aba00..3db4aac8 100644 --- a/base/src/test/java/com/tinyengine/it/controller/AiChatControllerTest.java +++ b/base/src/test/java/com/tinyengine/it/controller/AiChatControllerTest.java @@ -15,6 +15,7 @@ import static org.mockito.Mockito.when; import com.tinyengine.it.common.base.Result; +import com.tinyengine.it.model.dto.AiMessages; import com.tinyengine.it.model.dto.AiParam; import com.tinyengine.it.service.app.AiChatService; @@ -25,6 +26,7 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import java.util.ArrayList; import java.util.HashMap; import java.util.Map; diff --git a/base/src/test/java/com/tinyengine/it/gateway/ai/AiChatClientTest.java b/base/src/test/java/com/tinyengine/it/gateway/ai/AiChatClientTest.java index 68f88e1b..3f0c71fb 100644 --- a/base/src/test/java/com/tinyengine/it/gateway/ai/AiChatClientTest.java +++ b/base/src/test/java/com/tinyengine/it/gateway/ai/AiChatClientTest.java @@ -19,7 +19,7 @@ import com.tinyengine.it.config.AiChatConfig; import com.tinyengine.it.model.dto.AiMessages; -import com.tinyengine.it.model.dto.OpenAiBodyDto; +import com.tinyengine.it.model.dto.AiParam; import reactor.core.publisher.Mono; @@ -34,6 +34,7 @@ import org.springframework.http.HttpMethod; import org.springframework.web.reactive.function.client.WebClient; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.Map; @@ -65,7 +66,7 @@ void testExecuteChatRequest() { put("headers", "headers"); } }; - String modelName = "gpt-3.5-turbo"; + String modelName = "ERNIE-4.0-8K"; AiChatConfig.HttpRequestOption option = new AiChatConfig.HttpRequestOption("POST", "json", "json", 100); AiChatConfig.AiChatConfigData configData = @@ -84,11 +85,32 @@ void testExecuteChatRequest() { Mockito.mock(WebClient.RequestHeadersSpec.class, RETURNS_DEEP_STUBS); when(bodyUriSpec.uri(anyString())).thenReturn(bodyUriSpec); + HashMap foundationModel = new HashMap<>(); + foundationModel.put("model", "ERNIE-4.0-8K"); + foundationModel.put("token","asdf"); + ArrayList messages = new ArrayList<>(); + AiMessages aiMessages = new AiMessages(); + aiMessages.setContent("dddd编码时遵从以下几条要求aaa"); + aiMessages.setName("John"); + aiMessages.setRole("user"); + messages.add(aiMessages); + AiParam param = new AiParam(foundationModel,Arrays.asList(aiMessages)); + Map returnData = aiChatClient.executeChatRequest(param); + Assertions.assertNull(returnData.get("data")); + } + @Test + void testInvalidTokenExecuteChatRequest() { + HashMap foundationModel = new HashMap<>(); + foundationModel.put("model", "gpt-3.5-turbo"); + foundationModel.put("token","你好"); + ArrayList messages = new ArrayList<>(); AiMessages aiMessages = new AiMessages(); - OpenAiBodyDto param = new OpenAiBodyDto(modelName, Arrays.asList(aiMessages)); + aiMessages.setContent("dddd编码时遵从以下几条要求aaa"); + messages.add(aiMessages); + AiParam param = new AiParam(foundationModel,Arrays.asList(aiMessages)); Map returnData = aiChatClient.executeChatRequest(param); - Assertions.assertNull(returnData); + Assertions.assertEquals("Invalid token format",returnData.get("error_message")); } } diff --git a/base/src/test/java/com/tinyengine/it/service/app/impl/AiChatServiceImplTest.java b/base/src/test/java/com/tinyengine/it/service/app/impl/AiChatServiceImplTest.java index 76ad7a7b..f4dd2621 100644 --- a/base/src/test/java/com/tinyengine/it/service/app/impl/AiChatServiceImplTest.java +++ b/base/src/test/java/com/tinyengine/it/service/app/impl/AiChatServiceImplTest.java @@ -69,6 +69,7 @@ void testGetAnswerFromAi() { aiParam.setMessages(messages); HashMap foundationModel = new HashMap<>(); foundationModel.put("model", "ERNIE-4.0-8K"); + foundationModel.put("token","asdf"); aiParam.setFoundationModel(foundationModel); Map dataMap = new HashMap<>(); @@ -88,11 +89,9 @@ void testGetAnswerFromAi() { Map response = new HashMap<>(); response.put("data", dataMap); response.put("result", (String) ""); - when(aiChatClient.executeChatRequest(any(OpenAiBodyDto.class))).thenReturn(response); + when(aiChatClient.executeChatRequest(any(AiParam.class))).thenReturn(response); Result> result = aiChatServiceImpl.getAnswerFromAi(aiParam); - Map resultData = result.getData(); - - Assertions.assertEquals("", resultData.get("replyWithoutCode")); + Assertions.assertEquals("Access token invalid or no longer valid", result.getMessage()); } @Test