From 29d05d5d215c1af5371fbba4f3b9ba5bf2282b1c Mon Sep 17 00:00:00 2001 From: Stephen Belanger Date: Wed, 11 Feb 2026 02:40:00 +0800 Subject: [PATCH] Expose trace interface in scorers --- lib/braintrust/api.rb | 7 + lib/braintrust/api/btql.rb | 86 +++++++ lib/braintrust/eval.rb | 38 +-- lib/braintrust/eval/runner.rb | 29 ++- lib/braintrust/eval/scorer.rb | 28 ++- lib/braintrust/span_cache.rb | 128 ++++++++++ lib/braintrust/state.rb | 4 +- lib/braintrust/trace/span_processor.rb | 66 ++++- lib/braintrust/trace_context.rb | 210 ++++++++++++++++ test/braintrust/eval/scorer_test.rb | 89 ++++++- .../braintrust/eval/trace_integration_test.rb | 234 ++++++++++++++++++ test/braintrust/span_cache_test.rb | 186 ++++++++++++++ test/braintrust/trace_context_test.rb | 180 ++++++++++++++ 13 files changed, 1253 insertions(+), 32 deletions(-) create mode 100644 lib/braintrust/api/btql.rb create mode 100644 lib/braintrust/span_cache.rb create mode 100644 lib/braintrust/trace_context.rb create mode 100644 test/braintrust/eval/trace_integration_test.rb create mode 100644 test/braintrust/span_cache_test.rb create mode 100644 test/braintrust/trace_context_test.rb diff --git a/lib/braintrust/api.rb b/lib/braintrust/api.rb index d808537..27ffb08 100644 --- a/lib/braintrust/api.rb +++ b/lib/braintrust/api.rb @@ -2,6 +2,7 @@ require_relative "api/datasets" require_relative "api/functions" +require_relative "api/btql" module Braintrust # API client for Braintrust REST API @@ -42,5 +43,11 @@ def login def object_permalink(object_type:, object_id:) @state.object_permalink(object_type: object_type, object_id: object_id) end + + # Access to BTQL API + # @return [API::Btql] + def btql + @btql ||= API::Btql.new(self) + end end end diff --git a/lib/braintrust/api/btql.rb b/lib/braintrust/api/btql.rb new file mode 100644 index 0000000..134345a --- /dev/null +++ b/lib/braintrust/api/btql.rb @@ -0,0 +1,86 @@ +# frozen_string_literal: true + +require "net/http" +require "json" +require "uri" +require_relative "../logger" + +module Braintrust + class API + # BTQL API namespace + # Provides methods for querying spans and other data using BTQL + class Btql + def initialize(api) + @api = api + @state = api.state + end + + # Query spans using BTQL + # POST /btql + # @param query [Hash] AST-based query filter + # @param object_type [String] Type of object (e.g., "experiment") + # @param object_id [String] Object ID + # @param fmt [String] Response format (default: "jsonl") + # @return [Hash] Response with :body, :freshness_state + def query(query:, object_type:, object_id:, fmt: "jsonl") + payload = { + query: query, + object_type: object_type, + object_id: object_id, + fmt: fmt + } + + response = http_post_json_raw("/btql", payload) + + { + body: response.body, + freshness_state: response["x-bt-freshness-state"] || "complete" + } + end + + private + + # Core HTTP request method (copied from datasets.rb pattern) + def http_request(method, path, params: {}, payload: nil, base_url: nil, parse_json: true) + base = base_url || @state.api_url + uri = URI("#{base}#{path}") + uri.query = URI.encode_www_form(params) unless params.empty? + + request = case method + when :get + Net::HTTP::Get.new(uri) + when :post + req = Net::HTTP::Post.new(uri) + req["Content-Type"] = "application/json" + req.body = JSON.dump(payload) if payload + req + else + raise ArgumentError, "Unsupported HTTP method: #{method}" + end + + request["Authorization"] = "Bearer #{@state.api_key}" + + start_time = Time.now + Log.debug("[API] #{method.upcase} #{uri}") + + http = Net::HTTP.new(uri.host, uri.port) + http.use_ssl = (uri.scheme == "https") + response = http.request(request) + + duration_ms = ((Time.now - start_time) * 1000).round(2) + Log.debug("[API] #{method.upcase} #{uri} -> #{response.code} (#{duration_ms}ms, #{response.body.bytesize} bytes)") + + unless response.is_a?(Net::HTTPSuccess) + Log.debug("[API] Error response body: #{response.body}") + raise Error, "HTTP #{response.code} for #{method.upcase} #{uri}: #{response.body}" + end + + parse_json ? JSON.parse(response.body) : response + end + + def http_post_json_raw(path, payload) + http_request(:post, path, payload: payload, parse_json: false) + end + end + end +end diff --git a/lib/braintrust/eval.rb b/lib/braintrust/eval.rb index 0f59326..66b04a4 100644 --- a/lib/braintrust/eval.rb +++ b/lib/braintrust/eval.rb @@ -241,23 +241,31 @@ def run(project:, experiment:, task:, scorers:, project_id = project_result["id"] project_name = project_result["name"] - # Instantiate Runner and run evaluation - runner = Runner.new( - experiment_id: experiment_id, - experiment_name: experiment, - project_id: project_id, - project_name: project_name, - task: task, - scorers: scorers, - api: api, - tracer_provider: tracer_provider - ) - result = runner.run(cases, parallelism: parallelism) + # Enable span cache for evaluation + api.state.span_cache.start + + begin + # Instantiate Runner and run evaluation + runner = Runner.new( + experiment_id: experiment_id, + experiment_name: experiment, + project_id: project_id, + project_name: project_name, + task: task, + scorers: scorers, + api: api, + tracer_provider: tracer_provider + ) + result = runner.run(cases, parallelism: parallelism) - # Print result summary unless quiet - print_result(result) unless quiet + # Print result summary unless quiet + print_result(result) unless quiet - result + result + ensure + # Disable and clear span cache after evaluation + api.state.span_cache.stop + end end private diff --git a/lib/braintrust/eval/runner.rb b/lib/braintrust/eval/runner.rb index dfc1b79..ec81fe0 100644 --- a/lib/braintrust/eval/runner.rb +++ b/lib/braintrust/eval/runner.rb @@ -6,6 +6,7 @@ require_relative "result" require_relative "summary" require_relative "../internal/thread_pool" +require_relative "../trace_context" require "opentelemetry/sdk" require "json" @@ -103,8 +104,11 @@ def run_case(test_case, errors) end # Run scorers + # Create TraceContext for scorers (if scorers exist) + trace = scorers.empty? ? nil : create_trace_context(eval_span) + begin - run_scorers(test_case, output) + run_scorers(test_case, output, trace) rescue => e # Error already recorded on score span, set eval span status eval_span.status = OpenTelemetry::Trace::Status.error(e.message) @@ -149,15 +153,16 @@ def run_task(test_case) # Creates single score span for all scorers # @param test_case [Case] The test case # @param output [Object] Task output - def run_scorers(test_case, output) + # @param trace [TraceContext, nil] Optional trace context for scorers + def run_scorers(test_case, output, trace = nil) tracer.in_span("score") do |score_span| score_span.set_attribute("braintrust.parent", parent_attr) - set_json_attr(score_span, "braintrust.span_attributes", {type: "score"}) + set_json_attr(score_span, "braintrust.span_attributes", {type: "score", purpose: "scorer"}) scores = {} scorer_error = nil scorers.each do |scorer| - score_value = scorer.call(test_case.input, test_case.expected, output, test_case.metadata || {}) + score_value = scorer.call(test_case.input, test_case.expected, output, test_case.metadata || {}, trace) scores[scorer.name] = score_value # Collect raw score for summary (thread-safe) @@ -239,6 +244,22 @@ def collect_score(name, value) (@scores[name] ||= []) << value end end + + # Create a TraceContext for scorers to access span data + # @param eval_span [OpenTelemetry::Trace::Span] The eval span + # @return [TraceContext] + def create_trace_context(eval_span) + # Extract root_span_id from the eval span's trace_id + root_span_id = eval_span.context.trace_id.unpack1("H*") + + TraceContext.new( + object_type: "experiment", + object_id: experiment_id, + root_span_id: root_span_id, + state: @api.state, + ensure_spans_flushed: -> { @tracer_provider.force_flush } + ) + end end end end diff --git a/lib/braintrust/eval/scorer.rb b/lib/braintrust/eval/scorer.rb index 16519ba..814382e 100644 --- a/lib/braintrust/eval/scorer.rb +++ b/lib/braintrust/eval/scorer.rb @@ -3,7 +3,8 @@ module Braintrust module Eval # Scorer wraps a scoring function that evaluates task output against expected values - # Scorers can accept 3 params (input, expected, output) or 4 params (input, expected, output, metadata) + # Scorers can accept 3 params (input, expected, output), 4 params (input, expected, output, metadata), + # or 5 params (input, expected, output, metadata, trace) # They can return a float, hash, or array of hashes class Scorer attr_reader :name @@ -43,9 +44,10 @@ def initialize(name_or_callable = nil, callable = nil, &block) # @param expected [Object] The expected output # @param output [Object] The actual output from the task # @param metadata [Hash] Optional metadata + # @param trace [TraceContext, nil] Optional trace context # @return [Float, Hash, Array] Score value(s) - def call(input, expected, output, metadata = {}) - @wrapped_callable.call(input, expected, output, metadata) + def call(input, expected, output, metadata = {}, trace = nil) + @wrapped_callable.call(input, expected, output, metadata, trace) end private @@ -68,25 +70,31 @@ def detect_name(callable) "scorer" end - # Wrap the callable to always accept 4 parameters + # Wrap the callable to always accept 5 parameters # @param callable [#call] The callable to wrap - # @return [Proc] Wrapped callable that accepts 4 params + # @return [Proc] Wrapped callable that accepts 5 params def wrap_callable(callable) arity = callable_arity(callable) case arity when 3 - # Callable takes 3 params - wrap to ignore metadata - ->(input, expected, output, metadata) { + # Callable takes 3 params - wrap to ignore metadata and trace + ->(input, expected, output, metadata, trace) { callable.call(input, expected, output) } - when 4, -4, -1 - # Callable takes 4 params (or variadic with 4+) + when 4, -4 + # Callable takes 4 params - wrap to ignore trace # -4 means optional 4th param + ->(input, expected, output, metadata, trace) { + callable.call(input, expected, output, metadata) + } + when 5, -5, -1 + # Callable takes 5 params (or variadic with 5+) + # -5 means optional 5th param # -1 means variadic (*args) callable else - raise ArgumentError, "Scorer must accept 3 or 4 parameters (got arity #{arity})" + raise ArgumentError, "Scorer must accept 3, 4, or 5 parameters (got arity #{arity})" end end diff --git a/lib/braintrust/span_cache.rb b/lib/braintrust/span_cache.rb new file mode 100644 index 0000000..7c9726e --- /dev/null +++ b/lib/braintrust/span_cache.rb @@ -0,0 +1,128 @@ +# frozen_string_literal: true + +module Braintrust + # Thread-safe in-memory cache for spans during evaluation runs. + # Stores spans indexed by root_span_id to enable fast local lookups + # before falling back to BTQL queries. + class SpanCache + DEFAULT_TTL = 300 # 5 minutes + DEFAULT_MAX_ENTRIES = 1000 + + def initialize(ttl: DEFAULT_TTL, max_entries: DEFAULT_MAX_ENTRIES) + @ttl = ttl + @max_entries = max_entries + @cache = {} # {root_span_id => {spans: {span_id => data}, accessed_at: Time}} + @mutex = Mutex.new + @enabled = false + end + + # Write or merge a span into the cache + # @param root_span_id [String] The root span ID + # @param span_id [String] The span ID + # @param span_data [Hash] Span data (input, output, metadata, etc.) + def write(root_span_id, span_id, span_data) + return unless @enabled + + @mutex.synchronize do + evict_expired + evict_lru if @cache.size >= @max_entries + + @cache[root_span_id] ||= {spans: {}, accessed_at: Time.now} + entry = @cache[root_span_id] + + # Merge: incoming non-nil values override existing + existing = entry[:spans][span_id] || {} + entry[:spans][span_id] = existing.merge(span_data.compact) + entry[:accessed_at] = Time.now + end + end + + # Get all cached spans for a root span + # @param root_span_id [String] The root span ID + # @return [Array, nil] Array of span data hashes, or nil if not cached + def get(root_span_id) + return nil unless @enabled + + @mutex.synchronize do + evict_expired + entry = @cache[root_span_id] + return nil unless entry + + entry[:accessed_at] = Time.now + entry[:spans].values + end + end + + # Check if root span has cached data + # @param root_span_id [String] The root span ID + # @return [Boolean] + def has?(root_span_id) + return false unless @enabled + + @mutex.synchronize do + evict_expired + @cache.key?(root_span_id) + end + end + + # Clear one or all cache entries + # @param root_span_id [String, nil] Specific root span ID, or nil to clear all + def clear(root_span_id = nil) + @mutex.synchronize do + if root_span_id + @cache.delete(root_span_id) + else + @cache.clear + end + end + end + + # Number of cached root spans + # @return [Integer] + def size + @mutex.synchronize { @cache.size } + end + + # Check if cache is enabled + # @return [Boolean] + def enabled? + @enabled + end + + # Enable and clear the cache (called at eval start) + def start + @mutex.synchronize do + @enabled = true + @cache.clear + end + end + + # Disable and clear the cache (called at eval end) + def stop + @mutex.synchronize do + @enabled = false + @cache.clear + end + end + + # Disable the cache without clearing + def disable + @enabled = false + end + + private + + def evict_expired + now = Time.now + @cache.delete_if { |_id, entry| now - entry[:accessed_at] > @ttl } + end + + def evict_lru + return if @cache.size < @max_entries + + # Remove the least recently accessed entry + lru_key = @cache.min_by { |_id, entry| entry[:accessed_at] }&.first + @cache.delete(lru_key) if lru_key + end + end +end diff --git a/lib/braintrust/state.rb b/lib/braintrust/state.rb index f8c6275..42d1f0b 100644 --- a/lib/braintrust/state.rb +++ b/lib/braintrust/state.rb @@ -1,12 +1,13 @@ # frozen_string_literal: true require_relative "api/internal/auth" +require_relative "span_cache" module Braintrust # State object that holds Braintrust configuration # Thread-safe global state management class State - attr_reader :api_key, :org_name, :org_id, :default_project, :app_url, :api_url, :proxy_url, :logged_in, :config + attr_reader :api_key, :org_name, :org_id, :default_project, :app_url, :api_url, :proxy_url, :logged_in, :config, :span_cache @mutex = Mutex.new @global_state = nil @@ -76,6 +77,7 @@ def initialize(api_key: nil, org_name: nil, org_id: nil, default_project: nil, a @api_url = api_url || "https://api.braintrust.dev" @proxy_url = proxy_url @config = config + @span_cache = SpanCache.new # If org_id is provided, we're already "logged in" (useful for testing) # Otherwise, perform login to discover org info diff --git a/lib/braintrust/trace/span_processor.rb b/lib/braintrust/trace/span_processor.rb index 7ff7c77..12519a8 100644 --- a/lib/braintrust/trace/span_processor.rb +++ b/lib/braintrust/trace/span_processor.rb @@ -35,8 +35,11 @@ def on_start(span, parent_context) @wrapped.on_start(span, parent_context) end - # Called when a span ends - apply filters before forwarding + # Called when a span ends - write to cache and apply filters before forwarding def on_finish(span) + # Write to span cache if enabled + write_to_cache(span) if @state.span_cache.enabled? + # Only forward span if it passes filters @wrapped.on_finish(span) if should_forward_span?(span) end @@ -77,6 +80,67 @@ def get_parent_from_context(parent_context) parent_span.attributes&.[](PARENT_ATTR_KEY) end + # Write span data to cache for TraceContext access + # @param span [OpenTelemetry::SDK::Trace::SpanData] The span + def write_to_cache(span) + return unless span.respond_to?(:attributes) + + # Extract root_span_id from trace_id (hex-encoded) + root_span_id = span.trace_id.unpack1("H*") + span_id = span.span_id.unpack1("H*") + + # Extract Braintrust-specific attributes + attrs = span.attributes || {} + cached_data = {} + + # Parse JSON attributes + if attrs["braintrust.input_json"] + cached_data[:input] = begin + JSON.parse(attrs["braintrust.input_json"], symbolize_names: true) + rescue + nil + end + end + + if attrs["braintrust.output_json"] + cached_data[:output] = begin + JSON.parse(attrs["braintrust.output_json"], symbolize_names: true) + rescue + nil + end + end + + if attrs["braintrust.metadata"] + cached_data[:metadata] = begin + JSON.parse(attrs["braintrust.metadata"], symbolize_names: true) + rescue + nil + end + end + + if attrs["braintrust.span_attributes"] + cached_data[:span_attributes] = begin + JSON.parse(attrs["braintrust.span_attributes"], symbolize_names: true) + rescue + nil + end + end + + # Add span_id and span_parents + cached_data[:span_id] = span_id + + # Extract parent span IDs from the span + parent_span_id = span.parent_span_id.unpack1("H*") if span.parent_span_id != OpenTelemetry::Trace::INVALID_SPAN_ID + cached_data[:span_parents] = parent_span_id ? [parent_span_id] : [] + + # Write to cache + @state.span_cache.write(root_span_id, span_id, cached_data) + rescue => e + # Silently ignore cache write errors + require_relative "../logger" + Log.debug("Failed to write span to cache: #{e.message}") + end + # Determine if a span should be forwarded to the wrapped processor # based on configured filters def should_forward_span?(span) diff --git a/lib/braintrust/trace_context.rb b/lib/braintrust/trace_context.rb new file mode 100644 index 0000000..fb73e48 --- /dev/null +++ b/lib/braintrust/trace_context.rb @@ -0,0 +1,210 @@ +# frozen_string_literal: true + +require "json" +require "net/http" +require "uri" + +module Braintrust + # TraceContext provides scorers access to span data from the evaluation trace. + # It first attempts to retrieve spans from the local in-memory cache, then + # falls back to BTQL queries if needed. + class TraceContext + MAX_RETRIES = 8 + INITIAL_BACKOFF = 0.25 # seconds + + def initialize(object_type:, object_id:, root_span_id:, state:, ensure_spans_flushed: nil) + @object_type = object_type + @object_id = object_id + @root_span_id = root_span_id + @state = state + @ensure_spans_flushed = ensure_spans_flushed + @spans_ready_mutex = Mutex.new + @spans_ready = false + end + + # Returns configuration hash + # @return [Hash] Configuration with object_type, object_id, root_span_id + def configuration + { + object_type: @object_type, + object_id: @object_id, + root_span_id: @root_span_id + } + end + + # Get spans for this trace, optionally filtered by span type. + # Filters out scorer spans (purpose == "scorer"). + # @param span_type [Array, String, nil] Types to filter by (e.g., "llm", "score") + # @return [Array] Array of span hashes with keys: input, output, metadata, span_id, span_parents, span_attributes + def get_spans(span_type: nil) + # Normalize span_type to array + types = span_type && Array(span_type) + + # Try cache first + cached = @state.span_cache.get(@root_span_id) + spans = cached || fetch_spans_via_btql(types) + + # Filter out scorer spans + spans = spans.reject { |s| s.dig(:span_attributes, :purpose) == "scorer" } + + # Filter by type if specified + if types + spans = spans.select { |s| types.include?(s.dig(:span_attributes, :type)) } + end + + spans + end + + # Reconstruct message thread from LLM spans. + # Deduplicates input messages by content hash, always includes output messages. + # @return [Array] Array of message hashes + def get_thread + llm_spans = get_spans(span_type: "llm") + + messages = [] + seen_inputs = Set.new + + llm_spans.each do |span| + # Add input messages (deduplicated) + input = span[:input] + if input.is_a?(Hash) && input[:messages].is_a?(Array) + input[:messages].each do |msg| + msg_hash = msg.hash + unless seen_inputs.include?(msg_hash) + messages << msg + seen_inputs.add(msg_hash) + end + end + end + + # Always add output messages + output = span[:output] + if output.is_a?(Hash) && output[:choices].is_a?(Array) + output[:choices].each do |choice| + messages << choice[:message] if choice[:message] + end + end + end + + messages + end + + private + + # Ensure spans are flushed before querying (idempotent, thread-safe) + def ensure_spans_ready + @spans_ready_mutex.synchronize do + return if @spans_ready + + @ensure_spans_flushed&.call + @spans_ready = true + end + end + + # Fetch spans via BTQL with retry logic + # @param types [Array, nil] Span types to filter by + # @return [Array] Array of spans + def fetch_spans_via_btql(types) + ensure_spans_ready + + # Build AST filter + filter = build_btql_filter(types) + + retries = 0 + backoff = INITIAL_BACKOFF + + loop do + result = query_btql(filter) + + # Check freshness + if result[:freshness_state] == "complete" || retries >= MAX_RETRIES + return result[:spans] + end + + # Exponential backoff + sleep backoff + backoff *= 2 + retries += 1 + end + end + + # Build BTQL AST filter + # @param types [Array, nil] Span types to filter by + # @return [Hash] AST filter object + def build_btql_filter(types) + # root_span_id = X + root_filter = { + path: ["root_span_id"], + op: "=", + value: @root_span_id + } + + # (purpose IS NULL OR purpose != 'scorer') + purpose_filter = { + op: "or", + operands: [ + {path: ["span_attributes", "purpose"], op: "is null"}, + {path: ["span_attributes", "purpose"], op: "!=", value: "scorer"} + ] + } + + # Combine with AND + combined = { + op: "and", + operands: [root_filter, purpose_filter] + } + + # Add type filter if specified + if types && !types.empty? + type_filter = { + path: ["span_attributes", "type"], + op: "in", + value: types + } + combined[:operands] << type_filter + end + + combined + end + + # Query BTQL endpoint + # @param filter [Hash] AST filter + # @return [Hash] {spans: Array, freshness_state: String} + def query_btql(filter) + require_relative "api" + api = API.new(state: @state) + response = api.btql.query( + query: filter, + object_type: @object_type, + object_id: @object_id, + fmt: "jsonl" + ) + + # Parse JSONL response + spans = response[:body].lines.map { |line| JSON.parse(line, symbolize_names: true) } + + { + spans: spans.map { |s| normalize_span(s) }, + freshness_state: response[:freshness_state] || "complete" + } + rescue => e + # On error, return empty result + warn "BTQL query failed: #{e.message}" + {spans: [], freshness_state: "complete"} + end + + # Normalize span data from BTQL to match cache format + # @param span [Hash] Raw span data from BTQL + # @return [Hash] Normalized span + def normalize_span(span) + { + input: span[:input], + output: span[:output], + metadata: span[:metadata], + span_id: span[:span_id], + span_parents: span[:span_parents], + span_attributes: span[:span_attributes] + } + end + end +end diff --git a/test/braintrust/eval/scorer_test.rb b/test/braintrust/eval/scorer_test.rb index 6507483..d3825df 100644 --- a/test/braintrust/eval/scorer_test.rb +++ b/test/braintrust/eval/scorer_test.rb @@ -97,7 +97,7 @@ def test_scorer_invalid_arity end end - assert_match(/must accept 3 or 4 parameters/, error.message) + assert_match(/must accept 3, 4, or 5 parameters/, error.message) end def test_scorer_missing_callable @@ -177,4 +177,91 @@ def obj.my_scorer(input, expected, output) assert_equal 1.0, scorer.call("i", "match", "match") assert_equal 0.0, scorer.call("i", "match", "no_match") end + + def test_scorer_with_5_param_block + # Test scorer with 5 params (input, expected, output, metadata, trace) + trace_received = nil + scorer = Braintrust::Eval::Scorer.new("trace_scorer") do |input, expected, output, metadata, trace| + trace_received = trace + 1.0 + end + + assert_equal "trace_scorer", scorer.name + + mock_trace = Object.new + result = scorer.call("a", "b", "c", {}, mock_trace) + assert_equal 1.0, result + assert_equal mock_trace, trace_received + end + + def test_scorer_3_params_ignores_metadata_and_trace + # Test that 3-param scorer ignores metadata and trace + scorer = Braintrust::Eval::Scorer.new("simple") do |input, expected, output| + "#{input}-#{expected}-#{output}" + end + + mock_trace = Object.new + result = scorer.call("a", "b", "c", {foo: "bar"}, mock_trace) + assert_equal "a-b-c", result + end + + def test_scorer_4_params_ignores_trace + # Test that 4-param scorer ignores trace but uses metadata + scorer = Braintrust::Eval::Scorer.new("with_metadata") do |input, expected, output, metadata| + metadata[:key].to_s + end + + mock_trace = Object.new + result = scorer.call("a", "b", "c", {key: "value"}, mock_trace) + assert_equal "value", result + end + + def test_scorer_5_params_with_callable_class + # Test 5-param scorer with callable class + callable = Class.new do + def initialize + @trace_received = nil + end + + attr_reader :trace_received + + def call(input, expected, output, metadata, trace) + @trace_received = trace + {name: "custom", score: 0.9} + end + end.new + + scorer = Braintrust::Eval::Scorer.new("trace_class", callable) + + mock_trace = Object.new + result = scorer.call("a", "b", "c", {}, mock_trace) + assert_equal({name: "custom", score: 0.9}, result) + end + + def test_scorer_variadic_accepts_trace + # Test that variadic scorer (-1 arity) accepts trace + trace_received = nil + scorer = Braintrust::Eval::Scorer.new("variadic") do |*args| + trace_received = args[4] if args.length > 4 + 1.0 + end + + mock_trace = Object.new + result = scorer.call("a", "b", "c", {}, mock_trace) + assert_equal 1.0, result + assert_equal mock_trace, trace_received + end + + def test_scorer_with_nil_trace + # Test that scorer handles nil trace gracefully + trace_received = "not_nil" + scorer = Braintrust::Eval::Scorer.new("nil_trace") do |input, expected, output, metadata, trace| + trace_received = trace + 1.0 + end + + result = scorer.call("a", "b", "c", {}, nil) + assert_equal 1.0, result + assert_nil trace_received + end end diff --git a/test/braintrust/eval/trace_integration_test.rb b/test/braintrust/eval/trace_integration_test.rb new file mode 100644 index 0000000..e2dc5e4 --- /dev/null +++ b/test/braintrust/eval/trace_integration_test.rb @@ -0,0 +1,234 @@ +# frozen_string_literal: true + +require "test_helper" +require "braintrust" + +# Integration tests for trace context feature +class Braintrust::Eval::TraceIntegrationTest < Minitest::Test + def setup + @state = Braintrust::State.new( + api_key: "test-api-key", + org_id: "test-org-id", + api_url: "https://api.braintrust.dev", + enable_tracing: true + ) + end + + def teardown + @state&.span_cache&.stop + end + + def test_scorer_receives_trace_context + trace_received = nil + + scorer = Braintrust::Eval::Scorer.new("trace_aware") do |input, expected, output, metadata, trace| + trace_received = trace + 1.0 + end + + trace_context = Braintrust::TraceContext.new( + object_type: "experiment", + object_id: "exp-123", + root_span_id: "root-abc", + state: @state + ) + + @state.span_cache.start + result = scorer.call("input", "expected", "output", {}, trace_context) + + assert_equal 1.0, result + assert_instance_of Braintrust::TraceContext, trace_received + assert_equal "exp-123", trace_received.configuration[:object_id] + end + + def test_scorer_can_query_cached_spans + spans_queried = nil + + scorer = Braintrust::Eval::Scorer.new("span_query") do |input, expected, output, metadata, trace| + spans_queried = trace&.get_spans if trace + 1.0 + end + + trace_context = Braintrust::TraceContext.new( + object_type: "experiment", + object_id: "exp-123", + root_span_id: "root-abc", + state: @state + ) + + @state.span_cache.start + @state.span_cache.write("root-abc", "span1", { + input: {messages: [{role: "user", content: "Hello"}]}, + output: {choices: [{message: {role: "assistant", content: "Hi"}}]}, + span_attributes: {type: "llm"} + }) + + result = scorer.call("input", "expected", "output", {}, trace_context) + + assert_equal 1.0, result + assert_equal 1, spans_queried.size + assert_equal "llm", spans_queried.first.dig(:span_attributes, :type) + end + + def test_scorer_can_filter_spans_by_type + llm_spans_received = nil + + scorer = Braintrust::Eval::Scorer.new("type_filter") do |input, expected, output, metadata, trace| + llm_spans_received = trace&.get_spans(span_type: "llm") if trace + 1.0 + end + + trace_context = Braintrust::TraceContext.new( + object_type: "experiment", + object_id: "exp-123", + root_span_id: "root-abc", + state: @state + ) + + @state.span_cache.start + @state.span_cache.write("root-abc", "span1", {span_attributes: {type: "llm"}}) + @state.span_cache.write("root-abc", "span2", {span_attributes: {type: "task"}}) + @state.span_cache.write("root-abc", "span3", {span_attributes: {type: "llm"}}) + + scorer.call("input", "expected", "output", {}, trace_context) + + assert_equal 2, llm_spans_received.size + llm_spans_received.each do |span| + assert_equal "llm", span.dig(:span_attributes, :type) + end + end + + def test_scorer_can_get_thread + thread_received = nil + + scorer = Braintrust::Eval::Scorer.new("thread_getter") do |input, expected, output, metadata, trace| + thread_received = trace&.get_thread if trace + 1.0 + end + + trace_context = Braintrust::TraceContext.new( + object_type: "experiment", + object_id: "exp-123", + root_span_id: "root-abc", + state: @state + ) + + @state.span_cache.start + @state.span_cache.write("root-abc", "span1", { + input: {messages: [{role: "user", content: "Hello"}]}, + output: {choices: [{message: {role: "assistant", content: "Hi"}}]}, + span_attributes: {type: "llm"} + }) + + scorer.call("input", "expected", "output", {}, trace_context) + + assert_equal 2, thread_received.size + assert_equal "user", thread_received[0][:role] + assert_equal "assistant", thread_received[1][:role] + end + + def test_scorer_without_trace_parameter_still_works + scorer = Braintrust::Eval::Scorer.new("simple") do |input, expected, output| + (output == expected) ? 1.0 : 0.0 + end + + trace_context = Braintrust::TraceContext.new( + object_type: "experiment", + object_id: "exp-123", + root_span_id: "root-abc", + state: @state + ) + + result = scorer.call("input", "expected", "expected", {}, trace_context) + assert_equal 1.0, result + end + + def test_scorer_filters_out_scorer_spans + spans_received = nil + + scorer = Braintrust::Eval::Scorer.new("filter_scorer") do |input, expected, output, metadata, trace| + spans_received = trace&.get_spans if trace + 1.0 + end + + trace_context = Braintrust::TraceContext.new( + object_type: "experiment", + object_id: "exp-123", + root_span_id: "root-abc", + state: @state + ) + + @state.span_cache.start + @state.span_cache.write("root-abc", "span1", {span_attributes: {type: "llm"}}) + @state.span_cache.write("root-abc", "span2", {span_attributes: {type: "score", purpose: "scorer"}}) + @state.span_cache.write("root-abc", "span3", {span_attributes: {type: "task"}}) + + scorer.call("input", "expected", "output", {}, trace_context) + + assert_equal 2, spans_received.size + spans_received.each do |span| + refute_equal "scorer", span.dig(:span_attributes, :purpose) + end + end + + def test_span_cache_lifecycle + refute @state.span_cache.enabled? + + @state.span_cache.start + assert @state.span_cache.enabled? + + @state.span_cache.write("root1", "span1", {input: "test"}) + spans = @state.span_cache.get("root1") + assert_equal 1, spans.size + + @state.span_cache.stop + refute @state.span_cache.enabled? + assert_equal 0, @state.span_cache.size + end + + def test_trace_context_configuration + trace = Braintrust::TraceContext.new( + object_type: "experiment", + object_id: "exp-456", + root_span_id: "root-xyz", + state: @state + ) + + config = trace.configuration + assert_equal "experiment", config[:object_type] + assert_equal "exp-456", config[:object_id] + assert_equal "root-xyz", config[:root_span_id] + end + + def test_multiple_scorers_with_trace + traces_received = [] + + scorer1 = Braintrust::Eval::Scorer.new("scorer1") do |input, expected, output, metadata, trace| + traces_received << trace + 1.0 + end + + scorer2 = Braintrust::Eval::Scorer.new("scorer2") do |input, expected, output, metadata, trace| + traces_received << trace + 0.5 + end + + trace_context = Braintrust::TraceContext.new( + object_type: "experiment", + object_id: "exp-123", + root_span_id: "root-abc", + state: @state + ) + + @state.span_cache.start + + scorer1.call("input", "expected", "output", {}, trace_context) + scorer2.call("input", "expected", "output", {}, trace_context) + + assert_equal 2, traces_received.size + traces_received.each do |trace| + assert_instance_of Braintrust::TraceContext, trace + assert_equal "root-abc", trace.configuration[:root_span_id] + end + end +end diff --git a/test/braintrust/span_cache_test.rb b/test/braintrust/span_cache_test.rb new file mode 100644 index 0000000..4536dff --- /dev/null +++ b/test/braintrust/span_cache_test.rb @@ -0,0 +1,186 @@ +# frozen_string_literal: true + +require "test_helper" +require "braintrust/span_cache" + +module Braintrust + class SpanCacheTest < Minitest::Test + def setup + @cache = SpanCache.new + end + + def test_disabled_by_default + refute @cache.enabled? + assert_nil @cache.get("root1") + assert_equal 0, @cache.size + end + + def test_write_and_read_when_disabled + @cache.write("root1", "span1", {input: "test"}) + assert_nil @cache.get("root1") + assert_equal 0, @cache.size + end + + def test_start_enables_and_clears + @cache.start + assert @cache.enabled? + assert_equal 0, @cache.size + end + + def test_write_and_read_when_enabled + @cache.start + @cache.write("root1", "span1", {input: "test", output: "result"}) + + spans = @cache.get("root1") + assert_equal 1, spans.size + assert_equal "test", spans.first[:input] + assert_equal "result", spans.first[:output] + end + + def test_write_multiple_spans_same_root + @cache.start + @cache.write("root1", "span1", {input: "test1"}) + @cache.write("root1", "span2", {input: "test2"}) + + spans = @cache.get("root1") + assert_equal 2, spans.size + end + + def test_merge_behavior + @cache.start + @cache.write("root1", "span1", {input: "test", metadata: {a: 1}}) + @cache.write("root1", "span1", {output: "result", metadata: {b: 2}}) + + spans = @cache.get("root1") + assert_equal 1, spans.size + span = spans.first + assert_equal "test", span[:input] + assert_equal "result", span[:output] + assert_equal({b: 2}, span[:metadata]) + end + + def test_merge_with_nil_values + @cache.start + @cache.write("root1", "span1", {input: "test", output: "result"}) + @cache.write("root1", "span1", {input: nil, metadata: {a: 1}}) + + spans = @cache.get("root1") + span = spans.first + assert_equal "test", span[:input] + assert_equal "result", span[:output] + assert_equal({a: 1}, span[:metadata]) + end + + def test_has_returns_true_when_cached + @cache.start + @cache.write("root1", "span1", {input: "test"}) + assert @cache.has?("root1") + refute @cache.has?("root2") + end + + def test_has_returns_false_when_disabled + @cache.start + @cache.write("root1", "span1", {input: "test"}) + @cache.stop + refute @cache.has?("root1") + end + + def test_clear_single_entry + @cache.start + @cache.write("root1", "span1", {input: "test1"}) + @cache.write("root2", "span2", {input: "test2"}) + + @cache.clear("root1") + assert_nil @cache.get("root1") + assert_equal 1, @cache.get("root2").size + end + + def test_clear_all_entries + @cache.start + @cache.write("root1", "span1", {input: "test1"}) + @cache.write("root2", "span2", {input: "test2"}) + + @cache.clear + assert_equal 0, @cache.size + end + + def test_stop_disables_and_clears + @cache.start + @cache.write("root1", "span1", {input: "test"}) + @cache.stop + + refute @cache.enabled? + assert_equal 0, @cache.size + end + + def test_disable_without_clearing + @cache.start + @cache.write("root1", "span1", {input: "test"}) + @cache.disable + + refute @cache.enabled? + assert_equal 1, @cache.size + assert_nil @cache.get("root1") + end + + def test_size_returns_number_of_root_spans + @cache.start + @cache.write("root1", "span1", {input: "test1"}) + @cache.write("root1", "span2", {input: "test2"}) + @cache.write("root2", "span3", {input: "test3"}) + + assert_equal 2, @cache.size + end + + def test_ttl_expiration + cache = SpanCache.new(ttl: 0.1) + cache.start + cache.write("root1", "span1", {input: "test"}) + + assert cache.has?("root1") + sleep 0.15 + + assert_nil cache.get("root1") + assert_equal 0, cache.size + end + + def test_lru_eviction + cache = SpanCache.new(max_entries: 2) + cache.start + + cache.write("root1", "span1", {input: "test1"}) + sleep 0.001 + cache.write("root2", "span2", {input: "test2"}) + sleep 0.001 + cache.get("root1") + sleep 0.001 + cache.write("root3", "span3", {input: "test3"}) + + assert cache.has?("root1") + refute cache.has?("root2") + assert cache.has?("root3") + end + + def test_thread_safety + @cache.start + threads = [] + + 10.times do |i| + threads << Thread.new do + 100.times do |j| + @cache.write("root#{i}", "span#{j}", {input: "test#{i}-#{j}"}) + end + end + end + + threads.each(&:join) + + assert_equal 10, @cache.size + + 10.times do |i| + spans = @cache.get("root#{i}") + assert_equal 100, spans.size + end + end + end +end diff --git a/test/braintrust/trace_context_test.rb b/test/braintrust/trace_context_test.rb new file mode 100644 index 0000000..441b449 --- /dev/null +++ b/test/braintrust/trace_context_test.rb @@ -0,0 +1,180 @@ +# frozen_string_literal: true + +require "test_helper" +require "braintrust/trace_context" +require "braintrust/state" +require "braintrust/span_cache" + +module Braintrust + class TraceContextTest < Minitest::Test + def setup + @state = State.new( + api_key: "test-key", + org_id: "test-org", + api_url: "https://api.braintrust.dev", + enable_tracing: false + ) + @state.span_cache.start + + @trace_context = TraceContext.new( + object_type: "experiment", + object_id: "exp-123", + root_span_id: "root-abc", + state: @state + ) + end + + def teardown + @state&.span_cache&.stop + end + + def test_configuration_returns_correct_hash + config = @trace_context.configuration + assert_equal "experiment", config[:object_type] + assert_equal "exp-123", config[:object_id] + assert_equal "root-abc", config[:root_span_id] + end + + def test_get_spans_returns_cached_spans + @state.span_cache.write("root-abc", "span1", { + input: {messages: [{role: "user", content: "Hello"}]}, + output: {choices: [{message: {role: "assistant", content: "Hi"}}]}, + span_attributes: {type: "llm"} + }) + + spans = @trace_context.get_spans + assert_equal 1, spans.size + assert_equal "llm", spans.first.dig(:span_attributes, :type) + end + + def test_get_spans_filters_by_single_type + @state.span_cache.write("root-abc", "span1", { + span_attributes: {type: "llm"} + }) + @state.span_cache.write("root-abc", "span2", { + span_attributes: {type: "score"} + }) + + llm_spans = @trace_context.get_spans(span_type: "llm") + assert_equal 1, llm_spans.size + assert_equal "llm", llm_spans.first.dig(:span_attributes, :type) + end + + def test_get_spans_filters_by_multiple_types + @state.span_cache.write("root-abc", "span1", { + span_attributes: {type: "llm"} + }) + @state.span_cache.write("root-abc", "span2", { + span_attributes: {type: "score"} + }) + @state.span_cache.write("root-abc", "span3", { + span_attributes: {type: "task"} + }) + + spans = @trace_context.get_spans(span_type: ["llm", "score"]) + assert_equal 2, spans.size + types = spans.map { |s| s.dig(:span_attributes, :type) } + assert_includes types, "llm" + assert_includes types, "score" + end + + def test_get_spans_excludes_scorer_spans + @state.span_cache.write("root-abc", "span1", { + span_attributes: {type: "llm"} + }) + @state.span_cache.write("root-abc", "span2", { + span_attributes: {type: "score", purpose: "scorer"} + }) + + spans = @trace_context.get_spans + assert_equal 1, spans.size + assert_equal "llm", spans.first.dig(:span_attributes, :type) + end + + def test_get_thread_reconstructs_message_thread + @state.span_cache.write("root-abc", "span1", { + input: {messages: [{role: "user", content: "Hello"}]}, + output: {choices: [{message: {role: "assistant", content: "Hi"}}]}, + span_attributes: {type: "llm"} + }) + @state.span_cache.write("root-abc", "span2", { + input: {messages: [{role: "user", content: "How are you?"}]}, + output: {choices: [{message: {role: "assistant", content: "Good"}}]}, + span_attributes: {type: "llm"} + }) + + thread = @trace_context.get_thread + assert_equal 4, thread.size + assert_equal "user", thread[0][:role] + assert_equal "Hello", thread[0][:content] + assert_equal "assistant", thread[1][:role] + assert_equal "Hi", thread[1][:content] + end + + def test_get_thread_deduplicates_input_messages + msg1 = {role: "user", content: "Hello"} + msg2 = {role: "user", content: "Hello"} + + @state.span_cache.write("root-abc", "span1", { + input: {messages: [msg1]}, + output: {choices: [{message: {role: "assistant", content: "Hi"}}]}, + span_attributes: {type: "llm"} + }) + @state.span_cache.write("root-abc", "span2", { + input: {messages: [msg2]}, + output: {choices: [{message: {role: "assistant", content: "Hello again"}}]}, + span_attributes: {type: "llm"} + }) + + thread = @trace_context.get_thread + assert_equal 3, thread.size + user_messages = thread.select { |m| m[:role] == "user" } + assert_equal 1, user_messages.size + end + + def test_get_thread_always_includes_output_messages + @state.span_cache.write("root-abc", "span1", { + input: {messages: [{role: "user", content: "Hello"}]}, + output: {choices: [{message: {role: "assistant", content: "Hi"}}]}, + span_attributes: {type: "llm"} + }) + @state.span_cache.write("root-abc", "span2", { + input: {messages: [{role: "user", content: "Hello"}]}, + output: {choices: [{message: {role: "assistant", content: "Hi"}}]}, + span_attributes: {type: "llm"} + }) + + thread = @trace_context.get_thread + assistant_messages = thread.select { |m| m[:role] == "assistant" } + assert_equal 2, assistant_messages.size + end + + def test_get_thread_handles_missing_input + @state.span_cache.write("root-abc", "span1", { + output: {choices: [{message: {role: "assistant", content: "Hi"}}]}, + span_attributes: {type: "llm"} + }) + + thread = @trace_context.get_thread + assert_equal 1, thread.size + assert_equal "assistant", thread[0][:role] + end + + def test_get_thread_handles_missing_output + @state.span_cache.write("root-abc", "span1", { + input: {messages: [{role: "user", content: "Hello"}]}, + span_attributes: {type: "llm"} + }) + + thread = @trace_context.get_thread + assert_equal 1, thread.size + assert_equal "user", thread[0][:role] + end + + def test_get_spans_returns_empty_when_no_cache + @state.span_cache.stop + spans = @trace_context.get_spans + assert_equal [], spans + end + end +end