Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions py/src/braintrust/functions/invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def invoke(
api_key: str | None = None,
app_url: str | None = None,
force_login: bool = False,
trace_min_xact_id: str | None = None,
) -> T: ...


Expand Down Expand Up @@ -85,6 +86,7 @@ def invoke(
api_key: str | None = None,
app_url: str | None = None,
force_login: bool = False,
trace_min_xact_id: str | None = None,
) -> BraintrustStream: ...


Expand Down Expand Up @@ -112,6 +114,7 @@ def invoke(
api_key: str | None = None,
app_url: str | None = None,
force_login: bool = False,
trace_min_xact_id: str | None = None,
) -> BraintrustStream | T:
"""
Invoke a Braintrust function, returning a `BraintrustStream` or the value as a plain
Expand Down Expand Up @@ -151,6 +154,7 @@ def invoke(
global_function: The name of the global function to invoke.
function_type: The type of the global function to invoke. If unspecified, defaults to 'scorer'
for backward compatibility.
trace_min_xact_id: Optional minimum ingestion xact ID for compacted trace-ref reads.

Returns:
The output of the function. If `stream` is True, returns a `BraintrustStream`,
Expand Down Expand Up @@ -198,6 +202,8 @@ def invoke(
request["mode"] = mode
if strict is not None:
request["strict"] = strict
if trace_min_xact_id is not None:
request["trace_min_xact_id"] = trace_min_xact_id

headers = {
"Accept": "text/event-stream" if stream else "application/json",
Expand Down
26 changes: 26 additions & 0 deletions py/src/braintrust/functions/test_invoke.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,32 @@ def test_invoke_serializes_google_messages():
assert isinstance(parsed, dict) and parsed


def test_invoke_serializes_trace_min_xact_id():
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {}
mock_conn = MagicMock()
mock_conn.post.return_value = mock_resp

with (
patch("braintrust.functions.invoke.login"),
patch("braintrust.functions.invoke.get_span_parent_object") as mock_parent,
patch("braintrust.functions.invoke.proxy_conn", return_value=mock_conn),
):
mock_parent.return_value.export.return_value = "span-export"
invoke(
global_function="project_default",
function_type="preprocessor",
input={"trace_ref": {"object_id": "exp-123", "root_span_id": "root-456"}},
trace_min_xact_id="12345",
)

data = mock_conn.post.call_args.kwargs["data"]
parsed = json.loads(data.decode("utf-8"))
assert parsed["trace_min_xact_id"] == "12345"
assert "trace_read" not in parsed


@pytest.mark.vcr
def test_invoke_encodes_body_as_utf8_bytes(monkeypatch):
"""Regression test for BT-4620: non-Latin-1 Unicode must not be corrupted.
Expand Down
59 changes: 55 additions & 4 deletions py/src/braintrust/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import types
import uuid
from abc import ABC, abstractmethod
from collections import OrderedDict
from collections.abc import Callable, Iterator, Mapping, MutableMapping, Sequence
from functools import partial, wraps
from multiprocessing import cpu_count
Expand Down Expand Up @@ -119,6 +120,8 @@ class Logs3OverflowInputRow:
class LogItemWithMeta:
str_value: str
overflow_meta: Logs3OverflowInputRow
root_span_id: str | None = None
object_ids: dict[str, Any] = dataclasses.field(default_factory=dict)


class DatasetRef(TypedDict, total=False):
Expand Down Expand Up @@ -419,7 +422,11 @@ def default_get_api_conn():
# We lazily-initialize the logger so that it does any initialization
# (including reading env variables) upon the first actual usage.
self._global_bg_logger = LazyValue(
lambda: _HTTPBackgroundLogger(LazyValue(default_get_api_conn, use_mutex=True)), use_mutex=True
lambda: _HTTPBackgroundLogger(
LazyValue(default_get_api_conn, use_mutex=True),
record_write_xact_id=self.record_trace_write_xact_id,
),
use_mutex=True,
)

self._id_generator = None
Expand Down Expand Up @@ -462,6 +469,9 @@ def default_get_api_conn():
from braintrust.span_cache import SpanCache

self.span_cache = SpanCache()
self._trace_write_xact_ids: OrderedDict[tuple[str, str], str] = OrderedDict()
self._trace_write_xact_ids_max_size = int(os.environ.get("BRAINTRUST_TRACE_WRITE_XACT_IDS_MAX_SIZE", "10000"))
self._trace_write_xact_ids_lock = threading.Lock()
self._otel_flush_callback: Any | None = None

def reset_login_info(self):
Expand Down Expand Up @@ -521,6 +531,23 @@ def context_manager(self):

return self._context_manager

