diff --git a/py/src/braintrust/functions/invoke.py b/py/src/braintrust/functions/invoke.py index 5f48f2cd..ba8b25ef 100644 --- a/py/src/braintrust/functions/invoke.py +++ b/py/src/braintrust/functions/invoke.py @@ -57,6 +57,7 @@ def invoke( api_key: str | None = None, app_url: str | None = None, force_login: bool = False, + trace_min_xact_id: str | None = None, ) -> T: ... @@ -85,6 +86,7 @@ def invoke( api_key: str | None = None, app_url: str | None = None, force_login: bool = False, + trace_min_xact_id: str | None = None, ) -> BraintrustStream: ... @@ -112,6 +114,7 @@ def invoke( api_key: str | None = None, app_url: str | None = None, force_login: bool = False, + trace_min_xact_id: str | None = None, ) -> BraintrustStream | T: """ Invoke a Braintrust function, returning a `BraintrustStream` or the value as a plain @@ -151,6 +154,7 @@ def invoke( global_function: The name of the global function to invoke. function_type: The type of the global function to invoke. If unspecified, defaults to 'scorer' for backward compatibility. + trace_min_xact_id: Optional minimum ingestion xact ID for compacted trace-ref reads. Returns: The output of the function. If `stream` is True, returns a `BraintrustStream`, @@ -198,6 +202,8 @@ def invoke( request["mode"] = mode if strict is not None: request["strict"] = strict + if trace_min_xact_id is not None: + request["trace_min_xact_id"] = trace_min_xact_id headers = { "Accept": "text/event-stream" if stream else "application/json", diff --git a/py/src/braintrust/functions/test_invoke.py b/py/src/braintrust/functions/test_invoke.py index 650f5c81..1397be79 100644 --- a/py/src/braintrust/functions/test_invoke.py +++ b/py/src/braintrust/functions/test_invoke.py @@ -118,6 +118,32 @@ def test_invoke_serializes_google_messages(): assert isinstance(parsed, dict) and parsed +def test_invoke_serializes_trace_min_xact_id(): + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {} + mock_conn = MagicMock() + mock_conn.post.return_value = mock_resp + + with ( + patch("braintrust.functions.invoke.login"), + patch("braintrust.functions.invoke.get_span_parent_object") as mock_parent, + patch("braintrust.functions.invoke.proxy_conn", return_value=mock_conn), + ): + mock_parent.return_value.export.return_value = "span-export" + invoke( + global_function="project_default", + function_type="preprocessor", + input={"trace_ref": {"object_id": "exp-123", "root_span_id": "root-456"}}, + trace_min_xact_id="12345", + ) + + data = mock_conn.post.call_args.kwargs["data"] + parsed = json.loads(data.decode("utf-8")) + assert parsed["trace_min_xact_id"] == "12345" + assert "trace_read" not in parsed + + @pytest.mark.vcr def test_invoke_encodes_body_as_utf8_bytes(monkeypatch): """Regression test for BT-4620: non-Latin-1 Unicode must not be corrupted. diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index 522b31e1..22ebf21e 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -18,6 +18,7 @@ import types import uuid from abc import ABC, abstractmethod +from collections import OrderedDict from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence from functools import partial, wraps from multiprocessing import cpu_count @@ -119,6 +120,8 @@ class Logs3OverflowInputRow: class LogItemWithMeta: str_value: str overflow_meta: Logs3OverflowInputRow + root_span_id: str | None = None + object_ids: dict[str, Any] = dataclasses.field(default_factory=dict) class DatasetRef(TypedDict, total=False): @@ -419,7 +422,11 @@ def default_get_api_conn(): # We lazily-initialize the logger so that it does any initialization # (including reading env variables) upon the first actual usage. self._global_bg_logger = LazyValue( - lambda: _HTTPBackgroundLogger(LazyValue(default_get_api_conn, use_mutex=True)), use_mutex=True + lambda: _HTTPBackgroundLogger( + LazyValue(default_get_api_conn, use_mutex=True), + record_write_xact_id=self.record_trace_write_xact_id, + ), + use_mutex=True, ) self._id_generator = None @@ -462,6 +469,9 @@ def default_get_api_conn(): from braintrust.span_cache import SpanCache self.span_cache = SpanCache() + self._trace_write_xact_ids: OrderedDict[tuple[str, str], str] = OrderedDict() + self._trace_write_xact_ids_max_size = int(os.environ.get("BRAINTRUST_TRACE_WRITE_XACT_IDS_MAX_SIZE", "10000")) + self._trace_write_xact_ids_lock = threading.Lock() self._otel_flush_callback: Any | None = None def reset_login_info(self): @@ -521,6 +531,23 @@ def context_manager(self): return self._context_manager + def record_trace_write_xact_id(self, object_id: str, root_span_id: str, xact_id: str) -> None: + """Record the highest ingestion xact ID observed for a trace.""" + parsed_xact_id = int(xact_id) + key = (object_id, root_span_id) + with self._trace_write_xact_ids_lock: + current_xact_id = self._trace_write_xact_ids.get(key) + if current_xact_id is None or parsed_xact_id > int(current_xact_id): + self._trace_write_xact_ids[key] = xact_id + self._trace_write_xact_ids.move_to_end(key) + while len(self._trace_write_xact_ids) > self._trace_write_xact_ids_max_size: + self._trace_write_xact_ids.popitem(last=False) + + def get_trace_write_xact_id(self, object_id: str, root_span_id: str) -> str | None: + """Return the highest ingestion xact ID recorded for a trace.""" + with self._trace_write_xact_ids_lock: + return self._trace_write_xact_ids.get((object_id, root_span_id)) + def register_otel_flush(self, callback: Any) -> None: """ Register an OTEL flush callback. This is called by the OTEL integration @@ -554,6 +581,9 @@ def copy_state(self, other: "BraintrustState"): "_context_manager", "_last_otel_setting", "_context_manager_lock", + "_trace_write_xact_ids", + "_trace_write_xact_ids_max_size", + "_trace_write_xact_ids_lock", ) } ) @@ -864,14 +894,17 @@ def pick_logs3_overflow_object_ids(row: Mapping[str, Any]) -> dict[str, Any]: def stringify_with_overflow_meta(item: dict[str, Any]) -> LogItemWithMeta: str_value = bt_dumps(item) + object_ids = pick_logs3_overflow_object_ids(item) return LogItemWithMeta( str_value=str_value, overflow_meta=Logs3OverflowInputRow( - object_ids=pick_logs3_overflow_object_ids(item), + object_ids=object_ids, has_comment="comment" in item, is_delete=item.get(OBJECT_DELETE_FIELD) is True, byte_size=utf8_byte_length(str_value), ), + root_span_id=item.get("root_span_id") if isinstance(item.get("root_span_id"), str) else None, + object_ids=object_ids, ) @@ -1004,8 +1037,13 @@ def pop(self): # instances of this class, because concurrent _BackgroundLoggers will not log to # the backend in a deterministic order. class _HTTPBackgroundLogger: - def __init__(self, api_conn: LazyValue[HTTPConnection]): + def __init__( + self, + api_conn: LazyValue[HTTPConnection], + record_write_xact_id: Callable[[str, str, str], None] | None = None, + ): self.api_conn = api_conn + self._record_write_xact_id = record_write_xact_id self.masking_function: Callable[[Any], Any] | None = None self.outfile = sys.stderr self.flush_lock = threading.RLock() @@ -1383,6 +1421,7 @@ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_siz if error is None and resp is not None and resp.ok: if overflow_rows: self._overflow_upload_count += 1 + self._record_batch_write_xact_id(items, resp.headers.get("x-bt-write-xact-id")) return if error is None and resp is not None: resp_errmsg = f"{resp.status_code}: {resp.text}" @@ -1410,6 +1449,16 @@ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_siz print(f"log request failed after {self.num_tries} retries. Dropping batch", file=self.outfile) + def _record_batch_write_xact_id(self, items: Sequence[LogItemWithMeta], xact_id: str | None) -> None: + if not xact_id or self._record_write_xact_id is None: + return + for item in items: + if not item.root_span_id: + continue + for object_id in item.object_ids.values(): + if isinstance(object_id, str): + self._record_write_xact_id(object_id, item.root_span_id, xact_id) + def _dump_dropped_events(self, wrapped_items): publish_payloads_dir = [x for x in [self.all_publish_payloads_dir, self.failed_publish_payloads_dir] if x] if not (wrapped_items and publish_payloads_dir): @@ -1480,7 +1529,9 @@ def _internal_get_global_state() -> BraintrustState: @contextlib.contextmanager def _internal_with_custom_background_logger(): - custom_logger = _HTTPBackgroundLogger(LazyValue(lambda: _state.api_conn(), use_mutex=True)) + custom_logger = _HTTPBackgroundLogger( + LazyValue(lambda: _state.api_conn(), use_mutex=True), record_write_xact_id=_state.record_trace_write_xact_id + ) _state._override_bg_logger.logger = custom_logger try: yield custom_logger diff --git a/py/src/braintrust/test_logger.py b/py/src/braintrust/test_logger.py index e8c22bdc..ae25a089 100644 --- a/py/src/braintrust/test_logger.py +++ b/py/src/braintrust/test_logger.py @@ -137,6 +137,69 @@ def test_init_enable_atexit_flush(self): _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) # type: ignore mock_register.assert_called() + def test_records_write_xact_id_from_logs3_response(self): + from braintrust.logger import _HTTPBackgroundLogger, stringify_with_overflow_meta + + class FakeResponse: + ok = True + headers = {"x-bt-write-xact-id": "12345"} + + mock_conn = MagicMock() + mock_conn.post.return_value = FakeResponse() + recorded = [] + bg_logger = _HTTPBackgroundLogger( + LazyValue(lambda: mock_conn, use_mutex=False), + record_write_xact_id=lambda object_id, root_span_id, xact_id: recorded.append( + (object_id, root_span_id, xact_id) + ), + ) + + bg_logger._submit_logs_request( + [ + stringify_with_overflow_meta( + { + "experiment_id": "exp-123", + "root_span_id": "root-456", + "span_id": "span-789", + } + ) + ], + {"max_request_size": 1024 * 1024, "can_use_overflow": False}, + ) + + assert recorded == [("exp-123", "root-456", "12345")] + + def test_trace_write_xact_id_keeps_high_watermark(self): + from braintrust.logger import BraintrustState + + state = BraintrustState() + state.record_trace_write_xact_id("exp-123", "root-456", "200") + state.record_trace_write_xact_id("exp-123", "root-456", "100") + state.record_trace_write_xact_id("exp-123", "root-other", "50") + + assert state.get_trace_write_xact_id("exp-123", "root-456") == "200" + assert state.get_trace_write_xact_id("exp-123", "root-other") == "50" + + def test_trace_write_xact_id_rejects_non_numeric_values(self): + from braintrust.logger import BraintrustState + + state = BraintrustState() + with pytest.raises(ValueError): + state.record_trace_write_xact_id("exp-123", "root-456", "not-numeric") + + def test_trace_write_xact_ids_are_bounded(self): + from braintrust.logger import BraintrustState + + with patch.dict(os.environ, {"BRAINTRUST_TRACE_WRITE_XACT_IDS_MAX_SIZE": "2"}): + state = BraintrustState() + state.record_trace_write_xact_id("exp-123", "root-1", "1") + state.record_trace_write_xact_id("exp-123", "root-2", "2") + state.record_trace_write_xact_id("exp-123", "root-3", "3") + + assert state.get_trace_write_xact_id("exp-123", "root-1") is None + assert state.get_trace_write_xact_id("exp-123", "root-2") == "2" + assert state.get_trace_write_xact_id("exp-123", "root-3") == "3" + def test_init_disable_atexit_flush(self): from braintrust.logger import _HTTPBackgroundLogger diff --git a/py/src/braintrust/test_trace.py b/py/src/braintrust/test_trace.py index 577d4e58..16a105ab 100644 --- a/py/src/braintrust/test_trace.py +++ b/py/src/braintrust/test_trace.py @@ -306,12 +306,16 @@ def get_by_root_span_id(self, root_span_id: str): class _DummyState: - def __init__(self): + def __init__(self, xact_id: str | None = None): self.span_cache = _DummySpanCache() + self.xact_id = xact_id def login(self): return None + def get_trace_write_xact_id(self, object_id: str, root_span_id: str): + return self.xact_id + class TestLocalTraceGetThread: @pytest.mark.asyncio @@ -349,8 +353,33 @@ def fake_invoke(**kwargs): "root_span_id": "root-456", } } + assert calls[0]["trace_min_xact_id"] is None assert result == mock_thread + @pytest.mark.asyncio + async def test_passes_trace_min_xact_id_with_recorded_xact_id(self, monkeypatch): + calls = [] + + def fake_invoke(**kwargs): + calls.append(kwargs) + return [] + + monkeypatch.setattr("braintrust.trace.invoke", fake_invoke) + + trace = LocalTrace( + object_type="experiment", + object_id="exp-123", + root_span_id="root-456", + ensure_spans_flushed=None, + state=_DummyState(xact_id="12345"), + ) + + await trace.get_thread() + + assert calls[0]["trace_min_xact_id"] == "12345" + assert "trace_read" not in calls[0] + assert "skip_realtime" not in calls[0]["input"]["trace_ref"] + @pytest.mark.asyncio async def test_uses_custom_preprocessor(self, monkeypatch): calls = [] diff --git a/py/src/braintrust/trace.py b/py/src/braintrust/trace.py index 24bcefa2..b9216bf1 100644 --- a/py/src/braintrust/trace.py +++ b/py/src/braintrust/trace.py @@ -407,6 +407,8 @@ async def _fetch_thread(self, options: GetThreadOptions | None = None) -> list[A await asyncio.get_event_loop().run_in_executor(None, lambda: self._state.login()) preprocessor = options.get("preprocessor") if options and options.get("preprocessor") else None + trace_min_xact_id = self._state.get_trace_write_xact_id(self._object_id, self._root_span_id) + result = await asyncio.get_event_loop().run_in_executor( None, lambda: invoke( @@ -420,6 +422,7 @@ async def _fetch_thread(self, options: GetThreadOptions | None = None) -> list[A "root_span_id": self._root_span_id, } }, + trace_min_xact_id=trace_min_xact_id, ), )