Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/tabpfn_common_utils/telemetry/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .events import (
BaseTelemetryEvent,
ExtensionEntryEvent,
ModelLoadEvent,
PingEvent,
DatasetEvent,
Expand All @@ -24,6 +25,7 @@
# Public exports
__all__ = [
"BaseTelemetryEvent",
"ExtensionEntryEvent",
"PingEvent",
"DatasetEvent",
"ModelLoadEvent",
Expand Down
3 changes: 2 additions & 1 deletion src/tabpfn_common_utils/telemetry/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from functools import wraps
from typing import Any, Callable, Dict, Literal, Optional, Tuple, Union

from .events import FitEvent, PredictEvent
from .events import ExtensionEntryEvent, FitEvent, PredictEvent
from .service import capture_event
from tabpfn_common_utils.utils import shape_of

Expand Down Expand Up @@ -214,6 +214,7 @@ def wrapped(*args, **kwargs):
if _get_context_var("tabpfn_current_extension").get() is not None:
return fn(*args, **kwargs)
with _extension_context(extension_name):
capture_event(ExtensionEntryEvent(extension_name=extension_name))
return fn(*args, **kwargs)

setattr(wrapped, _MARKER_ATTR, extension_name)
Expand Down
18 changes: 18 additions & 0 deletions src/tabpfn_common_utils/telemetry/core/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -499,3 +499,21 @@ class PredictEvent(ModelCallEvent):
@property
def name(self) -> str:
return "predict_called"


@dataclass
class ExtensionEntryEvent(BaseTelemetryEvent):
"""
Event emitted once per user-facing extension entry point call.

Unlike FitEvent/PredictEvent which fire per downstream model call,
this fires exactly once when the outermost @set_extension decorator
is entered, giving an unbiased count of extension usage.
"""

# Name of the extension that was entered
extension_name: str = ""

@property
def name(self) -> str:
return "extension_entry"
115 changes: 115 additions & 0 deletions tests/telemetry/core/test_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from unittest.mock import patch

import pytest

from tabpfn_common_utils.telemetry.core.decorators import (
Expand All @@ -10,6 +12,7 @@
get_current_extension,
_extension_context,
)
from tabpfn_common_utils.telemetry.core.events import ExtensionEntryEvent


class TestSetExtensionDecorator:
Expand Down Expand Up @@ -268,3 +271,115 @@ def test_round_dims_special_cases(self) -> None:
# Test some intermediate values
assert _round_dims((1234, 67)) == (1200, 75) # 1234 -> 1200, 67 -> 75
assert _round_dims((876, 89)) == (1000, 100) # 876 -> 1000, 89 -> 100


class TestExtensionEntryEventEmission:
"""Test that set_extension emits ExtensionEntryEvent correctly."""

@patch("tabpfn_common_utils.telemetry.core.decorators.capture_event")
def test_single_call_emits_one_event(self, mock_capture):
"""A single decorated function call emits exactly one ExtensionEntryEvent."""

@set_extension("test_ext")
def my_func():
return 42

result = my_func()

assert result == 42
assert mock_capture.call_count == 1
event = mock_capture.call_args[0][0]
assert isinstance(event, ExtensionEntryEvent)
assert event.extension_name == "test_ext"

@patch("tabpfn_common_utils.telemetry.core.decorators.capture_event")
def test_nested_calls_emit_one_event(self, mock_capture):
"""Nested decorated calls emit only one event for the outermost extension."""

@set_extension("outer")
def outer():
return inner()

@set_extension("inner")
def inner():
return get_current_extension()

result = outer()

# Inner should see the outer context
assert result == "outer"
# Only one event emitted (for the outer entry)
assert mock_capture.call_count == 1
event = mock_capture.call_args[0][0]
assert isinstance(event, ExtensionEntryEvent)
assert event.extension_name == "outer"

@patch("tabpfn_common_utils.telemetry.core.decorators.capture_event")
def test_sequential_calls_emit_separate_events(self, mock_capture):
"""Two sequential (non-nested) calls emit two separate events."""

@set_extension("ext_a")
def func_a():
return "a"

@set_extension("ext_b")
def func_b():
return "b"

func_a()
func_b()

assert mock_capture.call_count == 2
assert mock_capture.call_args_list[0][0][0].extension_name == "ext_a"
assert mock_capture.call_args_list[1][0][0].extension_name == "ext_b"

@patch("tabpfn_common_utils.telemetry.core.decorators.capture_event")
def test_class_decorator_emits_one_event_per_public_method_call(self, mock_capture):
"""A class decorated with set_extension emits one event per public method call.
__init__ is private (starts with _) so it's not wrapped by default."""

@set_extension("cls_ext")
class MyClass:
def do_work(self):
return "done"

obj = MyClass()
obj.do_work()

# Only do_work is public; __init__ starts with _ so not wrapped
assert mock_capture.call_count == 1
event = mock_capture.call_args[0][0]
assert isinstance(event, ExtensionEntryEvent)
assert event.extension_name == "cls_ext"

@patch("tabpfn_common_utils.telemetry.core.decorators.capture_event")
def test_class_nested_method_calls_emit_one_event(self, mock_capture):
"""When a class method calls another decorated function, only the outer emits."""

@set_extension("inner_ext")
def helper():
return "helped"

@set_extension("cls_ext")
class MyClass:
def do_work(self):
return helper()

obj = MyClass()
result = obj.do_work()

assert result == "helped"
# do_work emits one event, helper() is nested so no event
assert mock_capture.call_count == 1
assert mock_capture.call_args[0][0].extension_name == "cls_ext"

@patch("posthog.Posthog.capture", side_effect=RuntimeError("PostHog down"))
def test_capture_event_resilient_to_posthog_failure(self, _mock_posthog):
"""PostHog client failure doesn't prevent the wrapped function from running."""

@set_extension("fail_ext")
def my_func():
return "still works"

result = my_func()
assert result == "still works"
46 changes: 46 additions & 0 deletions tests/telemetry/core/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from tabpfn_common_utils.telemetry.core.events import (
BaseTelemetryEvent,
DatasetEvent,
ExtensionEntryEvent,
FitEvent,
ModelLoadEvent,
PingEvent,
Expand Down Expand Up @@ -422,6 +423,51 @@ def test_model_load_event_properties_method(self):
assert "install_id" in props


class TestExtensionEntryEvent:
"""Test ExtensionEntryEvent class"""

def test_extension_entry_event_initialization(self):
"""Test ExtensionEntryEvent initialization with extension name"""
event = ExtensionEntryEvent(extension_name="post_hoc_ensembles")

assert event.extension_name == "post_hoc_ensembles"
assert event.name == "extension_entry"

def test_extension_entry_event_default_extension_name(self):
"""Test ExtensionEntryEvent default extension_name is empty string"""
event = ExtensionEntryEvent()

assert event.extension_name == ""
assert event.name == "extension_entry"

def test_extension_entry_event_inherits_base_properties(self):
"""Test that ExtensionEntryEvent inherits base telemetry properties"""
event = ExtensionEntryEvent(extension_name="rf_pfn")

assert isinstance(event.python_version, str)
assert isinstance(event.tabpfn_version, str)
assert isinstance(event.timestamp, datetime)
assert event.source == "sdk"

def test_extension_entry_event_properties_method(self):
"""Test ExtensionEntryEvent properties method"""
event = ExtensionEntryEvent(extension_name="interpretability")

props = event.properties

assert "name" not in props
assert props["extension_name"] == "interpretability"
assert "python_version" in props
assert "tabpfn_version" in props

def test_extension_entry_event_with_colon_separated_name(self):
"""Test ExtensionEntryEvent with sub-extension names like unsupervised:impute"""
event = ExtensionEntryEvent(extension_name="unsupervised:impute")

assert event.extension_name == "unsupervised:impute"
assert event.name == "extension_entry"


class TestEventIntegration:
"""Integration tests for all event types"""

Expand Down
Loading
Loading