Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 66 additions & 33 deletions src/main/java/com/llmproxy/controller/LlmProxyController.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.springframework.http.ResponseEntity;
import org.springframework.web.bind.annotation.*;

import java.nio.charset.StandardCharsets;
import java.time.Instant;
import java.util.Map;
import java.util.UUID;
Expand All @@ -29,6 +30,17 @@
@Slf4j
public class LlmProxyController {
private static final int MAX_QUERY_LENGTH = 32000;
private static final String RATE_LIMIT_ERROR = "Rate limit exceeded. Please try again later.";
private static final String EMPTY_QUERY_ERROR = "Query cannot be empty";
private static final String VALIDATION_ERROR_TYPE = "validation_error";
private static final String RATE_LIMIT_ERROR_TYPE = "rate_limit";
private static final String INTERNAL_ERROR_TYPE = "internal_error";

// Cache MediaType objects
private static final MediaType TEXT_PLAIN = MediaType.TEXT_PLAIN;
private static final MediaType APPLICATION_PDF = MediaType.APPLICATION_PDF;
private static final MediaType APPLICATION_DOCX = MediaType.parseMediaType(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document");

private final RouterService routerService;
private final LlmClientFactory clientFactory;
Expand All @@ -40,29 +52,32 @@ public ResponseEntity<QueryResponse> query(@RequestBody QueryRequest request, Ht
String clientIp = getClientIp(httpRequest);
if (!rateLimiterService.allowClient(clientIp)) {
log.warn("Rate limit exceeded for client: {}", clientIp);
Instant now = Instant.now();
return ResponseEntity.status(HttpStatus.TOO_MANY_REQUESTS)
.body(QueryResponse.builder()
.error("Rate limit exceeded. Please try again later.")
.errorType("rate_limit")
.timestamp(Instant.now())
.error(RATE_LIMIT_ERROR)
.errorType(RATE_LIMIT_ERROR_TYPE)
.timestamp(now)
.build());
}

if (request.getQuery() == null || request.getQuery().isEmpty()) {
Instant now = Instant.now();
return ResponseEntity.badRequest()
.body(QueryResponse.builder()
.error("Query cannot be empty")
.errorType("validation_error")
.timestamp(Instant.now())
.error(EMPTY_QUERY_ERROR)
.errorType(VALIDATION_ERROR_TYPE)
.timestamp(now)
.build());
}

if (request.getQuery().length() > MAX_QUERY_LENGTH) {
Instant now = Instant.now();
return ResponseEntity.badRequest()
.body(QueryResponse.builder()
.error("Query exceeds maximum length of " + MAX_QUERY_LENGTH + " characters")
.errorType("validation_error")
.timestamp(Instant.now())
.errorType(VALIDATION_ERROR_TYPE)
.timestamp(now)
.build());
}

Expand All @@ -89,11 +104,13 @@ public ResponseEntity<QueryResponse> query(@RequestBody QueryRequest request, Ht

QueryResult result = client.query(request.getQuery(), request.getModelVersion());

Instant responseTime = Instant.now();

QueryResponse response = QueryResponse.builder()
.response(result.getResponse())
.model(modelType)
.responseTimeMs(Instant.now().toEpochMilli() - startTime)
.timestamp(Instant.now())
.responseTimeMs(responseTime.toEpochMilli() - startTime)
.timestamp(responseTime)
.cached(false)
.requestId(request.getRequestId())
.inputTokens(result.getInputTokens())
Expand All @@ -119,12 +136,14 @@ public ResponseEntity<QueryResponse> query(@RequestBody QueryRequest request, Ht
LlmClient fallbackClient = clientFactory.getClient(fallbackModel);
QueryResult result = fallbackClient.query(request.getQuery(), request.getModelVersion());

Instant responseTime = Instant.now();

QueryResponse response = QueryResponse.builder()
.response(result.getResponse())
.model(fallbackModel)
.originalModel(ModelType.fromString(e.getModel()))
.responseTimeMs(Instant.now().toEpochMilli() - startTime)
.timestamp(Instant.now())
.responseTimeMs(responseTime.toEpochMilli() - startTime)
.timestamp(responseTime)
.cached(false)
.requestId(request.getRequestId())
.inputTokens(result.getInputTokens())
Expand All @@ -148,31 +167,43 @@ public ResponseEntity<QueryResponse> query(@RequestBody QueryRequest request, Ht

log.error("Error processing query: {}", e.getMessage());

HttpStatus status = switch (e.getStatusCode()) {
case 401 -> HttpStatus.UNAUTHORIZED;
case 408 -> HttpStatus.REQUEST_TIMEOUT;
case 429 -> HttpStatus.TOO_MANY_REQUESTS;
case 503 -> HttpStatus.SERVICE_UNAVAILABLE;
default -> HttpStatus.INTERNAL_SERVER_ERROR;
};
HttpStatus status;
switch (e.getStatusCode()) {
case 401:
status = HttpStatus.UNAUTHORIZED;
break;
case 408:
status = HttpStatus.REQUEST_TIMEOUT;
break;
case 429:
status = HttpStatus.TOO_MANY_REQUESTS;
break;
case 503:
status = HttpStatus.SERVICE_UNAVAILABLE;
break;
default:
status = HttpStatus.INTERNAL_SERVER_ERROR;
}

Instant errorTime = Instant.now();
return ResponseEntity.status(status)
.body(QueryResponse.builder()
.error(e.getMessage())
.errorType(e.getClass().getSimpleName())
.model(ModelType.fromString(e.getModel()))
.timestamp(Instant.now())
.timestamp(errorTime)
.requestId(request.getRequestId())
.build());

} catch (Exception e) {
log.error("Unexpected error processing query: {}", e.getMessage(), e);

Instant errorTime = Instant.now();
return ResponseEntity.status(HttpStatus.INTERNAL_SERVER_ERROR)
.body(QueryResponse.builder()
.error("Internal server error: " + e.getMessage())
.errorType("internal_error")
.timestamp(Instant.now())
.errorType(INTERNAL_ERROR_TYPE)
.timestamp(errorTime)
.requestId(request.getRequestId())
.build());
}
Expand All @@ -198,9 +229,10 @@ public ResponseEntity<Map<String, Object>> health(HttpServletRequest httpRequest
return ResponseEntity.status(HttpStatus.TOO_MANY_REQUESTS).build();
}

Instant now = Instant.now();
return ResponseEntity.ok(Map.of(
"status", "ok",
"timestamp", Instant.now()
"timestamp", now
));
}

Expand All @@ -212,42 +244,43 @@ public ResponseEntity<byte[]> download(@RequestBody Map<String, String> request,
return ResponseEntity.status(HttpStatus.TOO_MANY_REQUESTS).build();
}

String response = request.get("response");
String format = request.get("format");

if (response == null || response.isEmpty()) {
String responseText = request.get("response");
if (responseText == null || responseText.isEmpty()) {
return ResponseEntity.badRequest().build();
}

String format = request.get("format");
if (format == null) {
format = "txt";
}

byte[] responseBytes = responseText.getBytes(StandardCharsets.UTF_8);

MediaType mediaType;
String filename;

switch (format.toLowerCase()) {
case "txt":
mediaType = MediaType.TEXT_PLAIN;
mediaType = TEXT_PLAIN;
filename = "llm_response.txt";
break;
case "pdf":
mediaType = MediaType.APPLICATION_PDF;
mediaType = APPLICATION_PDF;
filename = "llm_response.pdf";
break;
case "docx":
mediaType = MediaType.parseMediaType("application/vnd.openxmlformats-officedocument.wordprocessingml.document");
mediaType = APPLICATION_DOCX;
filename = "llm_response.docx";
break;
default:
return ResponseEntity.badRequest().body(
"Unsupported format. Supported formats are: txt, pdf, docx.".getBytes());
"Unsupported format. Supported formats are: txt, pdf, docx.".getBytes(StandardCharsets.UTF_8));
}

return ResponseEntity.ok()
.contentType(mediaType)
.header("Content-Disposition", "attachment; filename=" + filename)
.body(response.getBytes());
.body(responseBytes);
}

private String getClientIp(HttpServletRequest request) {
Expand All @@ -257,4 +290,4 @@ private String getClientIp(HttpServletRequest request) {
}
return request.getRemoteAddr();
}
}
}