diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/CHANGELOG.md b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/CHANGELOG.md index 73dad1702..49d6e2bd4 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/CHANGELOG.md +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- Langchain embedding span support([#157](https://github.com/alibaba/loongsuite-python-agent/pull/157)) - Rerank / document-compressor span support ([#149](https://github.com/alibaba/loongsuite-python-agent/pull/149)) diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/README.md b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/README.md index fc9d93293..fc4b3ad50 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/README.md +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/README.md @@ -130,6 +130,7 @@ loongsuite-instrument | ReAct Step | `STEP` | `gen_ai.operation.name=react`, `gen_ai.react.round`, `gen_ai.react.finish_reason` | | Tool | `TOOL` | `gen_ai.operation.name=execute_tool` | | Retriever | `RETRIEVER` | `gen_ai.operation.name=retrieval` | +| Embedding | `EMBEDDING` | `gen_ai.operation.name=embeddings`, `gen_ai.request.model`, `gen_ai.provider.name`, `server.address`, `server.port`, `gen_ai.embeddings.dimension.count`, `gen_ai.request.encoding_formats` | | Reranker | `RERANKER` | `gen_ai.operation.name=rerank_documents`, `gen_ai.request.model`, `gen_ai.rerank.documents.count`, `gen_ai.request.top_k`, `gen_ai.rerank.input_documents`, `gen_ai.rerank.output_documents` (when content capture enabled) | ReAct Step spans are created for each Reasoning-Acting iteration, with the hierarchy: Agent > ReAct Step > LLM/Tool. Supported agent types: diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py index b67577c46..e0c390c42 100644 --- a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/__init__.py @@ -208,6 +208,29 @@ def _uninstrument_create_agent() -> None: _patched_create_agent_locations.clear() +# ------------------------------------------------------------------ +# Embeddings patch +# ------------------------------------------------------------------ + + +def _instrument_embeddings(handler: Any) -> None: + """Wrap all current and future ``Embeddings`` subclasses.""" + from opentelemetry.instrumentation.langchain.internal.patch_embedding import ( # noqa: PLC0415 + instrument_embeddings, + ) + + instrument_embeddings(handler) + + +def _uninstrument_embeddings() -> None: + """Restore original ``Embeddings`` methods.""" + from opentelemetry.instrumentation.langchain.internal.patch_embedding import ( # noqa: PLC0415 + uninstrument_embeddings, + ) + + uninstrument_embeddings() + + # ------------------------------------------------------------------ # BaseDocumentCompressor patch (rerank / compression) # ------------------------------------------------------------------ @@ -268,6 +291,7 @@ def _instrument(self, **kwargs: Any) -> None: _instrument_agent_executor() _instrument_create_agent() _instrument_document_compressor(handler) + _instrument_embeddings(handler) def _uninstrument(self, **kwargs: Any) -> None: try: @@ -281,6 +305,7 @@ def _uninstrument(self, **kwargs: Any) -> None: _uninstrument_agent_executor() _uninstrument_create_agent() _uninstrument_document_compressor() + _uninstrument_embeddings() class _BaseCallbackManagerInit: diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_embedding.py b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_embedding.py new file mode 100644 index 000000000..41bb6d271 --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/src/opentelemetry/instrumentation/langchain/internal/patch_embedding.py @@ -0,0 +1,489 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Embedding instrumentation patch for LangChain Embeddings. + +Because ``embed_documents`` and ``embed_query`` are abstract, every +subclass overrides them. We use the same strategy as ``patch_rerank.py``: + +1. Retroactively wrap methods on **every existing subclass** at + instrumentation time. +2. Install a ``__init_subclass__`` hook on ``Embeddings`` so that + any subclass defined **after** instrumentation is also wrapped + automatically. +""" + +from __future__ import annotations + +import contextvars +import logging +from typing import TYPE_CHECKING, Any + +import wrapt + +from opentelemetry.util.genai.extended_types import EmbeddingInvocation +from opentelemetry.util.genai.types import Error + +if TYPE_CHECKING: + from opentelemetry.util.genai.extended_handler import ( + ExtendedTelemetryHandler, + ) + +logger = logging.getLogger(__name__) + +# Depth counter to avoid duplicate spans when a proxy embeddings +# delegates to an inner embeddings (both subclasses are patched), +# or when default aembed_* calls patched embed_* via run_in_executor. +_EMBEDDING_CALL_DEPTH: contextvars.ContextVar[int] = contextvars.ContextVar( + "opentelemetry_langchain_embedding_call_depth", + default=0, +) + +# Module-level state for uninstrumentation. +_original_init_subclass: Any = None +_patched_classes: set[type] = set() + +_WRAPPER_TAG = "_loongsuite_embedding_wrapped" + +_SYNC_METHODS = ("embed_documents", "embed_query") +_ASYNC_METHODS = ("aembed_documents", "aembed_query") + + +# --------------------------------------------------------------------------- +# Helpers — metadata extraction +# --------------------------------------------------------------------------- + + +def _extract_embedding_provider(instance: Any) -> str: + """Infer a provider name from an Embeddings instance.""" + cls_name = type(instance).__name__ + module = type(instance).__module__ or "" + + _HINTS = [ + ("openai", "openai"), + ("azure", "azure"), + ("cohere", "cohere"), + ("huggingface", "huggingface"), + ("sentence_transformers", "sentence_transformers"), + ("google", "google"), + ("bedrock", "aws_bedrock"), + ("ollama", "ollama"), + ("jina", "jina"), + ("voyage", "voyageai"), + ("mistral", "mistral"), + ("dashscope", "dashscope"), + ("together", "together"), + ("fireworks", "fireworks"), + ] + lower = (module + "." + cls_name).lower() + for hint, provider in _HINTS: + if hint in lower: + return provider + + return "langchain" + + +def _extract_embedding_model(instance: Any) -> str: + """Extract a model name from an Embeddings instance (if available).""" + for attr in ("model", "model_name", "model_id", "deployment_name"): + val = getattr(instance, attr, None) + if val and isinstance(val, str): + return val + return "" + + +def _extract_server_address_port( + instance: Any, +) -> tuple[str | None, int | None]: + """Extract server address and port from an Embeddings instance.""" + from urllib.parse import urlparse # noqa: PLC0415 + + for attr in ( + "openai_api_base", + "base_url", + "api_base", + "endpoint", + "endpoint_url", + ): + val = getattr(instance, attr, None) + if val and isinstance(val, str): + try: + parsed = urlparse(val) + host = parsed.hostname + port = parsed.port + return host, port + except Exception: + continue + + # Some providers store a client object with base_url + client = getattr(instance, "client", None) + if client is not None: + client_base = getattr(client, "base_url", None) + if client_base: + url_str = str(client_base) + try: + parsed = urlparse(url_str) + return parsed.hostname, parsed.port + except Exception: + pass + + return None, None + + +def _extract_dimension_count(instance: Any) -> int | None: + """Extract embedding dimension count from an Embeddings instance.""" + for attr in ("dimensions", "dimension", "embedding_dim"): + val = getattr(instance, attr, None) + if val is not None and isinstance(val, int) and val > 0: + return val + return None + + +def _extract_encoding_formats(instance: Any) -> list[str] | None: + """Extract encoding formats from an Embeddings instance.""" + for attr in ("encoding_format", "embedding_format"): + val = getattr(instance, attr, None) + if val and isinstance(val, str): + return [val] + if val and isinstance(val, list): + return val + return None + + +# --------------------------------------------------------------------------- +# Wrapper factories +# --------------------------------------------------------------------------- + + +def _build_invocation(instance: Any) -> EmbeddingInvocation: + """Build an ``EmbeddingInvocation`` with all extractable attributes.""" + server_address, server_port = _extract_server_address_port(instance) + return EmbeddingInvocation( + request_model=_extract_embedding_model(instance), + provider=_extract_embedding_provider(instance), + server_address=server_address, + server_port=server_port, + dimension_count=_extract_dimension_count(instance), + encoding_formats=_extract_encoding_formats(instance), + ) + + +def _make_embed_documents_wrapper( + handler: "ExtendedTelemetryHandler", +) -> Any: + """Return a ``wrapt``-style wrapper for ``embed_documents``.""" + + def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + parent_depth = _EMBEDDING_CALL_DEPTH.get() + depth_token = _EMBEDDING_CALL_DEPTH.set(parent_depth + 1) + try: + if parent_depth > 0: + return wrapped(*args, **kwargs) + + invocation = _build_invocation(instance) + + try: + handler.start_embedding(invocation) + except Exception: + logger.debug("Failed to start embedding span", exc_info=True) + return wrapped(*args, **kwargs) + + try: + result = wrapped(*args, **kwargs) + handler.stop_embedding(invocation) + return result + except Exception as exc: + handler.fail_embedding( + invocation, Error(message=str(exc), type=type(exc)) + ) + raise + finally: + _EMBEDDING_CALL_DEPTH.reset(depth_token) + + return wrapper + + +def _make_embed_query_wrapper( + handler: "ExtendedTelemetryHandler", +) -> Any: + """Return a ``wrapt``-style wrapper for ``embed_query``.""" + + def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + parent_depth = _EMBEDDING_CALL_DEPTH.get() + depth_token = _EMBEDDING_CALL_DEPTH.set(parent_depth + 1) + try: + if parent_depth > 0: + return wrapped(*args, **kwargs) + + invocation = _build_invocation(instance) + + try: + handler.start_embedding(invocation) + except Exception: + logger.debug("Failed to start embedding span", exc_info=True) + return wrapped(*args, **kwargs) + + try: + result = wrapped(*args, **kwargs) + handler.stop_embedding(invocation) + return result + except Exception as exc: + handler.fail_embedding( + invocation, Error(message=str(exc), type=type(exc)) + ) + raise + finally: + _EMBEDDING_CALL_DEPTH.reset(depth_token) + + return wrapper + + +def _make_aembed_documents_wrapper( + handler: "ExtendedTelemetryHandler", +) -> Any: + """Return a ``wrapt``-style wrapper for ``aembed_documents``.""" + + def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + async def _instrumented() -> Any: + parent_depth = _EMBEDDING_CALL_DEPTH.get() + depth_token = _EMBEDDING_CALL_DEPTH.set(parent_depth + 1) + try: + if parent_depth > 0: + return await wrapped(*args, **kwargs) + + invocation = _build_invocation(instance) + + try: + handler.start_embedding(invocation) + except Exception: + logger.debug( + "Failed to start embedding span", exc_info=True + ) + return await wrapped(*args, **kwargs) + + try: + result = await wrapped(*args, **kwargs) + handler.stop_embedding(invocation) + return result + except Exception as exc: + handler.fail_embedding( + invocation, Error(message=str(exc), type=type(exc)) + ) + raise + finally: + _EMBEDDING_CALL_DEPTH.reset(depth_token) + + return _instrumented() + + return wrapper + + +def _make_aembed_query_wrapper( + handler: "ExtendedTelemetryHandler", +) -> Any: + """Return a ``wrapt``-style wrapper for ``aembed_query``.""" + + def wrapper(wrapped: Any, instance: Any, args: Any, kwargs: Any) -> Any: + async def _instrumented() -> Any: + parent_depth = _EMBEDDING_CALL_DEPTH.get() + depth_token = _EMBEDDING_CALL_DEPTH.set(parent_depth + 1) + try: + if parent_depth > 0: + return await wrapped(*args, **kwargs) + + invocation = _build_invocation(instance) + + try: + handler.start_embedding(invocation) + except Exception: + logger.debug( + "Failed to start embedding span", exc_info=True + ) + return await wrapped(*args, **kwargs) + + try: + result = await wrapped(*args, **kwargs) + handler.stop_embedding(invocation) + return result + except Exception as exc: + handler.fail_embedding( + invocation, Error(message=str(exc), type=type(exc)) + ) + raise + finally: + _EMBEDDING_CALL_DEPTH.reset(depth_token) + + return _instrumented() + + return wrapper + + +# --------------------------------------------------------------------------- +# Subclass discovery +# --------------------------------------------------------------------------- + + +def _all_subclasses(cls: type) -> set[type]: + """Recursively collect every subclass of *cls*.""" + result: set[type] = set() + queue = list(cls.__subclasses__()) + while queue: + sub = queue.pop() + if sub not in result: + result.add(sub) + queue.extend(sub.__subclasses__()) + return result + + +# --------------------------------------------------------------------------- +# Per-class patching / unpatching +# --------------------------------------------------------------------------- + + +def _patch_class( + cls: type, + sync_doc_wrapper: Any, + sync_query_wrapper: Any, + async_doc_wrapper: Any, + async_query_wrapper: Any, +) -> None: + """Wrap embedding methods on *cls*. + + Only wraps methods that are defined directly in *cls* (i.e. present + in ``cls.__dict__``). Skips classes that are already wrapped. + """ + if getattr(cls, _WRAPPER_TAG, False): + return + + _method_wrappers = { + "embed_documents": sync_doc_wrapper, + "embed_query": sync_query_wrapper, + "aembed_documents": async_doc_wrapper, + "aembed_query": async_query_wrapper, + } + + for method_name, wrapper_fn in _method_wrappers.items(): + if method_name in cls.__dict__: + original = cls.__dict__[method_name] + if not isinstance(original, wrapt.FunctionWrapper): + setattr( + cls, + method_name, + wrapt.FunctionWrapper(original, wrapper_fn), + ) + + setattr(cls, _WRAPPER_TAG, True) + _patched_classes.add(cls) + + +def _unpatch_class(cls: type) -> None: + """Restore original methods on *cls*.""" + for method_name in (*_SYNC_METHODS, *_ASYNC_METHODS): + method = cls.__dict__.get(method_name) + if isinstance(method, wrapt.FunctionWrapper): + setattr(cls, method_name, method.__wrapped__) + + try: + delattr(cls, _WRAPPER_TAG) + except AttributeError: + pass + + +def instrument_embeddings( + handler: "ExtendedTelemetryHandler", +) -> None: + """Wrap all current and future ``Embeddings`` subclasses.""" + global _original_init_subclass # noqa: PLW0603 + + try: + from langchain_core.embeddings import Embeddings # noqa: PLC0415 + except ImportError as exc: + logger.debug( + "Embeddings not available, skipping embedding instrumentation: %s", + exc, + ) + return + + sync_doc_wrapper = _make_embed_documents_wrapper(handler) + sync_query_wrapper = _make_embed_query_wrapper(handler) + async_doc_wrapper = _make_aembed_documents_wrapper(handler) + async_query_wrapper = _make_aembed_query_wrapper(handler) + + # 1. Retroactively patch every existing subclass. + for cls in _all_subclasses(Embeddings): + _patch_class( + cls, + sync_doc_wrapper, + sync_query_wrapper, + async_doc_wrapper, + async_query_wrapper, + ) + + # 2. Install an __init_subclass__ hook so future subclasses are + # patched automatically. + _original_init_subclass = Embeddings.__dict__.get("__init_subclass__") + + @classmethod # type: ignore[misc] + def _patched_init_subclass(cls: type, **kwargs: Any) -> None: + if _original_init_subclass is not None: + if isinstance(_original_init_subclass, classmethod): + _original_init_subclass.__func__(cls, **kwargs) + else: + _original_init_subclass(**kwargs) + else: + super(Embeddings, cls).__init_subclass__(**kwargs) + _patch_class( + cls, + sync_doc_wrapper, + sync_query_wrapper, + async_doc_wrapper, + async_query_wrapper, + ) + + Embeddings.__init_subclass__ = _patched_init_subclass # type: ignore[assignment] + + logger.debug( + "Patched Embeddings (%d existing subclass(es))", + len(_patched_classes), + ) + + +def uninstrument_embeddings() -> None: + """Restore original methods on all patched embeddings classes.""" + global _original_init_subclass # noqa: PLW0603 + + try: + from langchain_core.embeddings import Embeddings # noqa: PLC0415 + + if _original_init_subclass is not None: + Embeddings.__init_subclass__ = _original_init_subclass # type: ignore[assignment] + else: + if "__init_subclass__" in Embeddings.__dict__: + delattr(Embeddings, "__init_subclass__") + except Exception: + logger.debug( + "Failed to restore Embeddings.__init_subclass__", + exc_info=True, + ) + + for cls in list(_patched_classes): + try: + _unpatch_class(cls) + except Exception: + logger.debug("Failed to unpatch %s", cls, exc_info=True) + _patched_classes.clear() + _original_init_subclass = None + + logger.debug("Restored Embeddings subclasses") diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-langchain/tests/test_embedding_spans.py b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/tests/test_embedding_spans.py new file mode 100644 index 000000000..1fbae1175 --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-langchain/tests/test_embedding_spans.py @@ -0,0 +1,412 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for embedding span creation and attributes.""" + +from __future__ import annotations + +import asyncio + +import pytest +from langchain_core.embeddings import Embeddings + +from opentelemetry.trace import StatusCode + +# --------------------------------------------------------------------------- +# Fake embeddings for testing +# --------------------------------------------------------------------------- + + +class FakeEmbeddings(Embeddings): + """Basic fake embeddings with model_name.""" + + model_name: str = "fake-embed-model" + + def __init__(self, model_name: str = "fake-embed-model"): + self.model_name = model_name + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[0.1, 0.2, 0.3] for _ in texts] + + def embed_query(self, text: str) -> list[float]: + return [0.1, 0.2, 0.3] + + +class FakeOpenAIEmbeddings(Embeddings): + """Fake embeddings with OpenAI-style attributes for server/dimension tests.""" + + model_name: str = "text-embedding-3-small" + openai_api_base: str = "https://api.openai.com:443/v1" + dimensions: int = 1536 + + def __init__(self): + self.model_name = "text-embedding-3-small" + self.openai_api_base = "https://api.openai.com:443/v1" + self.dimensions = 1536 + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[0.1, 0.2, 0.3] for _ in texts] + + def embed_query(self, text: str) -> list[float]: + return [0.1, 0.2, 0.3] + + +class FakeErrorEmbeddings(Embeddings): + """Embeddings that always fail.""" + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + raise ValueError("embedding failure") + + def embed_query(self, text: str) -> list[float]: + raise ValueError("embedding failure") + + +class FakeAsyncEmbeddings(Embeddings): + """Embeddings with native async implementations.""" + + model_name: str = "fake-async-embed-model" + + def __init__(self, model_name: str = "fake-async-embed-model"): + self.model_name = model_name + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[0.1, 0.2] for _ in texts] + + def embed_query(self, text: str) -> list[float]: + return [0.1, 0.2] + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + return [[0.4, 0.5] for _ in texts] + + async def aembed_query(self, text: str) -> list[float]: + return [0.4, 0.5] + + +class FakeAsyncErrorEmbeddings(Embeddings): + """Async embeddings that always fail.""" + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + raise ValueError("sync embedding failure") + + def embed_query(self, text: str) -> list[float]: + raise ValueError("sync embedding failure") + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: + raise ValueError("async embedding failure") + + async def aembed_query(self, text: str) -> list[float]: + raise ValueError("async embedding failure") + + +class FakeProxyEmbeddings(Embeddings): + """A proxy that delegates to an inner Embeddings instance.""" + + def __init__(self): + self.inner = FakeEmbeddings() + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return self.inner.embed_documents(texts) + + def embed_query(self, text: str) -> list[float]: + return self.inner.embed_query(text) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_EMBEDDING_SPAN_NAME = "embeddings" + + +def _find_embedding_spans(span_exporter): + spans = span_exporter.get_finished_spans() + return [s for s in spans if s.name.startswith(_EMBEDDING_SPAN_NAME)] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestEmbeddingSpanCreation: + def test_embed_documents_creates_span(self, instrument, span_exporter): + emb = FakeEmbeddings() + result = emb.embed_documents(["hello", "world"]) + assert len(result) == 2 + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + + def test_embed_query_creates_span(self, instrument, span_exporter): + emb = FakeEmbeddings() + result = emb.embed_query("hello") + assert isinstance(result, list) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + + def test_embed_documents_error_span(self, instrument, span_exporter): + emb = FakeErrorEmbeddings() + with pytest.raises(ValueError, match="embedding failure"): + emb.embed_documents(["fail"]) + + spans = span_exporter.get_finished_spans() + error_spans = [ + s for s in spans if s.status.status_code == StatusCode.ERROR + ] + assert len(error_spans) >= 1 + + def test_embed_query_error_span(self, instrument, span_exporter): + emb = FakeErrorEmbeddings() + with pytest.raises(ValueError, match="embedding failure"): + emb.embed_query("fail") + + spans = span_exporter.get_finished_spans() + error_spans = [ + s for s in spans if s.status.status_code == StatusCode.ERROR + ] + assert len(error_spans) >= 1 + + +class TestEmbeddingSpanAttributes: + def test_operation_name(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.operation.name") == "embeddings" + + def test_span_kind_attribute(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.span.kind") == "EMBEDDING" + + def test_provider_attribute(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.provider.name") == "langchain" + + def test_model_attribute(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.request.model") == "fake-embed-model" + + def test_span_name_includes_model(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + assert "fake-embed-model" in spans[0].name + + def test_server_address_and_port(self, instrument, span_exporter): + emb = FakeOpenAIEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("server.address") == "api.openai.com" + assert attrs.get("server.port") == 443 + + def test_dimension_count(self, instrument, span_exporter): + emb = FakeOpenAIEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.embeddings.dimension.count") == 1536 + + def test_no_server_attrs_when_absent(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert "server.address" not in attrs + assert "server.port" not in attrs + assert "gen_ai.embeddings.dimension.count" not in attrs + + +class TestAsyncEmbeddingSpans: + def test_async_embed_documents_creates_span( + self, instrument, span_exporter + ): + emb = FakeAsyncEmbeddings() + result = asyncio.run(emb.aembed_documents(["hello", "world"])) + assert len(result) == 2 + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + + def test_async_embed_query_creates_span(self, instrument, span_exporter): + emb = FakeAsyncEmbeddings() + result = asyncio.run(emb.aembed_query("hello")) + assert isinstance(result, list) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + + def test_async_embed_documents_attributes(self, instrument, span_exporter): + emb = FakeAsyncEmbeddings() + asyncio.run(emb.aembed_documents(["test"])) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.operation.name") == "embeddings" + assert attrs.get("gen_ai.span.kind") == "EMBEDDING" + assert attrs.get("gen_ai.request.model") == "fake-async-embed-model" + + def test_async_embed_documents_error_span(self, instrument, span_exporter): + emb = FakeAsyncErrorEmbeddings() + with pytest.raises(ValueError, match="async embedding failure"): + asyncio.run(emb.aembed_documents(["fail"])) + + spans = span_exporter.get_finished_spans() + error_spans = [ + s for s in spans if s.status.status_code == StatusCode.ERROR + ] + assert len(error_spans) >= 1 + + def test_async_embed_query_error_span(self, instrument, span_exporter): + emb = FakeAsyncErrorEmbeddings() + with pytest.raises(ValueError, match="async embedding failure"): + asyncio.run(emb.aembed_query("fail")) + + spans = span_exporter.get_finished_spans() + error_spans = [ + s for s in spans if s.status.status_code == StatusCode.ERROR + ] + assert len(error_spans) >= 1 + + +class TestEmbeddingDeduplication: + def test_proxy_embed_documents_single_span( + self, instrument, span_exporter + ): + """A proxy that delegates to an inner embeddings should produce + exactly one embedding span, not two.""" + proxy = FakeProxyEmbeddings() + result = proxy.embed_documents(["test"]) + assert len(result) == 1 + + spans = _find_embedding_spans(span_exporter) + assert len(spans) == 1, ( + f"Expected exactly 1 embedding span, got {len(spans)}" + ) + + def test_proxy_embed_query_single_span(self, instrument, span_exporter): + proxy = FakeProxyEmbeddings() + result = proxy.embed_query("test") + assert isinstance(result, list) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) == 1, ( + f"Expected exactly 1 embedding span, got {len(spans)}" + ) + + def test_direct_embeddings_still_creates_span( + self, instrument, span_exporter + ): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) == 1 + + +class TestEmbeddingInitSubclassHook: + def test_post_instrumentation_subclass_creates_span( + self, instrument, span_exporter + ): + class LateDefinedEmbeddings(Embeddings): + model_name: str = "late-embed-model" + + def __init__(self): + self.model_name = "late-embed-model" + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[1.0] for _ in texts] + + def embed_query(self, text: str) -> list[float]: + return [1.0] + + emb = LateDefinedEmbeddings() + emb.embed_documents(["hello"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.request.model") == "late-embed-model" + + def test_post_instrumentation_async_subclass_creates_span( + self, instrument, span_exporter + ): + class LateAsyncEmbeddings(Embeddings): + model_name: str = "late-async-embed-model" + + def __init__(self): + self.model_name = "late-async-embed-model" + + def embed_documents(self, texts: list[str]) -> list[list[float]]: + return [[1.0] for _ in texts] + + def embed_query(self, text: str) -> list[float]: + return [1.0] + + async def aembed_documents( + self, texts: list[str] + ) -> list[list[float]]: + return [[2.0] for _ in texts] + + emb = LateAsyncEmbeddings() + asyncio.run(emb.aembed_documents(["hello"])) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + attrs = dict(spans[0].attributes) + assert attrs.get("gen_ai.request.model") == "late-async-embed-model" + + +class TestEmbeddingUninstrumentation: + def test_no_spans_after_uninstrument(self, instrument, span_exporter): + emb = FakeEmbeddings() + emb.embed_documents(["test"]) + + spans = _find_embedding_spans(span_exporter) + assert len(spans) >= 1 + + instrument.uninstrument() + span_exporter.clear() + + emb.embed_documents(["test"]) + spans = _find_embedding_spans(span_exporter) + assert len(spans) == 0