def record_trace_write_xact_id(self, object_id: str, root_span_id: str, xact_id: str) -> None:
"""Record the highest ingestion xact ID observed for a trace."""
parsed_xact_id = int(xact_id)
key = (object_id, root_span_id)
with self._trace_write_xact_ids_lock:
current_xact_id = self._trace_write_xact_ids.get(key)
if current_xact_id is None or parsed_xact_id > int(current_xact_id):
self._trace_write_xact_ids[key] = xact_id
self._trace_write_xact_ids.move_to_end(key)
while len(self._trace_write_xact_ids) > self._trace_write_xact_ids_max_size:
self._trace_write_xact_ids.popitem(last=False)

def get_trace_write_xact_id(self, object_id: str, root_span_id: str) -> str | None:
"""Return the highest ingestion xact ID recorded for a trace."""
with self._trace_write_xact_ids_lock:
return self._trace_write_xact_ids.get((object_id, root_span_id))

def register_otel_flush(self, callback: Any) -> None:
"""
Register an OTEL flush callback. This is called by the OTEL integration
Expand Down Expand Up @@ -554,6 +581,9 @@ def copy_state(self, other: "BraintrustState"):
"_context_manager",
"_last_otel_setting",
"_context_manager_lock",
"_trace_write_xact_ids",
"_trace_write_xact_ids_max_size",
"_trace_write_xact_ids_lock",
)
}
)
Expand Down Expand Up @@ -864,14 +894,17 @@ def pick_logs3_overflow_object_ids(row: Mapping[str, Any]) -> dict[str, Any]:

def stringify_with_overflow_meta(item: dict[str, Any]) -> LogItemWithMeta:
str_value = bt_dumps(item)
object_ids = pick_logs3_overflow_object_ids(item)
return LogItemWithMeta(
str_value=str_value,
overflow_meta=Logs3OverflowInputRow(
object_ids=pick_logs3_overflow_object_ids(item),
object_ids=object_ids,
has_comment="comment" in item,
is_delete=item.get(OBJECT_DELETE_FIELD) is True,
byte_size=utf8_byte_length(str_value),
),
root_span_id=item.get("root_span_id") if isinstance(item.get("root_span_id"), str) else None,
object_ids=object_ids,
)


Expand Down Expand Up @@ -1004,8 +1037,13 @@ def pop(self):
# instances of this class, because concurrent _BackgroundLoggers will not log to
# the backend in a deterministic order.
class _HTTPBackgroundLogger:
def __init__(self, api_conn: LazyValue[HTTPConnection]):
def __init__(
self,
api_conn: LazyValue[HTTPConnection],
record_write_xact_id: Callable[[str, str, str], None] | None = None,
):
self.api_conn = api_conn
self._record_write_xact_id = record_write_xact_id
self.masking_function: Callable[[Any], Any] | None = None
self.outfile = sys.stderr
self.flush_lock = threading.RLock()
Expand Down Expand Up @@ -1383,6 +1421,7 @@ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_siz
if error is None and resp is not None and resp.ok:
if overflow_rows:
self._overflow_upload_count += 1
self._record_batch_write_xact_id(items, resp.headers.get("x-bt-write-xact-id"))
return
if error is None and resp is not None:
resp_errmsg = f"{resp.status_code}: {resp.text}"
Expand Down Expand Up @@ -1410,6 +1449,16 @@ def _submit_logs_request(self, items: Sequence[LogItemWithMeta], max_request_siz

print(f"log request failed after {self.num_tries} retries. Dropping batch", file=self.outfile)

def _record_batch_write_xact_id(self, items: Sequence[LogItemWithMeta], xact_id: str | None) -> None:
if not xact_id or self._record_write_xact_id is None:
return
for item in items:
if not item.root_span_id:
continue
for object_id in item.object_ids.values():
if isinstance(object_id, str):
self._record_write_xact_id(object_id, item.root_span_id, xact_id)

def _dump_dropped_events(self, wrapped_items):
publish_payloads_dir = [x for x in [self.all_publish_payloads_dir, self.failed_publish_payloads_dir] if x]
if not (wrapped_items and publish_payloads_dir):
Expand Down Expand Up @@ -1480,7 +1529,9 @@ def _internal_get_global_state() -> BraintrustState:

@contextlib.contextmanager
def _internal_with_custom_background_logger():
custom_logger = _HTTPBackgroundLogger(LazyValue(lambda: _state.api_conn(), use_mutex=True))
custom_logger = _HTTPBackgroundLogger(
LazyValue(lambda: _state.api_conn(), use_mutex=True), record_write_xact_id=_state.record_trace_write_xact_id
)
_state._override_bg_logger.logger = custom_logger
try:
yield custom_logger
Expand Down
63 changes: 63 additions & 0 deletions py/src/braintrust/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,69 @@ def test_init_enable_atexit_flush(self):
_HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) # type: ignore
mock_register.assert_called()

def test_records_write_xact_id_from_logs3_response(self):
from braintrust.logger import _HTTPBackgroundLogger, stringify_with_overflow_meta

class FakeResponse:
ok = True
headers = {"x-bt-write-xact-id": "12345"}

mock_conn = MagicMock()
mock_conn.post.return_value = FakeResponse()
recorded = []
bg_logger = _HTTPBackgroundLogger(
LazyValue(lambda: mock_conn, use_mutex=False),
record_write_xact_id=lambda object_id, root_span_id, xact_id: recorded.append(
(object_id, root_span_id, xact_id)
),
)

bg_logger._submit_logs_request(
[
stringify_with_overflow_meta(
{
"experiment_id": "exp-123",
"root_span_id": "root-456",
"span_id": "span-789",
}
)
],
{"max_request_size": 1024 * 1024, "can_use_overflow": False},
)

assert recorded == [("exp-123", "root-456", "12345")]

def test_trace_write_xact_id_keeps_high_watermark(self):
from braintrust.logger import BraintrustState

state = BraintrustState()
state.record_trace_write_xact_id("exp-123", "root-456", "200")
state.record_trace_write_xact_id("exp-123", "root-456", "100")
state.record_trace_write_xact_id("exp-123", "root-other", "50")

assert state.get_trace_write_xact_id("exp-123", "root-456") == "200"
assert state.get_trace_write_xact_id("exp-123", "root-other") == "50"

def test_trace_write_xact_id_rejects_non_numeric_values(self):
from braintrust.logger import BraintrustState

state = BraintrustState()
with pytest.raises(ValueError):
state.record_trace_write_xact_id("exp-123", "root-456", "not-numeric")

def test_trace_write_xact_ids_are_bounded(self):
from braintrust.logger import BraintrustState

with patch.dict(os.environ, {"BRAINTRUST_TRACE_WRITE_XACT_IDS_MAX_SIZE": "2"}):
state = BraintrustState()
state.record_trace_write_xact_id("exp-123", "root-1", "1")
state.record_trace_write_xact_id("exp-123", "root-2", "2")
state.record_trace_write_xact_id("exp-123", "root-3", "3")

assert state.get_trace_write_xact_id("exp-123", "root-1") is None
assert state.get_trace_write_xact_id("exp-123", "root-2") == "2"
assert state.get_trace_write_xact_id("exp-123", "root-3") == "3"

def test_init_disable_atexit_flush(self):
from braintrust.logger import _HTTPBackgroundLogger

Expand Down
31 changes: 30 additions & 1 deletion py/src/braintrust/test_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,16 @@ def get_by_root_span_id(self, root_span_id: str):


class _DummyState:
def __init__(self):
def __init__(self, xact_id: str | None = None):
self.span_cache = _DummySpanCache()
self.xact_id = xact_id

def login(self):
return None

def get_trace_write_xact_id(self, object_id: str, root_span_id: str):
return self.xact_id


class TestLocalTraceGetThread:
@pytest.mark.asyncio
Expand Down Expand Up @@ -349,8 +353,33 @@ def fake_invoke(**kwargs):
"root_span_id": "root-456",
}
}
assert calls[0]["trace_min_xact_id"] is None
assert result == mock_thread

@pytest.mark.asyncio
async def test_passes_trace_min_xact_id_with_recorded_xact_id(self, monkeypatch):
calls = []

def fake_invoke(**kwargs):
calls.append(kwargs)
return []

monkeypatch.setattr("braintrust.trace.invoke", fake_invoke)

trace = LocalTrace(
object_type="experiment",
object_id="exp-123",
root_span_id="root-456",
ensure_spans_flushed=None,
state=_DummyState(xact_id="12345"),
)

await trace.get_thread()

assert calls[0]["trace_min_xact_id"] == "12345"
assert "trace_read" not in calls[0]
assert "skip_realtime" not in calls[0]["input"]["trace_ref"]

@pytest.mark.asyncio
async def test_uses_custom_preprocessor(self, monkeypatch):
calls = []
Expand Down
3 changes: 3 additions & 0 deletions py/src/braintrust/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ async def _fetch_thread(self, options: GetThreadOptions | None = None) -> list[A
await asyncio.get_event_loop().run_in_executor(None, lambda: self._state.login())
preprocessor = options.get("preprocessor") if options and options.get("preprocessor") else None

trace_min_xact_id = self._state.get_trace_write_xact_id(self._object_id, self._root_span_id)

result = await asyncio.get_event_loop().run_in_executor(
None,
lambda: invoke(
Expand All @@ -420,6 +422,7 @@ async def _fetch_thread(self, options: GetThreadOptions | None = None) -> list[A
"root_span_id": self._root_span_id,
}
},
trace_min_xact_id=trace_min_xact_id,
),
)

Expand Down
Loading