Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
import com.spots.domain.ai.dto.response.WeeklyRecommendResponse;
import com.spots.domain.ai.service.RecommendService;
import com.spots.global.exception.ApiResponse;
import java.util.concurrent.CompletableFuture;
import lombok.RequiredArgsConstructor;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;

import java.util.concurrent.CompletableFuture;

@RestController
@RequestMapping("/api/v1/recommend")
@RequiredArgsConstructor
Expand All @@ -26,4 +27,5 @@ public CompletableFuture<ApiResponse<WeeklyRecommendResponse>> recommendWeeklyWo
.recommendWeeklyRoutine(request.toServiceRequest())
.thenApply(ApiResponse::success);
}

}
41 changes: 30 additions & 11 deletions src/main/java/com/spots/domain/ai/service/RecommendLLMService.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,32 @@
import java.io.InputStream;
import java.util.concurrent.Semaphore;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StopWatch;

@Service
@RequiredArgsConstructor
@Transactional(readOnly = true)
@Slf4j
public class RecommendLLMService {

private final ChatClient chatClient;
private final ObjectMapper objectMapper;
private final Semaphore llmSemaphore = new Semaphore(5);

public WeeklyRecommendResponse createWeeklyPlan(RecommendLLMRequest request) {
StopWatch stopWatch = new StopWatch("LLM_Generation_Task");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AOP 방식도 있을텐데, StopWatch를 사용하는 이유가 있을까요? 메서드 내 작업단위마다 측정하기 위해서 일까요 !?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. AOP는 메서드 단위로 동작하기 때문에, 메서드 내 어느 부분에서 시간이 오래걸리는지는 확인하기 어렵다.
  2. AOP는 public method에서만 동작하기 때문에 private으로 되어있는 convertToRecommendProgramData 내에서 병목지점을 파악하기 힘들다.

와 같은 이유가 있다고 찾았는데, 혹시 잘못된 내용있으면 수정 부탁드립니다!

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

오호 새로운 사실을 알게 되었습니다 :>


try {
stopWatch.start("1. Semaphore Acquire");
llmSemaphore.acquire();
stopWatch.stop();

String systemMessage = loadPrompt("prompt/routineV2.prompt");
stopWatch.start("2. Prompt & JSON Prep");
String systemMessage = loadPrompt("prompt/routineV4.prompt");
String userMessage = """
아래는 사용자 정보와 후보 운동 프로그램 목록입니다.
이를 기반으로 일주일 운동 루틴 포토카드를 만들 JSON을 생성해주세요.
Expand All @@ -40,22 +48,33 @@ public WeeklyRecommendResponse createWeeklyPlan(RecommendLLMRequest request) {
%s
</user_data>
""".formatted(toJson(request));
stopWatch.stop();

stopWatch.start("3. LLM API Call (External)");
String llmResponse = chatClient
.prompt()
.system(systemMessage)
.user(userMessage)
.call()
.content();
.prompt()
.system(systemMessage)
.user(userMessage)
.call()
.content();
stopWatch.stop();

stopWatch.start("4. Response Parsing");
WeeklyRecommendResponse response = objectMapper.readValue(sanitize(llmResponse), WeeklyRecommendResponse.class);
stopWatch.stop();

log.info(stopWatch.prettyPrint());

return objectMapper.readValue(sanitize(llmResponse), WeeklyRecommendResponse.class);
return response;

} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new CustomException(LLM_INTERRUPT_ERROR);
} catch (JsonProcessingException e) {
log.error("JSON 파싱 실패. Raw Response: {}", e.getMessage());
throw new CustomException(INVALID_JSON_RESPONSE);
} catch (Exception e) {
log.error("LLM 호출 중 알 수 없는 에러 발생", e);
throw new CustomException(LLM_INTERRUPT_ERROR);
} finally {
llmSemaphore.release();
Expand All @@ -73,16 +92,16 @@ private String loadPrompt(String filename) {

private String toJson(Object obj) {
try {
return objectMapper.writerWithDefaultPrettyPrinter().writeValueAsString(obj);
return objectMapper.writeValueAsString(obj);
} catch (Exception e) {
throw new CustomException(JSON_CONVERSION_ERROR);
}
}

private String sanitize(String raw) {
return raw
.replaceAll("(?i)```json", "")
.replaceAll("```", "")
.trim();
.replaceAll("(?i)```json", "")
.replaceAll("```", "")
.trim();
}
}
117 changes: 85 additions & 32 deletions src/main/java/com/spots/domain/ai/service/RecommendService.java
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,25 @@
import com.spots.domain.program.dto.response.ProgramDetailInfoResponse;
import com.spots.domain.program.dto.response.ProgramInfoResponse;
import com.spots.domain.program.dto.response.TransportData;
import com.spots.domain.program.entity.Program;
import com.spots.domain.program.repository.ProgramRepository;
import com.spots.domain.program.service.ProgramService;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Function;
import java.util.stream.Collectors;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import org.springframework.transaction.annotation.Transactional;
import org.springframework.util.StopWatch;

@Service
@RequiredArgsConstructor
@Slf4j
public class RecommendService {

private final ProgramService programService;
Expand All @@ -44,7 +52,7 @@ public CompletableFuture<WeeklyRecommendResponse> recommendWeeklyRoutine(
null
)
.getContent();

log.info("programs size: {}", programs.size());
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

기존에 협의했던 코드 포맷인 구글 형식이 아닌 것 같습니다.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

죄송합니다. IDE 설정이 풀려있었습니다. 재설정하여 올리겠습니다.

List<RecommendProgramData> recommendProgramDataList = convertToRecommendProgramData(programs, request);

RecommendLLMRequest llmRequest = RecommendLLMRequest.from(request, recommendProgramDataList);
Expand All @@ -53,37 +61,81 @@ public CompletableFuture<WeeklyRecommendResponse> recommendWeeklyRoutine(
return CompletableFuture.completedFuture(response);
}

private List<RecommendProgramData> convertToRecommendProgramData(
List<ProgramInfoResponse> programs,
UserInfoServiceRequest request
) {
return programs.stream()
.map(programInfoResponse -> {

ProgramDetailInfoResponse programDetailInfoResponse =
programService.getProgram(programInfoResponse.programId());

List<TransportData> transports =
programDetailInfoResponse.transportData();

double distance =
programRepository.findById(programInfoResponse.programId())
.map(program -> calculateDistance(
request.latitude(),
request.longitude(),
program.getFacility().getFcltyLa(),
program.getFacility().getFcltyLo()
))
.orElse(0.0);

return new RecommendProgramData(
programInfoResponse,
transports,
distance
);
})
.toList();
}
private List<RecommendProgramData> convertToRecommendProgramData(
List<ProgramInfoResponse> programs,
UserInfoServiceRequest request
) {
// 성능 측정을 위한 스톱워치 생성
StopWatch stopWatch = new StopWatch("DataConversionTask");

// 1. ID 추출
List<Long> programIds = programs.stream()
.map(ProgramInfoResponse::programId)
.toList();

// 2. 배치 조회 측정
stopWatch.start("Bulk DB Fetch (findAllById)");
List<Program> programEntities = programRepository.findAllById(programIds);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

매번 findById 하지 않고, findAllById로 모두 불러온 다음에 Map 형식으로 O(1) 접근하는 방식이 정말 인상적입니다 :>

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

program을 50개씩만 가져오지만, 그래도 N + 1 문제를 해결해보았다는 점에서 많은 공부를 해봤습니다.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

저 또한, 공부가 된 것 같습니다 :> 고생하셨습니다 !

stopWatch.stop();

// Map 변환
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

이런 주석은 필요 없어보입니다.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

기초적인 주석은 빼도록 하겠습니다 ;ㅅ;

Map<Long, Program> programMap = programEntities.stream()
.collect(Collectors.toMap(Program::getId, Function.identity()));

// 3. 루프 처리 시간 측정 준비
// stream 내부에서 시간을 누적하기 위해 AtomicLong 사용
AtomicLong totalGetProgramTime = new AtomicLong(0);
AtomicLong totalCalcDistanceTime = new AtomicLong(0);

stopWatch.start("Stream Processing (Loop)");

List<RecommendProgramData> result = programs.stream()
.map(programInfoResponse -> {

// A. 상세 정보 조회 시간 측정 (가장 의심되는 구간)
long startService = System.nanoTime();
ProgramDetailInfoResponse programDetailInfoResponse =
programService.getProgram(programInfoResponse.programId());
long endService = System.nanoTime();
totalGetProgramTime.addAndGet(endService - startService); // 누적

List<TransportData> transports = programDetailInfoResponse.transportData();

Program program = programMap.get(programInfoResponse.programId());

// B. 거리 계산 시간 측정
double distance = 0.0;
if (program != null && program.getFacility() != null) {
long startCalc = System.nanoTime();
distance = calculateDistance(
request.latitude(),
request.longitude(),
program.getFacility().getFcltyLa(),
program.getFacility().getFcltyLo()
);
long endCalc = System.nanoTime();
totalCalcDistanceTime.addAndGet(endCalc - startCalc);
}

return new RecommendProgramData(
programInfoResponse,
transports,
distance
);
})
.toList();

stopWatch.stop();

// 4. 결과 로그 출력
log.info("================ 성능 측정 결과 ================");
log.info(stopWatch.prettyPrint()); // 전체적인 요약 출력
log.info(">> [상세] programService.getProgram() 총 소요 시간: {} ms", totalGetProgramTime.get() / 1_000_000);
log.info(">> [상세] 거리 계산 총 소요 시간: {} ms", totalCalcDistanceTime.get() / 1_000_000);
log.info("==============================================");

return result;
}

private Double calculateDistance(
Double userLa,
Expand Down Expand Up @@ -114,3 +166,4 @@ private Double calculateDistance(
.doubleValue();
}
}

1 change: 1 addition & 0 deletions src/main/resources/application.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ spring:
chat:
options:
model: gpt-4o
temperature: 1.0

# Async Config
mvc:
Expand Down
Loading
Loading