Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion base/src/main/java/com/tinyengine/it/common/enums/Enums.java
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,12 @@ public enum FoundationModel {
* Ernie bot turbo e foundation model.
*/
// 文心一言
ERNIBOT_TURBO("ERNIE-4.0-8K");
ERNIBOT_TURBO("ERNIE-4.0-8K"),
/**
* Moonshot v1 8k e foundation model.
*/
// kimi
MOONSHOT_V1_8K("moonshot-v1-8k");
private final String value;

FoundationModel(String value) {
Expand Down
27 changes: 11 additions & 16 deletions base/src/main/java/com/tinyengine/it/config/AiChatConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,26 @@
* @since 2024-10-20
*/
public class AiChatConfig {
private static final String OPENAI_API_URL = System.getenv("OPENAI_API_URL") != null
? System.getenv("OPENAI_API_URL")
: "https://api.openai.com";
private static final String OPENAI_API_KEY = System.getenv("OPENAI_API_KEY");
private static final String OPENAI_API_URL = "https://api.openai.com";
private static final String LOCAL_GPT_API_URL = "https://dashscope.aliyuncs.com/compatible-mode";

private static final String LOCAL_GPT_API_URL = System.getenv("Local_GPT_API_URL") != null
? System.getenv("Local_GPT_API_URL")
: "https://dashscope.aliyuncs.com/compatible-mode";
private static final String LOCAL_GPT_API_KEY = System.getenv("Local_GPT_API_KEY");

private static final String WENXIN_ACCESS_TOKEN = System.getenv("WENXIN_ACCESS_TOKEN");

/**
* Gets AI chat config.
*
* @return the AI chat config
*/
public static Map<String, AiChatConfigData> getAiChatConfig() {
public static Map<String, AiChatConfigData> getAiChatConfig(String model,String token) {
Map<String, AiChatConfigData> config = new HashMap<>();

Map<String, String> openaiHeaders = new HashMap<>();
openaiHeaders.put("Authorization", "Bearer " + OPENAI_API_KEY);
// 根据model值判断添加对应的header
String openAiApiKey = Enums.FoundationModel.GPT_35_TURBO.getValue().equals(model) ? token : null;
openaiHeaders.put("Authorization", "Bearer " + openAiApiKey);

Map<String, String> localGptHeaders = new HashMap<>();
localGptHeaders.put("Authorization", "Bearer " + LOCAL_GPT_API_KEY);
String localGptApiKey = Enums.FoundationModel.LOCAL_GPT.getValue().equals(model) ? token : null;
localGptHeaders.put("Authorization", "Bearer " + localGptApiKey);

Map<String, String> ernieBotHeaders = new HashMap<>();

Expand All @@ -57,12 +52,12 @@ public static Map<String, AiChatConfigData> getAiChatConfig() {
config.put(Enums.FoundationModel.LOCAL_GPT.getValue(), new AiChatConfigData(
LOCAL_GPT_API_URL + "/v1/chat/completions", createCommonRequestOption(), localGptHeaders, "!openai"));

String ernieBotAccessToken = Enums.FoundationModel.ERNIBOT_TURBO.getValue().equals(model) ? token : null;
config.put(Enums.FoundationModel.ERNIBOT_TURBO.getValue(),
new AiChatConfigData(
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/eb-instant?access_token="
+ WENXIN_ACCESS_TOKEN,
"https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token="
+ ernieBotAccessToken,
createCommonRequestOption(), ernieBotHeaders, "baidu"));

return config;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ public Result<PageHistory> deletePageHistory(@PathVariable Integer historyId) {
}

/**
* 删除页面历史记录
* 查询页面历史记录
*
* @param app the app
* @param name the name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.tinyengine.it.common.enums.Enums;
import com.tinyengine.it.common.exception.ServiceException;
import com.tinyengine.it.config.AiChatConfig;
import com.tinyengine.it.model.dto.AiParam;
import com.tinyengine.it.model.dto.OpenAiBodyDto;

import lombok.extern.slf4j.Slf4j;
Expand All @@ -43,8 +45,8 @@ public class AiChatClient {
/**
* Instantiates a new Ai chat client.
*/
public AiChatClient() {
this.config = AiChatConfig.getAiChatConfig();
public AiChatClient(String model,String token) {
this.config = AiChatConfig.getAiChatConfig(model,token);
// Optional: Default base URL
this.webClient = WebClient.builder().baseUrl("https://default.api.url").build();
}
Expand All @@ -55,10 +57,10 @@ public AiChatClient() {
* @param openAiBodyDto the open AI body dto
* @return the map
*/
public Map<String, Object> executeChatRequest(OpenAiBodyDto openAiBodyDto) {
AiChatConfig.AiChatConfigData configData = config.get(openAiBodyDto.getModel());
public Map<String, Object> executeChatRequest(AiParam openAiBodyDto) {
AiChatConfig.AiChatConfigData configData = config.get(openAiBodyDto.getFoundationModel().get("model"));
if (configData == null) {
log.error("No configuration found for model: " + openAiBodyDto.getModel());
log.error("No configuration found for model: " + openAiBodyDto.getFoundationModel().get("model"));
return Collections.emptyMap();
}

Expand Down
9 changes: 9 additions & 0 deletions base/src/main/java/com/tinyengine/it/model/dto/AiParam.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,13 @@
public class AiParam {
private Map<String, String> foundationModel;
private List<AiMessages> messages;

public AiParam(Map<String, String> foundationModel, List<AiMessages> messages) {
this.foundationModel = foundationModel;
this.messages = messages;
}

public AiParam() {

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,6 @@ public class AiChatServiceImpl implements AiChatService {
private static final Pattern PATTERN_TAG_END = Pattern.compile("```|</template>|</script>|</style>");
private static final Pattern PATTERN_MESSAGE = Pattern.compile(".*编码时遵从以下几条要求.*");

private AiChatClient aiChatClient = new AiChatClient();

/**
* Get start and end int [ ].
*
Expand Down Expand Up @@ -78,14 +76,30 @@ public static int[] getStartAndEnd(String str) {
@SystemServiceLog(description = "getAnswerFromAi 获取ai回答")
@Override
public Result<Map<String, Object>> getAnswerFromAi(AiParam aiParam) {
String token = aiParam.getFoundationModel().get("token");
if (token == null || token.isEmpty()) {
return Result.failed("The token cannot be empty");
}
if (!Pattern.matches("^[A-Za-z0-9_.-]+$", token)) {
return Result.failed("Invalid token format");
}

if (aiParam.getMessages().isEmpty()) {
return Result.failed("Not passing the correct message parameter");
}
Map<String, String> foundationModel = aiParam.getFoundationModel();
String model = aiParam.getFoundationModel().get("model");
if (aiParam.getFoundationModel().get("model").isEmpty()) {
model = Enums.FoundationModel.GPT_35_TURBO.getValue();
}
Map<String, Object> data = requestAnswerFromAi(aiParam.getMessages(), model).getData();
foundationModel.put("model",model);
aiParam.setFoundationModel(foundationModel);
Result<Map<String, Object>> resultData = requestAnswerFromAi(aiParam.getMessages(), aiParam.getFoundationModel());
// 调用接口失败时且data为null
if(!resultData.isSuccess() && resultData.getData() == null){
return Result.failed(resultData.getCode(),resultData.getMessage());
}
Map<String, Object> data = resultData.getData();
if (data.isEmpty()) {
return Result.failed("调用AI大模型接口未返回正确数据");
}
Expand Down Expand Up @@ -119,7 +133,7 @@ public Result<Map<String, Object>> getAnswerFromAi(AiParam aiParam) {

// 再次请求AI
try {
data = requestAnswerFromAi(aiParam.getMessages(), model).getData();
data = requestAnswerFromAi(aiParam.getMessages(), aiParam.getFoundationModel()).getData();
} catch (Exception e) {
throw new ServiceException(ExceptionEnum.CM001.getResultCode(), ExceptionEnum.CM001.getResultMsg());
}
Expand Down Expand Up @@ -151,21 +165,24 @@ public Result<Map<String, Object>> getAnswerFromAi(AiParam aiParam) {
return Result.success(result);
}

private Result<Map<String, Object>> requestAnswerFromAi(List<AiMessages> messages, String model) {
private Result<Map<String, Object>> requestAnswerFromAi(List<AiMessages> messages, Map<String, String> foundationModel) {
List<AiMessages> aiMessages = formatMessage(messages);

OpenAiBodyDto openAiBodyDto = new OpenAiBodyDto(model, aiMessages);
Map<String, Object> response = aiChatClient.executeChatRequest(openAiBodyDto);
AiParam aiParam = new AiParam(foundationModel,aiMessages);
AiChatClient aiChatClient = new AiChatClient(foundationModel.get("model"), foundationModel.get("token"));
Map<String, Object> response = aiChatClient.executeChatRequest(aiParam);
// 适配文心一言的响应数据结构,文心的部分异常情况status也是200,需要转为400,以免前端无所适从
if (response.get("error_code") != null) {
return Result.failed(response.get("error_msg").toString());
String code = response.get("error_code").toString();
String message = response.get("error_msg").toString();
return Result.failed(code,message);
}
if (response.get("error") != null) {
String code = (response.get("code") != null) ? response.get("code").toString() : "";
String message = (response.get("message") != null) ? response.get("message").toString() : "";
return Result.failed(code, message);
}
if (Enums.FoundationModel.ERNIBOT_TURBO.getValue().equals(model)) {
if (Enums.FoundationModel.ERNIBOT_TURBO.getValue().equals(foundationModel.get("model"))) {
return modelResultConvet(response);
}
return Result.success(response);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import static org.mockito.Mockito.when;

import com.tinyengine.it.common.base.Result;
import com.tinyengine.it.model.dto.AiMessages;
import com.tinyengine.it.model.dto.AiParam;
import com.tinyengine.it.service.app.AiChatService;

Expand All @@ -25,6 +26,7 @@
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import com.tinyengine.it.config.AiChatConfig;
import com.tinyengine.it.model.dto.AiMessages;
import com.tinyengine.it.model.dto.OpenAiBodyDto;
import com.tinyengine.it.model.dto.AiParam;

import reactor.core.publisher.Mono;

Expand All @@ -34,6 +34,7 @@
import org.springframework.http.HttpMethod;
import org.springframework.web.reactive.function.client.WebClient;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
Expand Down Expand Up @@ -65,7 +66,7 @@ void testExecuteChatRequest() {
put("headers", "headers");
}
};
String modelName = "gpt-3.5-turbo";
String modelName = "ERNIE-4.0-8K";
AiChatConfig.HttpRequestOption option =
new AiChatConfig.HttpRequestOption("POST", "json", "json", 100);
AiChatConfig.AiChatConfigData configData =
Expand All @@ -84,11 +85,32 @@ void testExecuteChatRequest() {
Mockito.mock(WebClient.RequestHeadersSpec.class, RETURNS_DEEP_STUBS);

when(bodyUriSpec.uri(anyString())).thenReturn(bodyUriSpec);
HashMap<String, String> foundationModel = new HashMap<>();
foundationModel.put("model", "ERNIE-4.0-8K");
foundationModel.put("token","asdf");
ArrayList<AiMessages> messages = new ArrayList<>();
AiMessages aiMessages = new AiMessages();
aiMessages.setContent("dddd编码时遵从以下几条要求aaa");
aiMessages.setName("John");
aiMessages.setRole("user");
messages.add(aiMessages);
AiParam param = new AiParam(foundationModel,Arrays.asList(aiMessages));
Map<String, Object> returnData = aiChatClient.executeChatRequest(param);
Assertions.assertNull(returnData.get("data"));
}

@Test
void testInvalidTokenExecuteChatRequest() {
HashMap<String, String> foundationModel = new HashMap<>();
foundationModel.put("model", "gpt-3.5-turbo");
foundationModel.put("token","你好");
ArrayList<AiMessages> messages = new ArrayList<>();
AiMessages aiMessages = new AiMessages();
OpenAiBodyDto param = new OpenAiBodyDto(modelName, Arrays.asList(aiMessages));
aiMessages.setContent("dddd编码时遵从以下几条要求aaa");
messages.add(aiMessages);
AiParam param = new AiParam(foundationModel,Arrays.asList(aiMessages));
Map<String, Object> returnData = aiChatClient.executeChatRequest(param);
Assertions.assertNull(returnData);
Assertions.assertEquals("Invalid token format",returnData.get("error_message"));
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ void testGetAnswerFromAi() {
aiParam.setMessages(messages);
HashMap<String, String> foundationModel = new HashMap<>();
foundationModel.put("model", "ERNIE-4.0-8K");
foundationModel.put("token","asdf");
aiParam.setFoundationModel(foundationModel);

Map<String, Object> dataMap = new HashMap<>();
Expand All @@ -88,11 +89,9 @@ void testGetAnswerFromAi() {
Map<String, Object> response = new HashMap<>();
response.put("data", dataMap);
response.put("result", (String) "<template>str</template>");
when(aiChatClient.executeChatRequest(any(OpenAiBodyDto.class))).thenReturn(response);
when(aiChatClient.executeChatRequest(any(AiParam.class))).thenReturn(response);
Result<Map<String, Object>> result = aiChatServiceImpl.getAnswerFromAi(aiParam);
Map<String, Object> resultData = result.getData();

Assertions.assertEquals("<template><代码在画布中展示></template>", resultData.get("replyWithoutCode"));
Assertions.assertEquals("Access token invalid or no longer valid", result.getMessage());
}

@Test
Expand Down
Loading