diff --git a/py/src/braintrust/logger.py b/py/src/braintrust/logger.py index 1221b5c6..3618cdeb 100644 --- a/py/src/braintrust/logger.py +++ b/py/src/braintrust/logger.py @@ -1026,6 +1026,11 @@ def __init__(self, api_conn: LazyValue[HTTPConnection]): except: self.all_publish_payloads_dir = None + try: + disable_atexit_flush = os.environ["BRAINTRUST_DISABLE_ATEXIT_FLUSH"].lower() in ("true", "1", "yes") + except: + disable_atexit_flush = False + self.start_thread_lock = threading.RLock() self.thread = threading.Thread(target=self._publisher, daemon=True) self.started = False @@ -1036,7 +1041,8 @@ def __init__(self, api_conn: LazyValue[HTTPConnection]): # Counter for tracking overflow uploads (useful for testing) self._overflow_upload_count = 0 - atexit.register(self._finalize) + if not disable_atexit_flush: + atexit.register(self._finalize) def enforce_queue_size_limit(self, enforce: bool) -> None: """ diff --git a/py/src/braintrust/test_logger.py b/py/src/braintrust/test_logger.py index 9e8829c7..b172ede2 100644 --- a/py/src/braintrust/test_logger.py +++ b/py/src/braintrust/test_logger.py @@ -124,6 +124,41 @@ def test_init_with_repo_info_does_not_raise(self): assert metadata.project.id == "test-project-id" assert metadata.experiment.name == "test-exp" + def test_init_enable_atexit_flush(self): + from braintrust.logger import _HTTPBackgroundLogger + + api_con_response = lambda: { + "project": {"id": "test-project-id", "name": "test-project"}, + "experiment": {"id": "test-exp-id", "name": "test-exp"}, + } + + with patch("atexit.register") as mock_register: + _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) # type: ignore + mock_register.assert_called() + + def test_init_disable_atexit_flush(self): + from braintrust.logger import _HTTPBackgroundLogger + + api_con_response = lambda: { + "project": {"id": "test-project-id", "name": "test-project"}, + "experiment": {"id": "test-exp-id", "name": "test-exp"}, + } + + with patch.dict(os.environ, {"BRAINTRUST_DISABLE_ATEXIT_FLUSH": "True"}): + with patch("atexit.register") as mock_register: + _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) # type: ignore + mock_register.assert_not_called() + + with patch.dict(os.environ, {"BRAINTRUST_DISABLE_ATEXIT_FLUSH": "1"}): + with patch("atexit.register") as mock_register: + _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) # type: ignore + mock_register.assert_not_called() + + with patch.dict(os.environ, {"BRAINTRUST_DISABLE_ATEXIT_FLUSH": "yes"}): + with patch("atexit.register") as mock_register: + _HTTPBackgroundLogger(LazyValue(api_con_response, use_mutex=False)) # type: ignore + mock_register.assert_not_called() + class TestLogger(TestCase): def test_extract_attachments_no_op(self):