diff --git a/kv_cache_benchmark/kv_cache/benchmark.py b/kv_cache_benchmark/kv_cache/benchmark.py index f345891..8219019 100755 --- a/kv_cache_benchmark/kv_cache/benchmark.py +++ b/kv_cache_benchmark/kv_cache/benchmark.py @@ -514,19 +514,19 @@ def process_requests(self, stop_event: threading.Event): storage_latency += read_lat request.context_tokens = remaining_tokens - # 2. For multi-turn conversations, access cache from previous turn. - if self.conversation_manager and request.turn_number > 1: - prev_turn_key = f"{request.conversation_id}_turn_{request.turn_number - 1}" - location, read_latency = self.cache.access_cache(prev_turn_key, InferencePhase.DECODE, 'multi_turn') - if location is not None: - storage_latency += read_latency - with self.results_lock: self.results['multi_turn_cache_hits'] += 1 - else: - with self.results_lock: self.results['multi_turn_cache_misses'] += 1 - - # 3. Perform the main PREFILL operation (a cache WRITE). # Skip if decode_only mode (disaggregated decode node) if not self.decode_only: + # 2. For multi-turn conversations, access cache from previous turn. + if self.conversation_manager and request.turn_number > 1: + prev_turn_key = f"{request.conversation_id}_turn_{request.turn_number - 1}" + location, read_latency = self.cache.access_cache(prev_turn_key, InferencePhase.DECODE, 'multi_turn') + if location is not None: + storage_latency += read_latency + with self.results_lock: self.results['multi_turn_cache_hits'] += 1 + else: + with self.results_lock: self.results['multi_turn_cache_misses'] += 1 + + # 3. Perform the main PREFILL operation (a cache WRITE). if request.phase == InferencePhase.PREFILL or request.phase == InferencePhase.PREFILL_DECODE: success, location, write_latency = self.cache.allocate_cache( request.cache_key, request.context_tokens, InferencePhase.PREFILL