From 831fdf4664e5af6afb8a748de96415a802a56394 Mon Sep 17 00:00:00 2001 From: minimAluminiumalism Date: Tue, 27 Jan 2026 14:32:25 +0800 Subject: [PATCH 1/5] fix(agno): fix aresponse missing await and double wrapped() calls in stream methods --- .gitignore | 3 + CHANGELOG-loongsuite.md | 3 + .../instrumentation/agno/_wrapper.py | 52 +++---- .../tests/test_wrapper.py | 139 ++++++++++++++++++ 4 files changed, 162 insertions(+), 35 deletions(-) create mode 100644 instrumentation-loongsuite/loongsuite-instrumentation-agno/tests/test_wrapper.py diff --git a/.gitignore b/.gitignore index 2248762ed..cabe7c72f 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,9 @@ _build/ .mypy_cache/ target +# pyright local config +pyrightconfig.json + # Benchmark result files *-benchmark.json diff --git a/CHANGELOG-loongsuite.md b/CHANGELOG-loongsuite.md index 8f8888539..445a71587 100644 --- a/CHANGELOG-loongsuite.md +++ b/CHANGELOG-loongsuite.md @@ -13,6 +13,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- `loongsuite-instrumentation-agno`: fix aresponse missing await and double wrapped() calls in stream methods + ([#107](https://github.com/alibaba/loongsuite-python-agent/pull/107)) + - `loongsuite-instrumentation-mem0`: fix unittest ([#98](https://github.com/alibaba/loongsuite-python-agent/pull/98)) diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-agno/src/opentelemetry/instrumentation/agno/_wrapper.py b/instrumentation-loongsuite/loongsuite-instrumentation-agno/src/opentelemetry/instrumentation/agno/_wrapper.py index 6c951ac69..5a3ad09c3 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-agno/src/opentelemetry/instrumentation/agno/_wrapper.py +++ b/instrumentation-loongsuite/loongsuite-instrumentation-agno/src/opentelemetry/instrumentation/agno/_wrapper.py @@ -494,20 +494,8 @@ def response_stream( instance, arguments ), ) as with_span: + responses = [] try: - response = wrapped(*args, **kwargs) - except Exception as exception: - with_span.record_exception(exception) - status = trace_api.Status( - status_code=trace_api.StatusCode.ERROR, - # Follow the format in OTEL SDK for description, see: - # https://github.com/open-telemetry/opentelemetry-python/blob/2b9dcfc5d853d1c10176937a6bcaade54cda1a31/opentelemetry-api/src/opentelemetry/trace/__init__.py#L588 # noqa E501 - description=f"{type(exception).__name__}: {exception}", - ) - with_span.finish_tracing(status=status) - raise - try: - responses = [] for response in wrapped(*args, **kwargs): responses.append(response) yield response @@ -521,11 +509,14 @@ def response_stream( attributes=dict(resp_attr), extra_attributes=dict(resp_attr), ) - except Exception: - logger.exception( - f"Failed to finalize response of type {type(response)}" + except Exception as exception: + with_span.record_exception(exception) + status = trace_api.Status( + status_code=trace_api.StatusCode.ERROR, + description=f"{type(exception).__name__}: {exception}", ) - with_span.finish_tracing() + with_span.finish_tracing(status=status) + raise async def aresponse( self, @@ -536,7 +527,7 @@ async def aresponse( ) -> Any: arguments = bind_arguments(wrapped, *args, **kwargs) if not self._enable_genai_capture() or instance is None: - return wrapped(*args, **kwargs) + return await wrapped(*args, **kwargs) with self._start_as_current_span( span_name="Model.aresponse", attributes=self._request_attributes_extractor.extract( @@ -597,20 +588,8 @@ async def aresponse_stream( instance, arguments ), ) as with_span: + responses = [] try: - response = wrapped(*args, **kwargs) - except Exception as exception: - with_span.record_exception(exception) - status = trace_api.Status( - status_code=trace_api.StatusCode.ERROR, - # Follow the format in OTEL SDK for description, see: - # https://github.com/open-telemetry/opentelemetry-python/blob/2b9dcfc5d853d1c10176937a6bcaade54cda1a31/opentelemetry-api/src/opentelemetry/trace/__init__.py#L588 # noqa E501 - description=f"{type(exception).__name__}: {exception}", - ) - with_span.finish_tracing(status=status) - raise - try: - responses = [] async for response in wrapped(*args, **kwargs): responses.append(response) yield response @@ -624,8 +603,11 @@ async def aresponse_stream( attributes=dict(resp_attr), extra_attributes=dict(resp_attr), ) - except Exception: - logger.exception( - f"Failed to finalize response of type {type(response)}" + except Exception as exception: + with_span.record_exception(exception) + status = trace_api.Status( + status_code=trace_api.StatusCode.ERROR, + description=f"{type(exception).__name__}: {exception}", ) - with_span.finish_tracing() + with_span.finish_tracing(status=status) + raise diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-agno/tests/test_wrapper.py b/instrumentation-loongsuite/loongsuite-instrumentation-agno/tests/test_wrapper.py new file mode 100644 index 000000000..1ff29394a --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-agno/tests/test_wrapper.py @@ -0,0 +1,139 @@ +""" +Unit tests for _wrapper.py +""" +import asyncio +from typing import AsyncIterator, Iterator +from unittest.mock import MagicMock + +from opentelemetry.instrumentation.agno._wrapper import AgnoModelWrapper + + +class TestAresponse: + """Tests for aresponse() method.""" + + def test_aresponse_returns_result_not_coroutine(self): + """aresponse() should await wrapped() on early return path.""" + mock_instance = MagicMock() + mock_instance.id = "test-model" + + async def run_test(): + wrapper = AgnoModelWrapper(tracer=MagicMock()) + wrapper._enable_genai_capture = lambda: False + + async def mock_wrapped(*args, **kwargs): + return "expected_result" + + result = await wrapper.aresponse( + mock_wrapped, mock_instance, (), {"messages": []} + ) + assert not asyncio.iscoroutine(result), \ + "aresponse() returned coroutine instead of awaited result" + + asyncio.run(run_test()) + + +class TestResponseStream: + """Tests for response_stream() method.""" + + def test_response_stream_calls_wrapped_once(self): + """response_stream() should call wrapped() exactly once.""" + call_count = [0] + original_method = AgnoModelWrapper.response_stream + + def patched_method(self, wrapped, instance, args, kwargs): + original_wrapped = wrapped + + def counting_wrapped(*a, **kw): + call_count[0] += 1 + return original_wrapped(*a, **kw) + return original_method(self, counting_wrapped, instance, args, kwargs) + + AgnoModelWrapper.response_stream = patched_method + try: + mock_instance = MagicMock() + mock_instance.id = "test-model" + + def mock_generator(*args, **kwargs) -> Iterator[str]: + yield "chunk1" + yield "chunk2" + + wrapper = AgnoModelWrapper(tracer=MagicMock()) + wrapper._enable_genai_capture = lambda: True + wrapper._request_attributes_extractor = MagicMock() + wrapper._request_attributes_extractor.extract.return_value = {} + wrapper._response_attributes_extractor = MagicMock() + wrapper._response_attributes_extractor.extract.return_value = {} + + with_span_mock = MagicMock() + with_span_mock.__enter__ = MagicMock(return_value=with_span_mock) + with_span_mock.__exit__ = MagicMock(return_value=False) + with_span_mock.finish_tracing = MagicMock() + wrapper._start_as_current_span = MagicMock(return_value=with_span_mock) + + results = list(wrapper.response_stream( + mock_generator, mock_instance, (), {"messages": []} + )) + + assert call_count[0] == 1, \ + f"wrapped() called {call_count[0]} times, expected 1" + assert results == ["chunk1", "chunk2"] + finally: + AgnoModelWrapper.response_stream = original_method + + +class TestAresponseStream: + """Tests for aresponse_stream() method.""" + + def test_aresponse_stream_calls_wrapped_once(self): + """aresponse_stream() should call wrapped() exactly once.""" + call_count = [0] + original_method = AgnoModelWrapper.aresponse_stream + + async def patched_method(self, wrapped, instance, args, kwargs): + original_wrapped = wrapped + + def counting_wrapped(*a, **kw): + call_count[0] += 1 + return original_wrapped(*a, **kw) + async for item in original_method( + self, counting_wrapped, instance, args, kwargs + ): + yield item + + AgnoModelWrapper.aresponse_stream = patched_method + try: + mock_instance = MagicMock() + mock_instance.id = "test-model" + + async def mock_async_generator(*args, **kwargs) -> AsyncIterator[str]: + yield "async_chunk1" + yield "async_chunk2" + + async def run_test(): + wrapper = AgnoModelWrapper(tracer=MagicMock()) + wrapper._enable_genai_capture = lambda: True + wrapper._request_attributes_extractor = MagicMock() + wrapper._request_attributes_extractor.extract.return_value = {} + wrapper._response_attributes_extractor = MagicMock() + wrapper._response_attributes_extractor.extract.return_value = {} + + with_span_mock = MagicMock() + with_span_mock.__enter__ = MagicMock(return_value=with_span_mock) + with_span_mock.__exit__ = MagicMock(return_value=False) + with_span_mock.finish_tracing = MagicMock() + wrapper._start_as_current_span = MagicMock(return_value=with_span_mock) + + results = [] + async for chunk in wrapper.aresponse_stream( + mock_async_generator, mock_instance, (), {"messages": []} + ): + results.append(chunk) + return results + + results = asyncio.run(run_test()) + + assert call_count[0] == 1, \ + f"wrapped() called {call_count[0]} times, expected 1" + assert results == ["async_chunk1", "async_chunk2"] + finally: + AgnoModelWrapper.aresponse_stream = original_method From b2a3efb250326c99dc2bbb433972f8f269a4ac4f Mon Sep 17 00:00:00 2001 From: minimAluminiumalism Date: Wed, 28 Jan 2026 19:31:04 +0800 Subject: [PATCH 2/5] fix cicd pipeline failures and apply agents suggestions --- .../tests/test_wrapper.py | 40 ++++++++++++++----- .../genai/_multimodal_upload/pre_uploader.py | 2 +- 2 files changed, 30 insertions(+), 12 deletions(-) diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-agno/tests/test_wrapper.py b/instrumentation-loongsuite/loongsuite-instrumentation-agno/tests/test_wrapper.py index 1ff29394a..f630f3085 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-agno/tests/test_wrapper.py +++ b/instrumentation-loongsuite/loongsuite-instrumentation-agno/tests/test_wrapper.py @@ -1,6 +1,7 @@ """ Unit tests for _wrapper.py """ + import asyncio from typing import AsyncIterator, Iterator from unittest.mock import MagicMock @@ -26,8 +27,9 @@ async def mock_wrapped(*args, **kwargs): result = await wrapper.aresponse( mock_wrapped, mock_instance, (), {"messages": []} ) - assert not asyncio.iscoroutine(result), \ + assert not asyncio.iscoroutine(result), ( "aresponse() returned coroutine instead of awaited result" + ) asyncio.run(run_test()) @@ -46,7 +48,10 @@ def patched_method(self, wrapped, instance, args, kwargs): def counting_wrapped(*a, **kw): call_count[0] += 1 return original_wrapped(*a, **kw) - return original_method(self, counting_wrapped, instance, args, kwargs) + + return original_method( + self, counting_wrapped, instance, args, kwargs + ) AgnoModelWrapper.response_stream = patched_method try: @@ -68,14 +73,19 @@ def mock_generator(*args, **kwargs) -> Iterator[str]: with_span_mock.__enter__ = MagicMock(return_value=with_span_mock) with_span_mock.__exit__ = MagicMock(return_value=False) with_span_mock.finish_tracing = MagicMock() - wrapper._start_as_current_span = MagicMock(return_value=with_span_mock) + wrapper._start_as_current_span = MagicMock( + return_value=with_span_mock + ) - results = list(wrapper.response_stream( - mock_generator, mock_instance, (), {"messages": []} - )) + results = list( + wrapper.response_stream( + mock_generator, mock_instance, (), {"messages": []} + ) + ) - assert call_count[0] == 1, \ + assert call_count[0] == 1, ( f"wrapped() called {call_count[0]} times, expected 1" + ) assert results == ["chunk1", "chunk2"] finally: AgnoModelWrapper.response_stream = original_method @@ -95,6 +105,7 @@ async def patched_method(self, wrapped, instance, args, kwargs): def counting_wrapped(*a, **kw): call_count[0] += 1 return original_wrapped(*a, **kw) + async for item in original_method( self, counting_wrapped, instance, args, kwargs ): @@ -105,7 +116,9 @@ def counting_wrapped(*a, **kw): mock_instance = MagicMock() mock_instance.id = "test-model" - async def mock_async_generator(*args, **kwargs) -> AsyncIterator[str]: + async def mock_async_generator( + *args, **kwargs + ) -> AsyncIterator[str]: yield "async_chunk1" yield "async_chunk2" @@ -118,10 +131,14 @@ async def run_test(): wrapper._response_attributes_extractor.extract.return_value = {} with_span_mock = MagicMock() - with_span_mock.__enter__ = MagicMock(return_value=with_span_mock) + with_span_mock.__enter__ = MagicMock( + return_value=with_span_mock + ) with_span_mock.__exit__ = MagicMock(return_value=False) with_span_mock.finish_tracing = MagicMock() - wrapper._start_as_current_span = MagicMock(return_value=with_span_mock) + wrapper._start_as_current_span = MagicMock( + return_value=with_span_mock + ) results = [] async for chunk in wrapper.aresponse_stream( @@ -132,8 +149,9 @@ async def run_test(): results = asyncio.run(run_test()) - assert call_count[0] == 1, \ + assert call_count[0] == 1, ( f"wrapped() called {call_count[0]} times, expected 1" + ) assert results == ["async_chunk1", "async_chunk2"] finally: AgnoModelWrapper.aresponse_stream = original_method diff --git a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/_multimodal_upload/pre_uploader.py b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/_multimodal_upload/pre_uploader.py index 9259db8c8..7e403c39c 100644 --- a/util/opentelemetry-util-genai/src/opentelemetry/util/genai/_multimodal_upload/pre_uploader.py +++ b/util/opentelemetry-util-genai/src/opentelemetry/util/genai/_multimodal_upload/pre_uploader.py @@ -52,7 +52,7 @@ # Try importing audio processing libraries (optional dependencies) try: - import numpy as np + import numpy as np # pyright: ignore[reportMissingImports] import soundfile as sf # pyright: ignore[reportMissingImports] _audio_libs_available = True From bd59ce66d320f65694f1d155eb529caa26047441 Mon Sep 17 00:00:00 2001 From: minimAluminiumalism Date: Sat, 21 Mar 2026 15:42:19 +0800 Subject: [PATCH 3/5] feat(loongsuite-instrumentation-langchain): add rerank/document-compressor span support Automatically instrument BaseDocumentCompressor subclasses to emit rerank spans with OpenTelemetry semantic conventions. Uses __init_subclass__ hook to cover classes defined after instrumentation. --- .../CHANGELOG.md | 5 + .../README.md | 1 + .../instrumentation/langchain/__init__.py | 25 + .../langchain/internal/patch_rerank.py | 448 +++++++++++++++ .../tests/test_rerank_spans.py | 544 ++++++++++++++++++ 5 files changed, 1023 insertions(+) create mode 100644 instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_rerank.py create mode 100644 instrumentation-loongsuite/loongsuite-instrumentation-langchain/tests/test_rerank_spans.py diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/CHANGELOG.md b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/CHANGELOG.md index bde79c856..73dad1702 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/CHANGELOG.md +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/CHANGELOG.md @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## Unreleased +### Added + +- Rerank / document-compressor span support + ([#149](https://github.com/alibaba/loongsuite-python-agent/pull/149)) + ### Changed - Set `run_inline = True` on the tracer so LangChain callbacks run inline for correct OpenTelemetry context propagation diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/README.md b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/README.md index 8a7a3f132..fc9d93293 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/README.md +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/README.md @@ -130,6 +130,7 @@ loongsuite-instrument | ReAct Step | `STEP` | `gen_ai.operation.name=react`, `gen_ai.react.round`, `gen_ai.react.finish_reason` | | Tool | `TOOL` | `gen_ai.operation.name=execute_tool` | | Retriever | `RETRIEVER` | `gen_ai.operation.name=retrieval` | +| Reranker | `RERANKER` | `gen_ai.operation.name=rerank_documents`, `gen_ai.request.model`, `gen_ai.rerank.documents.count`, `gen_ai.request.top_k`, `gen_ai.rerank.input_documents`, `gen_ai.rerank.output_documents` (when content capture enabled) | ReAct Step spans are created for each Reasoning-Acting iteration, with the hierarchy: Agent > ReAct Step > LLM/Tool. Supported agent types: diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py index a8b258e5a..b67577c46 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py @@ -208,6 +208,29 @@ def _uninstrument_create_agent() -> None: _patched_create_agent_locations.clear() +# ------------------------------------------------------------------ +# BaseDocumentCompressor patch (rerank / compression) +# ------------------------------------------------------------------ + + +def _instrument_document_compressor(handler: Any) -> None: + """Wrap all current and future ``BaseDocumentCompressor`` subclasses.""" + from opentelemetry.instrumentation.langchain.internal.patch_rerank import ( # noqa: PLC0415 + instrument_document_compressor, + ) + + instrument_document_compressor(handler) + + +def _uninstrument_document_compressor() -> None: + """Restore original ``BaseDocumentCompressor`` methods.""" + from opentelemetry.instrumentation.langchain.internal.patch_rerank import ( # noqa: PLC0415 + uninstrument_document_compressor, + ) + + uninstrument_document_compressor() + + class LangChainInstrumentor(BaseInstrumentor): """An instrumentor for LangChain.""" @@ -244,6 +267,7 @@ def _instrument(self, **kwargs: Any) -> None: _instrument_agent_executor() _instrument_create_agent() + _instrument_document_compressor(handler) def _uninstrument(self, **kwargs: Any) -> None: try: @@ -256,6 +280,7 @@ def _uninstrument(self, **kwargs: Any) -> None: _uninstrument_agent_executor() _uninstrument_create_agent() + _uninstrument_document_compressor() class _BaseCallbackManagerInit: diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_rerank.py b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_rerank.py new file mode 100644 index 000000000..dfee59c2b --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_rerank.py @@ -0,0 +1,448 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Rerank instrumentation patch for BaseDocumentCompressor. + +Because ``compress_documents`` is abstract and every subclass overrides +it, ``wrapt.wrap_function_wrapper`` on the base class won't intercept +subclass calls. Instead we: + +1. Retroactively wrap ``compress_documents`` / ``acompress_documents`` + on **every existing subclass** at instrumentation time. +2. Install a ``__init_subclass__`` hook on ``BaseDocumentCompressor`` + so that any subclass defined **after** instrumentation is also + wrapped automatically. +""" + +from __future__ import annotations + +import contextvars +import json +import logging +from typing import TYPE_CHECKING, Any + +import wrapt + +# Depth counter to avoid duplicate spans when a proxy/wrapper compressor +# delegates to an inner compressor (both subclasses are patched). +# Only the outermost call (depth == 0) creates a telemetry span. +_COMPRESSOR_CALL_DEPTH: contextvars.ContextVar[int] = contextvars.ContextVar( + "opentelemetry_langchain_compressor_call_depth", + default=0, +) + +if TYPE_CHECKING: + from opentelemetry.util.genai.extended_handler import ( + ExtendedTelemetryHandler, + ) + +from opentelemetry.instrumentation.langchain.internal._tracer import ( + LoongsuiteTracer, +) +from opentelemetry.util.genai.extended_types import RerankInvocation +from opentelemetry.util.genai.types import Error + +logger = logging.getLogger(__name__) + +# Module-level state for uninstrumentation. +_original_init_subclass: Any = None +_patched_classes: set[type] = set() + +_WRAPPER_TAG = "_loongsuite_rerank_wrapped" + + +# --------------------------------------------------------------------------- +# Helpers — context and metadata extraction +# --------------------------------------------------------------------------- + + +def _find_tracer_from_callbacks(callbacks: Any) -> LoongsuiteTracer | None: + """Find ``LoongsuiteTracer`` from a ``callbacks`` parameter. + + ``callbacks`` may be a ``BaseCallbackManager``, a list of handlers, + or ``None``. + """ + if callbacks is None: + return None + + # BaseCallbackManager (has handlers / inheritable_handlers attrs) + for attr in ("inheritable_handlers", "handlers"): + handlers = getattr(callbacks, attr, None) + if handlers: + for h in handlers: + if isinstance(h, LoongsuiteTracer): + return h + + # Plain list of handlers + if isinstance(callbacks, list): + for h in callbacks: + if isinstance(h, LoongsuiteTracer): + return h + + return None + + +def _get_parent_context(callbacks: Any) -> Any: + """Extract the parent OpenTelemetry ``Context`` from *callbacks*. + + When ``compress_documents`` is invoked from + ``ContextualCompressionRetriever``, ``callbacks`` is a child + ``CallbackManager`` whose ``parent_run_id`` points to the + retriever run. We look up the corresponding ``_RunData`` in the + tracer to get its ``Context`` so the rerank span is parented + correctly. + """ + tracer = _find_tracer_from_callbacks(callbacks) + if tracer is None: + return None + + parent_run_id = getattr(callbacks, "parent_run_id", None) + if parent_run_id is None: + return None + + with tracer._lock: + rd = tracer._runs.get(parent_run_id) + if rd is not None: + return rd.context + return None + + +def _extract_compressor_provider(instance: Any) -> str: + """Infer a provider name from a compressor instance.""" + cls_name = type(instance).__name__ + module = type(instance).__module__ or "" + + _HINTS = [ + ("cohere", "cohere"), + ("jina", "jina"), + ("flashrank", "flashrank"), + ("cross_encoder", "sentence_transformers"), + ("crossencoder", "sentence_transformers"), + ("bge", "bge"), + ] + lower = (module + "." + cls_name).lower() + for hint, provider in _HINTS: + if hint in lower: + return provider + + return "langchain" + + +def _extract_compressor_model(instance: Any) -> str | None: + """Extract a model name from a compressor instance (if available).""" + for attr in ("model_name", "model", "model_id"): + val = getattr(instance, attr, None) + if val and isinstance(val, str): + return val + return None + + +def _extract_top_n(instance: Any) -> int | None: + """Extract ``top_n`` / ``top_k`` from a compressor instance.""" + for attr in ("top_n", "top_k"): + val = getattr(instance, attr, None) + if val is not None and isinstance(val, int): + return val + return None + + +def _documents_to_json(documents: Any) -> str | None: + """Serialise LangChain ``Document`` objects to a JSON string.""" + if not documents: + return None + try: + result = [] + for doc in documents: + entry: dict[str, Any] = {} + content = getattr(doc, "page_content", None) or getattr( + doc, "content", None + ) + if content: + entry["content"] = content + meta = getattr(doc, "metadata", None) or {} + if meta: + entry["metadata"] = meta + doc_id = getattr(doc, "id", None) + if doc_id: + entry["id"] = doc_id + result.append(entry) + return json.dumps(result, ensure_ascii=False, default=str) + except Exception: + logger.debug("Failed to serialize documents", exc_info=True) + return None + + +# --------------------------------------------------------------------------- +# Wrapper factories +# --------------------------------------------------------------------------- + + +def _make_compress_documents_wrapper( + handler: "ExtendedTelemetryHandler", +) -> Any: + """Return a ``wrapt``-style wrapper for ``compress_documents``.""" + + def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + parent_depth = _COMPRESSOR_CALL_DEPTH.get() + depth_token = _COMPRESSOR_CALL_DEPTH.set(parent_depth + 1) + try: + if parent_depth > 0: + # Inner call in a proxy/wrapper chain — skip instrumentation. + return wrapped(*args, **kwargs) + + documents = args[0] if args else kwargs.get("documents", []) + callbacks = kwargs.get("callbacks") or ( + args[2] if len(args) > 2 else None + ) + + parent_ctx = _get_parent_context(callbacks) + + invocation = RerankInvocation( + provider=_extract_compressor_provider(instance), + request_model=_extract_compressor_model(instance), + documents_count=len(documents) if documents else None, + top_k=_extract_top_n(instance), + input_documents=_documents_to_json(documents), + ) + + try: + handler.start_rerank(invocation, context=parent_ctx) + except Exception: + logger.debug("Failed to start rerank span", exc_info=True) + return wrapped(*args, **kwargs) + + try: + result = wrapped(*args, **kwargs) + invocation.output_documents = _documents_to_json(result) + handler.stop_rerank(invocation) + return result + except Exception as exc: + handler.fail_rerank( + invocation, Error(message=str(exc), type=type(exc)) + ) + raise + finally: + _COMPRESSOR_CALL_DEPTH.reset(depth_token) + + return wrapper + + +def _make_acompress_documents_wrapper( + handler: "ExtendedTelemetryHandler", +) -> Any: + """Return a ``wrapt``-style wrapper for ``acompress_documents``. + + Returns a coroutine so the caller can ``await`` it. + """ + + def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + # All ContextVar ops must happen inside the coroutine because + # asyncio.run() copies the context + + async def _instrumented() -> Any: + parent_depth = _COMPRESSOR_CALL_DEPTH.get() + depth_token = _COMPRESSOR_CALL_DEPTH.set(parent_depth + 1) + try: + if parent_depth > 0: + return await wrapped(*args, **kwargs) + + documents = args[0] if args else kwargs.get("documents", []) + callbacks = kwargs.get("callbacks") or ( + args[2] if len(args) > 2 else None + ) + + parent_ctx = _get_parent_context(callbacks) + + invocation = RerankInvocation( + provider=_extract_compressor_provider(instance), + request_model=_extract_compressor_model(instance), + documents_count=len(documents) if documents else None, + top_k=_extract_top_n(instance), + input_documents=_documents_to_json(documents), + ) + + try: + handler.start_rerank(invocation, context=parent_ctx) + except Exception: + logger.debug( + "Failed to start rerank span", exc_info=True + ) + return await wrapped(*args, **kwargs) + + try: + result = await wrapped(*args, **kwargs) + invocation.output_documents = _documents_to_json(result) + handler.stop_rerank(invocation) + return result + except Exception as exc: + handler.fail_rerank( + invocation, Error(message=str(exc), type=type(exc)) + ) + raise + finally: + _COMPRESSOR_CALL_DEPTH.reset(depth_token) + + return _instrumented() + + return wrapper + + +# --------------------------------------------------------------------------- +# Subclass discovery +# --------------------------------------------------------------------------- + + +def _all_subclasses(cls: type) -> set[type]: + """Recursively collect every subclass of *cls*.""" + result: set[type] = set() + queue = list(cls.__subclasses__()) + while queue: + sub = queue.pop() + if sub not in result: + result.add(sub) + queue.extend(sub.__subclasses__()) + return result + + +# --------------------------------------------------------------------------- +# Per-class patching / unpatching +# --------------------------------------------------------------------------- + + +def _patch_class( + cls: type, + sync_wrapper: Any, + async_wrapper: Any, +) -> None: + """Wrap ``compress_documents`` and ``acompress_documents`` on *cls*. + + Only wraps methods that are defined directly in *cls* (i.e. present + in ``cls.__dict__``). Skips classes that are already wrapped. + """ + if getattr(cls, _WRAPPER_TAG, False): + return + + if "compress_documents" in cls.__dict__: + original = cls.__dict__["compress_documents"] + if not isinstance(original, wrapt.FunctionWrapper): + cls.compress_documents = wrapt.FunctionWrapper( + original, sync_wrapper + ) + + if "acompress_documents" in cls.__dict__: + original = cls.__dict__["acompress_documents"] + if not isinstance(original, wrapt.FunctionWrapper): + cls.acompress_documents = wrapt.FunctionWrapper( + original, async_wrapper + ) + + setattr(cls, _WRAPPER_TAG, True) + _patched_classes.add(cls) + + +def _unpatch_class(cls: type) -> None: + """Restore original methods on *cls*.""" + for method_name in ("compress_documents", "acompress_documents"): + method = cls.__dict__.get(method_name) + if isinstance(method, wrapt.FunctionWrapper): + setattr(cls, method_name, method.__wrapped__) + + try: + delattr(cls, _WRAPPER_TAG) + except AttributeError: + pass + + +def instrument_document_compressor( + handler: "ExtendedTelemetryHandler", +) -> None: + """Wrap all current and future ``BaseDocumentCompressor`` subclasses.""" + global _original_init_subclass # noqa: PLW0603 + + try: + from langchain_core.documents.compressor import ( # noqa: PLC0415 + BaseDocumentCompressor, + ) + except ImportError as exc: + logger.debug( + "BaseDocumentCompressor not available, " + "skipping rerank instrumentation: %s", + exc, + ) + return + + sync_wrapper = _make_compress_documents_wrapper(handler) + async_wrapper = _make_acompress_documents_wrapper(handler) + + # 1. Retroactively patch every existing subclass. + for cls in _all_subclasses(BaseDocumentCompressor): + _patch_class(cls, sync_wrapper, async_wrapper) + + # 2. Install an __init_subclass__ hook so future subclasses are + # patched automatically. + _original_init_subclass = BaseDocumentCompressor.__dict__.get( + "__init_subclass__" + ) + + @classmethod # type: ignore[misc] + def _patched_init_subclass(cls: type, **kwargs: Any) -> None: + if _original_init_subclass is not None: + if isinstance(_original_init_subclass, classmethod): + _original_init_subclass.__func__(cls, **kwargs) + else: + _original_init_subclass(**kwargs) + else: + super(BaseDocumentCompressor, cls).__init_subclass__(**kwargs) + _patch_class(cls, sync_wrapper, async_wrapper) + + BaseDocumentCompressor.__init_subclass__ = _patched_init_subclass # type: ignore[assignment] + + logger.debug( + "Patched BaseDocumentCompressor (%d existing subclass(es))", + len(_patched_classes), + ) + + +def uninstrument_document_compressor() -> None: + """Restore original methods on all patched compressor classes.""" + global _original_init_subclass # noqa: PLW0603 + + # Restore __init_subclass__ + try: + from langchain_core.documents.compressor import ( # noqa: PLC0415 + BaseDocumentCompressor, + ) + + if _original_init_subclass is not None: + BaseDocumentCompressor.__init_subclass__ = _original_init_subclass # type: ignore[assignment] + else: + # Remove our hook — fall back to the default behaviour. + if "__init_subclass__" in BaseDocumentCompressor.__dict__: + delattr(BaseDocumentCompressor, "__init_subclass__") + except Exception: + logger.debug( + "Failed to restore BaseDocumentCompressor.__init_subclass__", + exc_info=True, + ) + + for cls in list(_patched_classes): + try: + _unpatch_class(cls) + except Exception: + logger.debug("Failed to unpatch %s", cls, exc_info=True) + _patched_classes.clear() + _original_init_subclass = None + + logger.debug("Restored BaseDocumentCompressor subclasses") diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/tests/test_rerank_spans.py b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/tests/test_rerank_spans.py new file mode 100644 index 000000000..c8580ba79 --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/tests/test_rerank_spans.py @@ -0,0 +1,544 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for rerank / document-compressor span creation and attributes.""" + +from __future__ import annotations + +import asyncio +from typing import Sequence + +import pytest +from langchain_core.callbacks import Callbacks +from langchain_core.documents import Document +from langchain_core.documents.compressor import BaseDocumentCompressor + +from opentelemetry.trace import StatusCode + +# --------------------------------------------------------------------------- +# Fake compressors for testing +# --------------------------------------------------------------------------- + + +class FakeReranker(BaseDocumentCompressor): + """A fake reranker that returns documents with relevance scores.""" + + model_name: str = "fake-rerank-model" + top_n: int = 2 + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Callbacks | None = None, + ) -> Sequence[Document]: + scored = [] + for i, doc in enumerate(documents): + score = 1.0 / (i + 1) + scored.append( + Document( + page_content=doc.page_content, + metadata={**doc.metadata, "relevance_score": score}, + ) + ) + scored.sort( + key=lambda d: d.metadata.get("relevance_score", 0), reverse=True + ) + return scored[: self.top_n] + + +class FakeErrorReranker(BaseDocumentCompressor): + """A fake reranker that always fails.""" + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Callbacks | None = None, + ) -> Sequence[Document]: + raise ValueError("rerank failure") + + +class FakeSimpleCompressor(BaseDocumentCompressor): + """A compressor with no model_name attribute.""" + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Callbacks | None = None, + ) -> Sequence[Document]: + return [doc for doc in documents if len(doc.page_content) > 5] + + +class FakeProxyCompressor(BaseDocumentCompressor): + """A proxy compressor that delegates to an inner compressor.""" + + inner: FakeReranker = None # type: ignore[assignment] + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.inner = FakeReranker() + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Callbacks | None = None, + ) -> Sequence[Document]: + return self.inner.compress_documents(documents, query, callbacks) + + async def acompress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Callbacks | None = None, + ) -> Sequence[Document]: + return self.inner.compress_documents(documents, query, callbacks) + + +class FakeAsyncReranker(BaseDocumentCompressor): + """A fake reranker with both sync and async implementations.""" + + model_name: str = "fake-async-model" + top_n: int = 2 + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Callbacks | None = None, + ) -> Sequence[Document]: + return list(documents[: self.top_n]) + + async def acompress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Callbacks | None = None, + ) -> Sequence[Document]: + scored = [] + for i, doc in enumerate(documents): + score = 1.0 / (i + 1) + scored.append( + Document( + page_content=doc.page_content, + metadata={**doc.metadata, "relevance_score": score}, + ) + ) + scored.sort( + key=lambda d: d.metadata.get("relevance_score", 0), reverse=True + ) + return scored[: self.top_n] + + +class FakeAsyncErrorReranker(BaseDocumentCompressor): + """A fake async reranker that always fails.""" + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Callbacks | None = None, + ) -> Sequence[Document]: + raise ValueError("sync rerank failure") + + async def acompress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Callbacks | None = None, + ) -> Sequence[Document]: + raise ValueError("async rerank failure") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_RERANK_SPAN_NAME = "rerank_documents" + + +def _find_rerank_spans(span_exporter): + spans = span_exporter.get_finished_spans() + return [s for s in spans if _RERANK_SPAN_NAME in s.name.lower()] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestRerankSpanCreation: + def test_reranker_creates_span(self, instrument, span_exporter): + reranker = FakeReranker() + docs = [ + Document(page_content="doc1"), + Document(page_content="doc2"), + Document(page_content="doc3"), + ] + result = reranker.compress_documents(docs, "test query") + assert len(result) == 2 + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + + def test_reranker_error_span(self, instrument, span_exporter): + reranker = FakeErrorReranker() + docs = [Document(page_content="doc1")] + with pytest.raises(ValueError, match="rerank failure"): + reranker.compress_documents(docs, "fail query") + + spans = span_exporter.get_finished_spans() + error_spans = [ + s for s in spans if s.status.status_code == StatusCode.ERROR + ] + assert len(error_spans) >= 1 + + def test_simple_compressor_creates_span(self, instrument, span_exporter): + compressor = FakeSimpleCompressor() + docs = [ + Document(page_content="short"), + Document(page_content="this is a longer document"), + ] + result = compressor.compress_documents(docs, "test") + assert len(result) == 1 + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + + +class TestRerankSpanAttributes: + """Verify rerank span attributes are captured correctly.""" + + def test_operation_name(self, instrument, span_exporter): + reranker = FakeReranker() + docs = [Document(page_content="doc1")] + reranker.compress_documents(docs, "query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + + assert attrs.get("gen_ai.operation.name") == "rerank_documents" + + def test_span_kind_attribute(self, instrument, span_exporter): + reranker = FakeReranker() + docs = [Document(page_content="doc1")] + reranker.compress_documents(docs, "query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + + assert attrs.get("gen_ai.span.kind") == "RERANKER" + + def test_provider_attribute(self, instrument, span_exporter): + reranker = FakeReranker() + docs = [Document(page_content="doc1")] + reranker.compress_documents(docs, "query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + + assert attrs.get("gen_ai.provider.name") == "langchain" + + def test_model_attribute(self, instrument, span_exporter): + reranker = FakeReranker() + docs = [Document(page_content="doc1")] + reranker.compress_documents(docs, "query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + + assert attrs.get("gen_ai.request.model") == "fake-rerank-model" + + def test_documents_count(self, instrument, span_exporter): + reranker = FakeReranker() + docs = [ + Document(page_content="doc1"), + Document(page_content="doc2"), + Document(page_content="doc3"), + ] + reranker.compress_documents(docs, "query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + + assert attrs.get("gen_ai.rerank.documents.count") == 3 + + def test_top_k_attribute(self, instrument, span_exporter): + reranker = FakeReranker(top_n=5) + docs = [Document(page_content="doc1")] + reranker.compress_documents(docs, "query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + + assert attrs.get("gen_ai.request.top_k") == 5 + + def test_span_name_includes_model(self, instrument, span_exporter): + reranker = FakeReranker() + docs = [Document(page_content="doc1")] + reranker.compress_documents(docs, "query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + + assert "fake-rerank-model" in rerank_spans[0].name + + +class TestRerankDocumentContent: + """Verify input/output document content in span attributes.""" + + def test_input_documents_captured(self, instrument, span_exporter): + reranker = FakeReranker() + docs = [ + Document(page_content="Machine learning basics"), + Document(page_content="Deep learning overview"), + ] + reranker.compress_documents(docs, "ML query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + + input_docs = attrs.get("gen_ai.rerank.input_documents", "") + assert "Machine learning basics" in input_docs + assert "Deep learning overview" in input_docs + + def test_output_documents_captured(self, instrument, span_exporter): + reranker = FakeReranker() + docs = [ + Document(page_content="doc1"), + Document(page_content="doc2"), + Document(page_content="doc3"), + ] + reranker.compress_documents(docs, "query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + + output_docs = attrs.get("gen_ai.rerank.output_documents", "") + assert "relevance_score" in output_docs + + def test_no_content_when_disabled( + self, instrument_no_content, span_exporter + ): + """Input/output documents should NOT be captured when content capture is disabled.""" + reranker = FakeReranker() + docs = [Document(page_content="secret doc")] + reranker.compress_documents(docs, "query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + + assert "gen_ai.rerank.input_documents" not in attrs, ( + "Input documents should NOT be captured when content capture is disabled" + ) + assert "gen_ai.rerank.output_documents" not in attrs, ( + "Output documents should NOT be captured when content capture is disabled" + ) + + +class TestAsyncRerankSpans: + """Verify that async acompress_documents is instrumented correctly.""" + + def test_async_reranker_creates_span(self, instrument, span_exporter): + reranker = FakeAsyncReranker() + docs = [ + Document(page_content="doc1"), + Document(page_content="doc2"), + Document(page_content="doc3"), + ] + result = asyncio.run(reranker.acompress_documents(docs, "async query")) + assert len(result) == 2 + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + + def test_async_reranker_span_attributes(self, instrument, span_exporter): + reranker = FakeAsyncReranker() + docs = [Document(page_content="doc1")] + asyncio.run(reranker.acompress_documents(docs, "query")) + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + + assert attrs.get("gen_ai.operation.name") == "rerank_documents" + assert attrs.get("gen_ai.span.kind") == "RERANKER" + assert attrs.get("gen_ai.request.model") == "fake-async-model" + + def test_async_reranker_error_span(self, instrument, span_exporter): + reranker = FakeAsyncErrorReranker() + docs = [Document(page_content="doc1")] + with pytest.raises(ValueError, match="async rerank failure"): + asyncio.run(reranker.acompress_documents(docs, "fail query")) + + spans = span_exporter.get_finished_spans() + error_spans = [ + s for s in spans if s.status.status_code == StatusCode.ERROR + ] + assert len(error_spans) >= 1 + + def test_async_output_documents_captured(self, instrument, span_exporter): + reranker = FakeAsyncReranker() + docs = [ + Document(page_content="doc1"), + Document(page_content="doc2"), + ] + asyncio.run(reranker.acompress_documents(docs, "query")) + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + + output_docs = attrs.get("gen_ai.rerank.output_documents", "") + assert "relevance_score" in output_docs + + +class TestRerankInitSubclassHook: + """Verify that subclasses defined AFTER instrumentation are auto-patched.""" + + def test_post_instrumentation_subclass_creates_span( + self, instrument, span_exporter + ): + # Define a NEW compressor class AFTER instrumentation has been applied. + class LateDefinedCompressor(BaseDocumentCompressor): + model_name: str = "late-model" + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Callbacks | None = None, + ) -> Sequence[Document]: + return list(documents) + + compressor = LateDefinedCompressor() + docs = [Document(page_content="hello")] + compressor.compress_documents(docs, "query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + assert attrs.get("gen_ai.request.model") == "late-model" + + def test_post_instrumentation_async_subclass_creates_span( + self, instrument, span_exporter + ): + class LateAsyncCompressor(BaseDocumentCompressor): + model_name: str = "late-async-model" + + def compress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Callbacks | None = None, + ) -> Sequence[Document]: + return list(documents) + + async def acompress_documents( + self, + documents: Sequence[Document], + query: str, + callbacks: Callbacks | None = None, + ) -> Sequence[Document]: + return list(documents) + + compressor = LateAsyncCompressor() + docs = [Document(page_content="hello")] + asyncio.run(compressor.acompress_documents(docs, "query")) + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + attrs = dict(rerank_spans[0].attributes) + assert attrs.get("gen_ai.request.model") == "late-async-model" + + +class TestRerankDeduplication: + """Verify that proxy/wrapper compressors do NOT produce duplicate spans.""" + + def test_proxy_compressor_single_span(self, instrument, span_exporter): + """A proxy that delegates to an inner compressor should produce + exactly one rerank span, not two.""" + proxy = FakeProxyCompressor() + docs = [ + Document(page_content="doc1"), + Document(page_content="doc2"), + Document(page_content="doc3"), + ] + result = proxy.compress_documents(docs, "test query") + assert len(result) == 2 + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) == 1, ( + f"Expected exactly 1 rerank span, got {len(rerank_spans)}" + ) + + def test_async_proxy_compressor_single_span( + self, instrument, span_exporter + ): + """Async proxy should also produce exactly one span.""" + proxy = FakeProxyCompressor() + docs = [ + Document(page_content="doc1"), + Document(page_content="doc2"), + ] + result = asyncio.run(proxy.acompress_documents(docs, "query")) + assert len(result) == 2 + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) == 1, ( + f"Expected exactly 1 rerank span, got {len(rerank_spans)}" + ) + + def test_direct_compressor_still_creates_span( + self, instrument, span_exporter + ): + """A direct (non-proxy) call should still produce a span.""" + reranker = FakeReranker() + docs = [Document(page_content="doc1")] + reranker.compress_documents(docs, "query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) == 1 + + +class TestRerankUninstrumentation: + """Verify that uninstrument removes rerank spans.""" + + def test_no_spans_after_uninstrument(self, instrument, span_exporter): + reranker = FakeReranker() + docs = [Document(page_content="doc1")] + reranker.compress_documents(docs, "query") + + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) >= 1 + + # Uninstrument + instrument.uninstrument() + span_exporter.clear() + + # Should not produce spans after uninstrumentation + reranker.compress_documents(docs, "query") + rerank_spans = _find_rerank_spans(span_exporter) + assert len(rerank_spans) == 0 From b3d08e367a781508e8d38ce81446bc00cda1e586 Mon Sep 17 00:00:00 2001 From: minimAluminiumalism Date: Mon, 23 Mar 2026 19:41:25 +0800 Subject: [PATCH 4/5] chore: lint fix --- .../langchain/internal/patch_rerank.py | 28 +++++++++---------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_rerank.py b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_rerank.py index dfee59c2b..3bd7efa41 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_rerank.py +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_rerank.py @@ -35,27 +35,27 @@ import wrapt -# Depth counter to avoid duplicate spans when a proxy/wrapper compressor -# delegates to an inner compressor (both subclasses are patched). -# Only the outermost call (depth == 0) creates a telemetry span. -_COMPRESSOR_CALL_DEPTH: contextvars.ContextVar[int] = contextvars.ContextVar( - "opentelemetry_langchain_compressor_call_depth", - default=0, +from opentelemetry.instrumentation.langchain.internal._tracer import ( + LoongsuiteTracer, ) +from opentelemetry.util.genai.extended_types import RerankInvocation +from opentelemetry.util.genai.types import Error if TYPE_CHECKING: from opentelemetry.util.genai.extended_handler import ( ExtendedTelemetryHandler, ) -from opentelemetry.instrumentation.langchain.internal._tracer import ( - LoongsuiteTracer, -) -from opentelemetry.util.genai.extended_types import RerankInvocation -from opentelemetry.util.genai.types import Error - logger = logging.getLogger(__name__) +# Depth counter to avoid duplicate spans when a proxy/wrapper compressor +# delegates to an inner compressor (both subclasses are patched). +# Only the outermost call (depth == 0) creates a telemetry span. +_COMPRESSOR_CALL_DEPTH: contextvars.ContextVar[int] = contextvars.ContextVar( + "opentelemetry_langchain_compressor_call_depth", + default=0, +) + # Module-level state for uninstrumentation. _original_init_subclass: Any = None _patched_classes: set[type] = set() @@ -276,9 +276,7 @@ async def _instrumented() -> Any: try: handler.start_rerank(invocation, context=parent_ctx) except Exception: - logger.debug( - "Failed to start rerank span", exc_info=True - ) + logger.debug("Failed to start rerank span", exc_info=True) return await wrapped(*args, **kwargs) try: From 7a72725b9bf114db5c9a1f3ab18f4485c837fefc Mon Sep 17 00:00:00 2001 From: minimAluminiumalism Date: Thu, 26 Mar 2026 11:39:34 +0800 Subject: [PATCH 5/5] feat: add support for langchain embedding --- .../CHANGELOG.md | 1 + .../README.md | 1 + .../instrumentation/langchain/__init__.py | 25 + .../langchain/internal/patch_embedding.py | 489 ++++++++++++++++++ .../tests/test_embedding_spans.py | 412 +++++++++++++++ 5 files changed, 928 insertions(+) create mode 100644 instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_embedding.py create mode 100644 instrumentation-loongsuite/loongsuite-instrumentation-langchain/tests/test_embedding_spans.py diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/CHANGELOG.md b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/CHANGELOG.md index 73dad1702..49d6e2bd4 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/CHANGELOG.md +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Langchain embedding span support([#157](https://github.com/alibaba/loongsuite-python-agent/pull/157)) - Rerank / document-compressor span support ([#149](https://github.com/alibaba/loongsuite-python-agent/pull/149)) diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/README.md b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/README.md index fc9d93293..fc4b3ad50 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/README.md +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/README.md @@ -130,6 +130,7 @@ loongsuite-instrument | ReAct Step | `STEP` | `gen_ai.operation.name=react`, `gen_ai.react.round`, `gen_ai.react.finish_reason` | | Tool | `TOOL` | `gen_ai.operation.name=execute_tool` | | Retriever | `RETRIEVER` | `gen_ai.operation.name=retrieval` | +| Embedding | `EMBEDDING` | `gen_ai.operation.name=embeddings`, `gen_ai.request.model`, `gen_ai.provider.name`, `server.address`, `server.port`, `gen_ai.embeddings.dimension.count`, `gen_ai.request.encoding_formats` | | Reranker | `RERANKER` | `gen_ai.operation.name=rerank_documents`, `gen_ai.request.model`, `gen_ai.rerank.documents.count`, `gen_ai.request.top_k`, `gen_ai.rerank.input_documents`, `gen_ai.rerank.output_documents` (when content capture enabled) | ReAct Step spans are created for each Reasoning-Acting iteration, with the hierarchy: Agent > ReAct Step > LLM/Tool. Supported agent types: diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py index b67577c46..e0c390c42 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py @@ -208,6 +208,29 @@ def _uninstrument_create_agent() -> None: _patched_create_agent_locations.clear() +# ------------------------------------------------------------------ +# Embeddings patch +# ------------------------------------------------------------------ + + +def _instrument_embeddings(handler: Any) -> None: + """Wrap all current and future ``Embeddings`` subclasses.""" + from opentelemetry.instrumentation.langchain.internal.patch_embedding import ( # noqa: PLC0415 + instrument_embeddings, + ) + + instrument_embeddings(handler) + + +def _uninstrument_embeddings() -> None: + """Restore original ``Embeddings`` methods.""" + from opentelemetry.instrumentation.langchain.internal.patch_embedding import ( # noqa: PLC0415 + uninstrument_embeddings, + ) + + uninstrument_embeddings() + + # ------------------------------------------------------------------ # BaseDocumentCompressor patch (rerank / compression) # ------------------------------------------------------------------ @@ -268,6 +291,7 @@ def _instrument(self, **kwargs: Any) -> None: _instrument_agent_executor() _instrument_create_agent() _instrument_document_compressor(handler) + _instrument_embeddings(handler) def _uninstrument(self, **kwargs: Any) -> None: try: @@ -281,6 +305,7 @@ def _uninstrument(self, **kwargs: Any) -> None: _uninstrument_agent_executor() _uninstrument_create_agent() _uninstrument_document_compressor() + _uninstrument_embeddings() class _BaseCallbackManagerInit: diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_embedding.py b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_embedding.py new file mode 100644 index 000000000..41bb6d271 --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_embedding.py @@ -0,0 +1,489 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Embedding instrumentation patch for LangChain Embeddings. + +Because ``embed_documents`` and ``embed_query`` are abstract, every +subclass overrides them. We use the same strategy as ``patch_rerank.py``: + +1. Retroactively wrap methods on **every existing subclass** at + instrumentation time. +2. Install a ``__init_subclass__`` hook on ``Embeddings`` so that + any subclass defined **after** instrumentation is also wrapped + automatically. +""" + +from __future__ import annotations + +import contextvars +import logging +from typing import TYPE_CHECKING, Any + +import wrapt + +from opentelemetry.util.genai.extended_types import EmbeddingInvocation +from opentelemetry.util.genai.types import Error + +if TYPE_CHECKING: + from opentelemetry.util.genai.extended_handler import ( + ExtendedTelemetryHandler, + ) + +logger = logging.getLogger(__name__) + +# Depth counter to avoid duplicate spans when a proxy embeddings +# delegates to an inner embeddings (both subclasses are patched), +# or when default aembed_* calls patched embed_* via run_in_executor. +_EMBEDDING_CALL_DEPTH: contextvars.ContextVar[int] = contextvars.ContextVar( + "opentelemetry_langchain_embedding_call_depth", + default=0, +) + +# Module-level state for uninstrumentation. +_original_init_subclass: Any = None +_patched_classes: set[type] = set() + +_WRAPPER_TAG = "_loongsuite_embedding_wrapped" + +_SYNC_METHODS = ("embed_documents", "embed_query") +_ASYNC_METHODS = ("aembed_documents", "aembed_query") + + +# --------------------------------------------------------------------------- +# Helpers — metadata extraction +# --------------------------------------------------------------------------- + + +def _extract_embedding_provider(instance: Any) -> str: + """Infer a provider name from an Embeddings instance.""" + cls_name = type(instance).__name__ + module = type(instance).__module__ or "" + + _HINTS = [ + ("openai", "openai"), + ("azure", "azure"), + ("cohere", "cohere"), + ("huggingface", "huggingface"), + ("sentence_transformers", "sentence_transformers"), + ("google", "google"), + ("bedrock", "aws_bedrock"), + ("ollama", "ollama"), + ("jina", "jina"), + ("voyage", "voyageai"), + ("mistral", "mistral"), + ("dashscope", "dashscope"), + ("together", "together"), + ("fireworks", "fireworks"), + ] + lower = (module + "." + cls_name).lower() + for hint, provider in _HINTS: + if hint in lower: + return provider + + return "langchain" + + +def _extract_embedding_model(instance: Any) -> str: + """Extract a model name from an Embeddings instance (if available).""" + for attr in ("model", "model_name", "model_id", "deployment_name"): + val = getattr(instance, attr, None) + if val and isinstance(val, str): + return val + return "" + + +def _extract_server_address_port( + instance: Any, +) -> tuple[str | None, int | None]: + """Extract server address and port from an Embeddings instance.""" + from urllib.parse import urlparse # noqa: PLC0415 + + for attr in ( + "openai_api_base", + "base_url", + "api_base", + "endpoint", + "endpoint_url", + ): + val = getattr(instance, attr, None) + if val and isinstance(val, str): + try: + parsed = urlparse(val) + host = parsed.hostname + port = parsed.port + return host, port + except Exception: + continue + + # Some providers store a client object with base_url + client = getattr(instance, "client", None) + if client is not None: + client_base = getattr(client, "base_url", None) + if client_base: + url_str = str(client_base) + try: + parsed = urlparse(url_str) + return parsed.hostname, parsed.port + except Exception: + pass + + return None, None + + +def _extract_dimension_count(instance: Any) -> int | None: + """Extract embedding dimension count from an Embeddings instance.""" + for attr in ("dimensions", "dimension", "embedding_dim"): + val = getattr(instance, attr, None) + if val is not None and isinstance(val, int) and val > 0: + return val + return None + + +def _extract_encoding_formats(instance: Any) -> list[str] | None: + """Extract encoding formats from an Embeddings instance.""" + for attr in ("encoding_format", "embedding_format"): + val = getattr(instance, attr, None) + if val and isinstance(val, str): + return [val] + if val and isinstance(val, list): + return val + return None + + +# --------------------------------------------------------------------------- +# Wrapper factories +# --------------------------------------------------------------------------- + + +def _build_invocation(instance: Any) -> EmbeddingInvocation: + """Build an ``EmbeddingInvocation`` with all extractable attributes.""" + server_address, server_port = _extract_server_address_port(instance) + return EmbeddingInvocation( + request_model=_extract_embedding_model(instance), + provider=_extract_embedding_provider(instance), + server_address=server_address, + server_port=server_port, + dimension_count=_extract_dimension_count(instance), + encoding_formats=_extract_encoding_formats(instance), + ) + + +def _make_embed_documents_wrapper( + handler: "ExtendedTelemetryHandler", +) -> Any: + """Return a ``wrapt``-style wrapper for ``embed_documents``.""" + + def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + parent_depth = _EMBEDDING_CALL_DEPTH.get() + depth_token = _EMBEDDING_CALL_DEPTH.set(parent_depth + 1) + try: + if parent_depth > 0: + return wrapped(*args, **kwargs) + + invocation = _build_invocation(instance) + + try: + handler.start_embedding(invocation) + except Exception: + logger.debug("Failed to start embedding span", exc_info=True) + return wrapped(*args, **kwargs) + + try: + result = wrapped(*args, **kwargs) + handler.stop_embedding(invocation) + return result + except Exception as exc: + handler.fail_embedding( + invocation, Error(message=str(exc), type=type(exc)) + ) + raise + finally: + _EMBEDDING_CALL_DEPTH.reset(depth_token) + + return wrapper + + +def _make_embed_query_wrapper( + handler: "ExtendedTelemetryHandler", +) -> Any: + """Return a ``wrapt``-style wrapper for ``embed_query``.""" + + def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + parent_depth = _EMBEDDING_CALL_DEPTH.get() + depth_token = _EMBEDDING_CALL_DEPTH.set(parent_depth + 1) + try: + if parent_depth > 0: + return wrapped(*args, **kwargs) + + invocation = _build_invocation(instance) + + try: + handler.start_embedding(invocation) + except Exception: + logger.debug("Failed to start embedding span", exc_info=True) + return wrapped(*args, **kwargs) + + try: + result = wrapped(*args, **kwargs) + handler.stop_embedding(invocation) + return result + except Exception as exc: + handler.fail_embedding( + invocation, Error(message=str(exc), type=type(exc)) + ) + raise + finally: + _EMBEDDING_CALL_DEPTH.reset(depth_token) + + return wrapper + + +def _make_aembed_documents_wrapper( + handler: "ExtendedTelemetryHandler", +) -> Any: + """Return a ``wrapt``-style wrapper for ``aembed_documents``.""" + + def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + async def _instrumented() -> Any: + parent_depth = _EMBEDDING_CALL_DEPTH.get() + depth_token = _EMBEDDING_CALL_DEPTH.set(parent_depth + 1) + try: + if parent_depth > 0: + return await wrapped(*args, **kwargs) + + invocation = _build_invocation(instance) + + try: + handler.start_embedding(invocation) + except Exception: + logger.debug( + "Failed to start embedding span", exc_info=True + ) + return await wrapped(*args, **kwargs) + + try: + result = await wrapped(*args, **kwargs) + handler.stop_embedding(invocation) + return result + except Exception as exc: + handler.fail_embedding( + invocation, Error(message=str(exc), type=type(exc)) + ) + raise + finally: + _EMBEDDING_CALL_DEPTH.reset(depth_token) + + return _instrumented() + + return wrapper + + +def _make_aembed_query_wrapper( + handler: "ExtendedTelemetryHandler", +) -> Any: + """Return a ``wrapt``-style wrapper for ``aembed_query``.""" + + def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + async def _instrumented() -> Any: + parent_depth = _EMBEDDING_CALL_DEPTH.get() + depth_token = _EMBEDDING_CALL_DEPTH.set(parent_depth + 1) + try: + if parent_depth > 0: + return await wrapped(*args, **kwargs) + + invocation = _build_invocation(instance) + + try: + handler.start_embedding(invocation) + except Exception: + logger.debug( + "Failed to start embedding span", exc_info=True + ) + return await wrapped(*args, **kwargs) + + try: + result = await wrapped(*args, **kwargs) + handler.stop_embedding(invocation) + return result + except Exception as exc: + handler.fail_embedding( + invocation, Error(message=str(exc), type=type(exc)) + ) + raise + finally: + _EMBEDDING_CALL_DEPTH.reset(depth_token) + + return _instrumented() + + return wrapper + + +# --------------------------------------------------------------------------- +# Subclass discovery +# --------------------------------------------------------------------------- + + +def _all_subclasses(cls: type) -> set[type]: + """Recursively collect every subclass of *cls*.""" + result: set[type] = set() + queue = list(cls.__subclasses__()) + while queue: + sub = queue.pop() + if sub not in result: + result.add(sub) + queue.extend(sub.__subclasses__()) + return result + + +# --------------------------------------------------------------------------- +# Per-class patching / unpatching +# --------------------------------------------------------------------------- + + +def _patch_class( + cls: type, + sync_doc_wrapper: Any, + sync_query_wrapper: Any, + async_doc_wrapper: Any, + async_query_wrapper: Any, +) -> None: + """Wrap embedding methods on *cls*. + + Only wraps methods that are defined directly in *cls* (i.e. present + in ``cls.__dict__``). Skips classes that are already wrapped. + """ + if getattr(cls, _WRAPPER_TAG, False): + return + + _method_wrappers = { + "embed_documents": sync_doc_wrapper, + "embed_query": sync_query_wrapper, + "aembed_documents": async_doc_wrapper, + "aembed_query": async_query_wrapper, + } + + for method_name, wrapper_fn in _method_wrappers.items(): + if method_name in cls.__dict__: + original = cls.__dict__[method_name] + if not isinstance(original, wrapt.FunctionWrapper): + setattr( + cls, + method_name, + wrapt.FunctionWrapper(original, wrapper_fn), + ) + + setattr(cls, _WRAPPER_TAG, True) + _patched_classes.add(cls) + + +def _unpatch_class(cls: type) -> None: + """Restore original methods on *cls*.""" + for method_name in (*_SYNC_METHODS, *_ASYNC_METHODS): + method = cls.__dict__.get(method_name) + if isinstance(method, wrapt.FunctionWrapper): + setattr(cls, method_name, method.__wrapped__) + + try: + delattr(cls, _WRAPPER_TAG) + except AttributeError: + pass + + +def instrument_embeddings( + handler: "ExtendedTelemetryHandler", +) -> None: + """Wrap all current and future ``Embeddings`` subclasses.""" + global _original_init_subclass # noqa: PLW0603 + + try: + from langchain_core.embeddings import Embeddings # noqa: PLC0415 + except ImportError as exc: + logger.debug( + "Embeddings not available, skipping embedding instrumentation: %s", + exc, + ) + return + + sync_doc_wrapper = _make_embed_documents_wrapper(handler) + sync_query_wrapper = _make_embed_query_wrapper(handler) + async_doc_wrapper = _make_aembed_documents_wrapper(handler) + async_query_wrapper = _make_aembed_query_wrapper(handler) + + # 1. Retroactively patch every existing subclass. + for cls in _all_subclasses(Embeddings): + _patch_class( + cls, + sync_doc_wrapper, + sync_query_wrapper, + async_doc_wrapper, + async_query_wrapper, + ) + + # 2. Install an __init_subclass__ hook so future subclasses are + # patched automatically. + _original_init_subclass = Embeddings.__dict__.get("__init_subclass__") + + @classmethod # type: ignore[misc] + def _patched_init_subclass(cls: type, **kwargs: Any) -> None: + if _original_init_subclass is not None: + if isinstance(_original_init_subclass, classmethod): + _original_init_subclass.__func__(cls, **kwargs) + else: + _original_init_subclass(**kwargs) + else: + super(Embeddings, cls).__init_subclass__(**kwargs) + _patch_class( + cls, + sync_doc_wrapper, + sync_query_wrapper, + async_doc_wrapper, + async_query_wrapper, + ) + + Embeddings.__init_subclass__ = _patched_init_subclass # type: ignore[assignment] + + logger.debug( + "Patched Embeddings (%d existing subclass(es))", + len(_patched_classes), + ) + + +def uninstrument_embeddings() -> None: + """Restore original methods on all patched embeddings classes.""" + global _original_init_subclass # noqa: PLW0603 + + try: + from langchain_core.embeddings import Embeddings # noqa: PLC0415 + + if _original_init_subclass is not None: + Embeddings.__init_subclass__ = _original_init_subclass # type: ignore[assignment] + else: + if "__init_subclass__" in Embeddings.__dict__: + delattr(Embeddings, "__init_subclass__") + except Exception: + logger.debug( + "Failed to restore Embeddings.__init_subclass__", + exc_info=True, + ) + + for cls in list(_patched_classes): + try: + _unpatch_class(cls) + except Exception: + logger.debug("Failed to unpatch %s", cls, exc_info=True) + _patched_classes.clear() + _original_init_subclass = None + + logger.debug("Restored Embeddings subclasses") diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/tests/test_embedding_spans.py b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/tests/test_embedding_spans.py new file mode 100644 index 000000000..1fbae1175 --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/tests/test_embedding_spans.py @@ -0,0 +1,412 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for embedding span creation and attributes.""" + +from __future__ import annotations + +import asyncio + +import pytest +from langchain_core.embeddings import Embeddings + +from opentelemetry.trace import StatusCode + +# --------------------------------------------------------------------------- +# Fake embeddings for testing +# --------------------------------------------------------------------------- + + +class FakeEmbeddings(Embeddings): + """Basic fake embeddings with model_name.""" + + model_name: str = "fake-embed-model" + + def __init__(self, model_name: str = "fake-embed-model"): + self.model_name = model_name + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[0.1, 0.2, 0.3] for _ in texts] + + def embed_query(self, text: str) -> list[float]: + return [0.1, 0.2, 0.3] + + +class FakeOpenAIEmbeddings(Embeddings): + """Fake embeddings with OpenAI-style attributes for server/dimension tests.""" + + model_name: str = "text-embedding-3-small" + openai_api_base: str = "https://api.openai.com:443/v1" + dimensions: int = 1536 + + def __init__(self): + self.model_name = "text-embedding-3-small" + self.openai_api_base = "https://api.openai.com:443/v1" + self.dimensions = 1536 + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[0.1, 0.2, 0.3] for _ in texts] + + def embed_query(self, text: str) -> list[float]: + return [0.1, 0.2, 0.3] + + +class FakeErrorEmbeddings(Embeddings): + """Embeddings that always fail.""" + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + raise ValueError("embedding failure") + + def embed_query(self, text: str) -> list[float]: + raise ValueError("embedding failure") + + +class FakeAsyncEmbeddings(Embeddings): + """Embeddings with native async implementations.""" + + model_name: str = "fake-async-embed-model" + + def __init__(self, model_name: str = "fake-async-embed-model"): + self.model_name = model_name + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[0.1, 0.2] for _ in texts] + + def embed_query(self, text: str) -> list[float]: + return [0.1, 0.2] + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + return [[0.4, 0.5] for _ in texts] + + async def aembed_query(self, text: str) -> list[float]: + return [0.4, 0.5] + + +class FakeAsyncErrorEmbeddings(Embeddings): + """Async embeddings that always fail.""" + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + raise ValueError("sync embedding failure") + + def embed_query(self, text: str) -> list[float]: + raise ValueError("sync embedding failure") + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + raise ValueError("async embedding failure") + + async def aembed_query(self, text: str) -> list[float]: + raise ValueError("async embedding failure") + + +class FakeProxyEmbeddings(Embeddings): + """A proxy that delegates to an inner Embeddings instance.""" + + def __init__(self): + self.inner = FakeEmbeddings() + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return self.inner.embed_documents(texts) + + def embed_query(self, text: str) -> list[float]: + return self.inner.embed_query(text) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_EMBEDDING_SPAN_NAME = "embeddings" + + +def _find_embedding_spans(span_exporter): + spans = span_exporter.get_finished_spans() + return [s for s in spans if s.name.startswith(_EMBEDDING_SPAN_NAME)] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestEmbeddingSpanCreation: + def test_embed_documents_creates_span(self, instrument, span_exporter): + emb = FakeEmbeddings() + result = emb.embed_documents(["hello", "world"]) + assert len(result) == 2 + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + + def test_embed_query_creates_span(self, instrument, span_exporter): + emb = FakeEmbeddings() + result = emb.embed_query("hello") + assert isinstance(result, list) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + + def test_embed_documents_error_span(self, instrument, span_exporter): + emb = FakeErrorEmbeddings() + with pytest.raises(ValueError, match="embedding failure"): + emb.embed_documents(["fail"]) + + spans = span_exporter.get_finished_spans() + error_spans = [ + s for s in spans if s.status.status_code == StatusCode.ERROR + ] + assert len(error_spans) >= 1 + + def test_embed_query_error_span(self, instrument, span_exporter): + emb = FakeErrorEmbeddings() + with pytest.raises(ValueError, match="embedding failure"): + emb.embed_query("fail") + + spans = span_exporter.get_finished_spans() + error_spans = [ + s for s in spans if s.status.status_code == StatusCode.ERROR + ] + assert len(error_spans) >= 1 + + +class TestEmbeddingSpanAttributes: + def test_operation_name(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.operation.name") == "embeddings" + + def test_span_kind_attribute(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.span.kind") == "EMBEDDING" + + def test_provider_attribute(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.provider.name") == "langchain" + + def test_model_attribute(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.request.model") == "fake-embed-model" + + def test_span_name_includes_model(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + assert "fake-embed-model" in spans[0].name + + def test_server_address_and_port(self, instrument, span_exporter): + emb = FakeOpenAIEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("server.address") == "api.openai.com" + assert attrs.get("server.port") == 443 + + def test_dimension_count(self, instrument, span_exporter): + emb = FakeOpenAIEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.embeddings.dimension.count") == 1536 + + def test_no_server_attrs_when_absent(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert "server.address" not in attrs + assert "server.port" not in attrs + assert "gen_ai.embeddings.dimension.count" not in attrs + + +class TestAsyncEmbeddingSpans: + def test_async_embed_documents_creates_span( + self, instrument, span_exporter + ): + emb = FakeAsyncEmbeddings() + result = asyncio.run(emb.aembed_documents(["hello", "world"])) + assert len(result) == 2 + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + + def test_async_embed_query_creates_span(self, instrument, span_exporter): + emb = FakeAsyncEmbeddings() + result = asyncio.run(emb.aembed_query("hello")) + assert isinstance(result, list) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + + def test_async_embed_documents_attributes(self, instrument, span_exporter): + emb = FakeAsyncEmbeddings() + asyncio.run(emb.aembed_documents(["test"])) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.operation.name") == "embeddings" + assert attrs.get("gen_ai.span.kind") == "EMBEDDING" + assert attrs.get("gen_ai.request.model") == "fake-async-embed-model" + + def test_async_embed_documents_error_span(self, instrument, span_exporter): + emb = FakeAsyncErrorEmbeddings() + with pytest.raises(ValueError, match="async embedding failure"): + asyncio.run(emb.aembed_documents(["fail"])) + + spans = span_exporter.get_finished_spans() + error_spans = [ + s for s in spans if s.status.status_code == StatusCode.ERROR + ] + assert len(error_spans) >= 1 + + def test_async_embed_query_error_span(self, instrument, span_exporter): + emb = FakeAsyncErrorEmbeddings() + with pytest.raises(ValueError, match="async embedding failure"): + asyncio.run(emb.aembed_query("fail")) + + spans = span_exporter.get_finished_spans() + error_spans = [ + s for s in spans if s.status.status_code == StatusCode.ERROR + ] + assert len(error_spans) >= 1 + + +class TestEmbeddingDeduplication: + def test_proxy_embed_documents_single_span( + self, instrument, span_exporter + ): + """A proxy that delegates to an inner embeddings should produce + exactly one embedding span, not two.""" + proxy = FakeProxyEmbeddings() + result = proxy.embed_documents(["test"]) + assert len(result) == 1 + + spans = _find_embedding_spans(span_exporter) + assert len(spans) == 1, ( + f"Expected exactly 1 embedding span, got {len(spans)}" + ) + + def test_proxy_embed_query_single_span(self, instrument, span_exporter): + proxy = FakeProxyEmbeddings() + result = proxy.embed_query("test") + assert isinstance(result, list) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) == 1, ( + f"Expected exactly 1 embedding span, got {len(spans)}" + ) + + def test_direct_embeddings_still_creates_span( + self, instrument, span_exporter + ): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) == 1 + + +class TestEmbeddingInitSubclassHook: + def test_post_instrumentation_subclass_creates_span( + self, instrument, span_exporter + ): + class LateDefinedEmbeddings(Embeddings): + model_name: str = "late-embed-model" + + def __init__(self): + self.model_name = "late-embed-model" + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[1.0] for _ in texts] + + def embed_query(self, text: str) -> list[float]: + return [1.0] + + emb = LateDefinedEmbeddings() + emb.embed_documents(["hello"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.request.model") == "late-embed-model" + + def test_post_instrumentation_async_subclass_creates_span( + self, instrument, span_exporter + ): + class LateAsyncEmbeddings(Embeddings): + model_name: str = "late-async-embed-model" + + def __init__(self): + self.model_name = "late-async-embed-model" + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[1.0] for _ in texts] + + def embed_query(self, text: str) -> list[float]: + return [1.0] + + async def aembed_documents( + self, texts: list[str] + ) -> list[list[float]]: + return [[2.0] for _ in texts] + + emb = LateAsyncEmbeddings() + asyncio.run(emb.aembed_documents(["hello"])) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.request.model") == "late-async-embed-model" + + +class TestEmbeddingUninstrumentation: + def test_no_spans_after_uninstrument(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + + instrument.uninstrument() + span_exporter.clear() + + emb.embed_documents(["test"]) + spans = _find_embedding_spans(span_exporter) + assert len(spans) == 0