diff --git a/ci/vale/styles/config/vocabularies/nat/accept.txt b/ci/vale/styles/config/vocabularies/nat/accept.txt index 853bd3a841..bc541eae8c 100644 --- a/ci/vale/styles/config/vocabularies/nat/accept.txt +++ b/ci/vale/styles/config/vocabularies/nat/accept.txt @@ -151,6 +151,7 @@ Pareto Patronus PCIe PDF(s?) +Perplexity [Pp]luggable [Pp]ostprocess [Pp]ostprocessing diff --git a/docs/source/build-workflows/embedders.md b/docs/source/build-workflows/embedders.md index ab114b68e7..e53522454c 100644 --- a/docs/source/build-workflows/embedders.md +++ b/docs/source/build-workflows/embedders.md @@ -25,6 +25,7 @@ NeMo Agent Toolkit supports the following embedder providers: | [OpenAI](https://openai.com) | `openai` | OpenAI API | | [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-foundry/openai/quickstart) | `azure_openai` | Azure OpenAI API | | [Hugging Face](https://huggingface.co) | `huggingface` | Local sentence-transformers or remote Inference Endpoints (TEI) | +| [Perplexity](https://docs.perplexity.ai/api-reference/embeddings-post) | `perplexity` | Perplexity Embeddings API (`pplx-embed-v1-0.6b`, `pplx-embed-v1-4b`, contextualized variants) | ## Embedder Configuration @@ -41,6 +42,9 @@ embedders: azure_openai_embedder: _type: azure_openai azure_deployment: text-embedding-3-small + perplexity_embedder: + _type: perplexity + model_name: pplx-embed-v1-0.6b ``` ### NVIDIA NIM @@ -120,3 +124,28 @@ embedders: endpoint_url: http://localhost:8081 api_key: ${HF_TOKEN} ``` + +### Perplexity + +Perplexity exposes a dedicated embeddings endpoint at `POST https://api.perplexity.ai/v1/embeddings`. The toolkit ships a native client that batches inputs, decodes the on-wire base64 payload locally, and forwards an `X-Pplx-Integration: nemo-agent-toolkit/` attribution header on every request. + +You can use the following environment variables to configure the Perplexity embedder provider: + +* `PERPLEXITY_API_KEY` - The API key to access the Perplexity Embeddings API + +The Perplexity embedder provider is defined by the {py:class}`~nat.embedder.perplexity_embedder.PerplexityEmbedderModelConfig` class. + +* `model_name` - Embedding model identifier. Standard embeddings: `pplx-embed-v1-0.6b` (1024-dim, default) or `pplx-embed-v1-4b` (2560-dim). Document-aware: `pplx-embed-context-v1-0.6b` or `pplx-embed-context-v1-4b` +* `api_key` - Perplexity API key (falls back to `PERPLEXITY_API_KEY`) +* `base_url` - Base URL for the Perplexity API (default: `https://api.perplexity.ai/v1`) +* `dimensions` - Optional Matryoshka output dimension. Range 128-1024 for `0.6b` models and 128-2560 for `4b` models. Omit for full dimensions +* `encoding_format` - On-wire encoding: `base64_int8` (default, signed int8) or `base64_binary` (packed bits for large-scale retrieval) +* `batch_size` - Maximum inputs per request (1-512). Defaults to `64` + +```yaml +embedders: + perplexity_embedder: + _type: perplexity + model_name: pplx-embed-v1-0.6b + dimensions: 512 # optional Matryoshka truncation +``` diff --git a/docs/source/components/integrations/frameworks.md b/docs/source/components/integrations/frameworks.md index 536da5f03f..c73938b859 100644 --- a/docs/source/components/integrations/frameworks.md +++ b/docs/source/components/integrations/frameworks.md @@ -40,7 +40,7 @@ NeMo Agent Toolkit provides different levels of support for each framework acros The ability to use various large language model providers with a framework, including NVIDIA NIM, OpenAI, Azure OpenAI, AWS Bedrock, LiteLLM, and Hugging Face. ### Embedder Provider Support -The ability to use embedding model providers for vector representations, including NVIDIA NIM embeddings, OpenAI embeddings, and Azure OpenAI embeddings. +The ability to use embedding model providers for vector representations, including NVIDIA NIM embeddings, OpenAI embeddings, Azure OpenAI embeddings, and Perplexity embeddings. ### Retriever Provider Support The ability to integrate with vector databases and retrieval systems, such as NeMo Retriever and Milvus. @@ -154,7 +154,7 @@ For more information, visit the [LangChain documentation](https://docs.langchain | Capability | Providers / Details | |-------------------------|-------------------------------------------------------------------------------------| | **LLM Providers** | NVIDIA NIM, OpenAI, Azure OpenAI, AWS Bedrock, LiteLLM, Hugging Face | -| **Embedder Providers** | NVIDIA NIM, OpenAI, Azure OpenAI | +| **Embedder Providers** | NVIDIA NIM, OpenAI, Azure OpenAI, Perplexity | | **Retriever Providers** | NeMo Retriever, Milvus | | **Tool Calling** | Fully supported through LangChain's `StructuredTool` interface | | **Profiling** | Comprehensive profiling support with callback handlers | @@ -174,7 +174,7 @@ For more information, visit the [LlamaIndex website](https://www.llamaindex.ai/) | Capability | Providers / Details | |-------------------------|-------------------------------------------------------------------------------------| | **LLM Providers** | NVIDIA NIM, OpenAI, Azure OpenAI, AWS Bedrock, LiteLLM | -| **Embedder Providers** | NVIDIA NIM, OpenAI, Azure OpenAI | +| **Embedder Providers** | NVIDIA NIM, OpenAI, Azure OpenAI, Perplexity | | **Retriever Providers** | None (Use LlamaIndex native retrievers) | | **Tool Calling** | Fully supported through LlamaIndex's `FunctionTool` interface | | **Profiling** | Comprehensive profiling support with callback handlers | diff --git a/docs/source/get-started/tutorials/add-tools-to-a-workflow.md b/docs/source/get-started/tutorials/add-tools-to-a-workflow.md index 3f1c477ed4..57dc7d8230 100644 --- a/docs/source/get-started/tutorials/add-tools-to-a-workflow.md +++ b/docs/source/get-started/tutorials/add-tools-to-a-workflow.md @@ -109,26 +109,26 @@ Workflow Result: ``` ## Alternate Method Using a Web Search Tool -Adding individual web pages to a workflow can be cumbersome, especially when dealing with multiple web pages. An alternative method is to use a web search tool. NeMo Agent Toolkit provides two web search tools: `tavily_internet_search` which utilizes the [Tavily Search API](https://tavily.com/), and `exa_internet_search` which utilizes the [Exa Search API](https://exa.ai/). +Adding individual web pages to a workflow can be cumbersome, especially when dealing with multiple web pages. An alternative method is to use a web search tool. NeMo Agent Toolkit provides web search tools including: `perplexity_internet_search` which utilizes the [Perplexity Search API](https://docs.perplexity.ai/api-reference/search-post), `tavily_internet_search` which utilizes the [Tavily Search API](https://tavily.com/), and `exa_internet_search` which utilizes the [Exa Search API](https://exa.ai/). -### Using Tavily Search +### Using Perplexity Search -The `tavily_internet_search` tool is part of the `nvidia-nat[langchain]` package, to install the package run: +The `perplexity_internet_search` tool ships with the core `nvidia-nat` package and is framework-agnostic — it can be used with any of the agent frameworks supported by the toolkit (`langchain`, `llama_index`, `crewai`, `semantic_kernel`, `agno`, `adk`, `strands`, and `autogen`). No framework-specific extra is required to install it: ```bash # local package install from source -uv pip install -e ".[langchain]" +uv pip install -e . ``` -Prior to using the `tavily_internet_search` tool, create an account at [`tavily.com`](https://tavily.com/) and obtain an API key. Once obtained, set the `TAVILY_API_KEY` environment variable to the API key: +Prior to using the `perplexity_internet_search` tool, create a Perplexity account and obtain an API key from the [API key page](https://www.perplexity.ai/account/api/keys). Once obtained, set the `PERPLEXITY_API_KEY` environment variable to the API key (`PPLX_API_KEY` is also accepted as a fallback): ```bash -export TAVILY_API_KEY= +export PERPLEXITY_API_KEY= ``` -We will now update the `functions` section of the configuration file replacing the two `webpage_query` tools with a single `tavily_internet_search` tool entry: +We will now update the `functions` section of the configuration file replacing the two `webpage_query` tools with a single `perplexity_internet_search` tool entry: ```yaml functions: internet_search: - _type: tavily_internet_search + _type: perplexity_internet_search current_datetime: _type: current_datetime ``` @@ -140,20 +140,56 @@ workflow: tool_names: [internet_search, current_datetime] ``` -The resulting configuration file is located at `examples/documentation_guides/workflows/custom_workflow/search_config.yml` in the NeMo Agent Toolkit repository. - When you re-run the workflow with the updated configuration file: ```bash -nat run --config_file examples/documentation_guides/workflows/custom_workflow/search_config.yml \ +nat run --config_file \ --input "How do I trace only specific parts of my LangChain application?" ``` -Which will then yield a slightly different result to the same question: +The `perplexity_internet_search` tool supports additional configuration options: +```yaml +functions: + internet_search: + _type: perplexity_internet_search + max_results: 5 + max_retries: 3 + max_query_length: 2000 # queries longer than this are truncated + search_recency_filter: week # 'hour', 'day', 'week', 'month', or 'year' + country: US # ISO 3166-1 alpha-2 country code + max_tokens_per_page: 4096 ``` -Workflow Result: -['To trace only specific parts of a LangChain application, users can use the `@traceable` decorator to mark specific functions or methods as traceable. Additionally, users can configure the tracing functionality to log traces to a specific project, add metadata and tags to traces, and customize the run name and ID. Users can also use the `LangChainTracer` class to trace specific invocations or parts of their application. Furthermore, users can use the `tracing_v2_enabled` context manager to trace a specific block of code.'] + +### Using Tavily Search + +The `tavily_internet_search` tool is part of the `nvidia-nat[langchain]` package, to install the package run: +```bash +# local package install from source +uv pip install -e ".[langchain]" ``` +Prior to using the `tavily_internet_search` tool, create an account at [`tavily.com`](https://tavily.com/) and obtain an API key. Once obtained, set the `TAVILY_API_KEY` environment variable to the API key: +```bash +export TAVILY_API_KEY= +``` + +You can use the `tavily_internet_search` tool by updating the `functions` section of the configuration file: +```yaml +functions: + internet_search: + _type: tavily_internet_search + current_datetime: + _type: current_datetime +``` + +Then ensure the tool is included in the workflow tool list: +```yaml +workflow: + _type: react_agent + tool_names: [internet_search, current_datetime] +``` + +A sample configuration file using `tavily_internet_search` is located at `examples/documentation_guides/workflows/custom_workflow/search_config.yml` in the NeMo Agent Toolkit repository. + ### Using Exa Search The `exa_internet_search` tool is also part of the `nvidia-nat[langchain]` package. If you haven't already installed it: @@ -167,7 +203,7 @@ Prior to using the `exa_internet_search` tool, create an account at [`exa.ai`](h export EXA_API_KEY= ``` -You can use the `exa_internet_search` tool in the same way as `tavily_internet_search` by updating the `functions` section of the configuration file: +You can use the `exa_internet_search` tool in the same way as the other web search tools by updating the `functions` section of the configuration file: ```yaml functions: internet_search: diff --git a/packages/nvidia_nat_core/src/nat/embedder/perplexity_embedder.py b/packages/nvidia_nat_core/src/nat/embedder/perplexity_embedder.py new file mode 100644 index 0000000000..4f33c0df2d --- /dev/null +++ b/packages/nvidia_nat_core/src/nat/embedder/perplexity_embedder.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +from pydantic import AliasChoices +from pydantic import ConfigDict +from pydantic import Field + +from nat.builder.builder import Builder +from nat.builder.embedder import EmbedderProviderInfo +from nat.cli.register_workflow import register_embedder_provider +from nat.data_models.common import OptionalSecretStr +from nat.data_models.embedder import EmbedderBaseConfig +from nat.data_models.retry_mixin import RetryMixin +from nat.data_models.ssl_verification_mixin import SSLVerificationMixin + +# Supported model identifiers for the Perplexity Embeddings API. +# Standard embeddings are for independent texts (queries/documents). +# Contextualized embeddings are document-aware; chunks from the same document share context. +PerplexityEmbeddingModel = typing.Literal[ + "pplx-embed-v1-0.6b", + "pplx-embed-v1-4b", + "pplx-embed-context-v1-0.6b", + "pplx-embed-context-v1-4b", +] + + +class PerplexityEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, SSLVerificationMixin, name="perplexity"): + """A Perplexity Embeddings API provider to be used with an embedder client. + + Perplexity exposes a dedicated embeddings endpoint at + ``https://api.perplexity.ai/v1/embeddings`` (standard) and + ``https://api.perplexity.ai/v1/contextualizedembeddings`` (contextualized). + Authentication uses ``PERPLEXITY_API_KEY``. + + Reference: https://docs.perplexity.ai/api-reference/embeddings-post + """ + + model_config = ConfigDict(protected_namespaces=(), extra="allow") + + api_key: OptionalSecretStr = Field( + default=None, + description="Perplexity API key to interact with the embeddings endpoint. " + "Falls back to the ``PERPLEXITY_API_KEY`` environment variable when unset.", + ) + base_url: str = Field( + default="https://api.perplexity.ai/v1", + description="Base URL for the Perplexity API. The embedder appends ``/embeddings`` " + "for standard models and ``/contextualizedembeddings`` for context models.", + ) + model_name: PerplexityEmbeddingModel = Field( + default="pplx-embed-v1-0.6b", + validation_alias=AliasChoices("model_name", "model"), + serialization_alias="model", + description="Perplexity embedding model. Standard: ``pplx-embed-v1-0.6b`` (1024-dim) " + "or ``pplx-embed-v1-4b`` (2560-dim). Contextualized: ``pplx-embed-context-v1-0.6b`` " + "or ``pplx-embed-context-v1-4b``.", + ) + dimensions: int | None = Field( + default=None, + ge=128, + le=2560, + description="Matryoshka output dimensions. Range is 128–1024 for ``0.6b`` models and " + "128–2560 for ``4b`` models. Defaults to full dimensions when unset.", + ) + batch_size: int = Field( + default=64, + ge=1, + le=512, + description="Maximum number of input texts to send per request. The Perplexity API " + "accepts up to 512 inputs per call (subject to a 120,000 total-token cap).", + ) + encoding_format: typing.Literal["base64_int8", "base64_binary"] = Field( + default="base64_int8", + description="On-wire encoding for the embedding payload. ``base64_int8`` (default) " + "returns signed int8 values; ``base64_binary`` returns 1-bit-per-dimension packed bits.", + ) + + +@register_embedder_provider(config_type=PerplexityEmbedderModelConfig) +async def perplexity_embedder_model(config: PerplexityEmbedderModelConfig, _builder: Builder): + yield EmbedderProviderInfo( + config=config, + description="A Perplexity Embeddings API model for use with an Embedder client.", + ) diff --git a/packages/nvidia_nat_core/src/nat/embedder/register.py b/packages/nvidia_nat_core/src/nat/embedder/register.py index ac82ccb1d7..ff21dbe349 100644 --- a/packages/nvidia_nat_core/src/nat/embedder/register.py +++ b/packages/nvidia_nat_core/src/nat/embedder/register.py @@ -21,3 +21,4 @@ from . import huggingface_embedder from . import nim_embedder from . import openai_embedder +from . import perplexity_embedder diff --git a/packages/nvidia_nat_core/src/nat/tool/perplexity_internet_search.py b/packages/nvidia_nat_core/src/nat/tool/perplexity_internet_search.py new file mode 100644 index 0000000000..3b437e4037 --- /dev/null +++ b/packages/nvidia_nat_core/src/nat/tool/perplexity_internet_search.py @@ -0,0 +1,213 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import asyncio +import logging +from collections.abc import AsyncIterator +from typing import Literal + +import httpx +from pydantic import Field + +from nat.builder.builder import Builder +from nat.builder.framework_enum import LLMFrameworkEnum +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.common import SerializableSecretStr +from nat.data_models.common import get_secret_value +from nat.data_models.function import FunctionBaseConfig + +logger = logging.getLogger(__name__) + +PERPLEXITY_SEARCH_URL = "https://api.perplexity.ai/search" + + +# Internet Search tool +class PerplexityInternetSearchToolConfig(FunctionBaseConfig, name="perplexity_internet_search"): + """ + Tool that retrieves relevant contexts from web search (using Perplexity) for the given question. + Requires a PERPLEXITY_API_KEY. + """ + + max_results: int = Field(default=5, ge=1, le=20, description="Maximum number of search results to return.") + api_key: SerializableSecretStr = Field( + default_factory=lambda: SerializableSecretStr(""), description="The API key for the Perplexity service." + ) + max_retries: int = Field(default=3, ge=1, description="Maximum number of retries for the search request") + max_query_length: int = Field( + default=2000, + ge=1, + description="Maximum query length in characters. Queries exceeding this limit will be truncated.", + ) + search_recency_filter: Literal["hour", "day", "week", "month", "year"] | None = Field( + default=None, description="Filter search results by recency - 'hour', 'day', 'week', 'month', or 'year'." + ) + country: str | None = Field( + default=None, + min_length=2, + max_length=2, + pattern=r"^[A-Z]{2}$", + description="Country to filter search results by ISO 3166-1 alpha-2 code (uppercase, e.g. 'US', 'GB').", + ) + max_tokens_per_page: int = Field( + default=4096, ge=1, description="Maximum number of tokens to retrieve per search result page." + ) + + +def _get_integration_header() -> str: + """Build the ``X-Pplx-Integration`` header value for outbound Perplexity requests. + + Returns: + str: A ``"nemo-agent-toolkit/"`` slug. Falls back to ``"unknown"`` + when the ``nvidia-nat`` package metadata cannot be resolved. + """ + from importlib import metadata + + try: + package_version = metadata.version("nvidia-nat") + except metadata.PackageNotFoundError: + package_version = "unknown" + return f"nemo-agent-toolkit/{package_version}" + + +@register_function( + config_type=PerplexityInternetSearchToolConfig, + framework_wrappers=[ + LLMFrameworkEnum.LANGCHAIN, + LLMFrameworkEnum.LLAMA_INDEX, + LLMFrameworkEnum.CREWAI, + LLMFrameworkEnum.SEMANTIC_KERNEL, + LLMFrameworkEnum.AGNO, + LLMFrameworkEnum.ADK, + LLMFrameworkEnum.STRANDS, + LLMFrameworkEnum.AUTOGEN, + ], +) +async def perplexity_internet_search( + tool_config: PerplexityInternetSearchToolConfig, builder: Builder +) -> AsyncIterator[FunctionInfo]: + """Register the Perplexity internet search tool with the NAT runtime. + + Resolves the Perplexity API key from the tool config or the + ``PERPLEXITY_API_KEY`` environment variable, then yields a ``FunctionInfo`` + wrapping an async coroutine that calls Perplexity's Search API + (``POST https://api.perplexity.ai/search``) with retry/backoff and formats + results as ```` blocks suitable for LLM consumption. + + The tool is framework-agnostic (uses raw ``httpx``) and is registered with + every framework wrapper in :class:`LLMFrameworkEnum` so it can be consumed + by any NAT-supported agent framework. + + Args: + tool_config: ``PerplexityInternetSearchToolConfig`` carrying API key, + result/retry caps, recency/country filters, and per-page token cap. + builder: NAT ``Builder`` instance (unused by this tool but required by + the ``@register_function`` contract). + + Yields: + FunctionInfo: The registered async search function. + """ + import os + + api_key = get_secret_value(tool_config.api_key) if tool_config.api_key else "" + resolved_api_key = api_key or os.environ.get("PERPLEXITY_API_KEY", "") + + async def _perplexity_internet_search(question: str) -> str: + """This tool retrieves relevant contexts from web search (using Perplexity) for the given question. + + Args: + question (str): The question to be answered. + + Returns: + str: The web search results. + """ + if not resolved_api_key: + return "Web search is unavailable: 'PERPLEXITY_API_KEY' is not configured." + + # Truncate long queries to the configured limit + max_len = tool_config.max_query_length + if len(question) > max_len: + logger.warning("Perplexity query truncated from %d to %d characters", len(question), max_len) + question = question[: max_len - 3] + "..." if max_len > 3 else question[:max_len] + + request_body = { + "query": question, + "max_results": tool_config.max_results, + "max_tokens_per_page": tool_config.max_tokens_per_page, + } + if tool_config.search_recency_filter is not None: + request_body["search_recency_filter"] = tool_config.search_recency_filter + if tool_config.country is not None: + request_body["country"] = tool_config.country + + headers = { + "Authorization": f"Bearer {resolved_api_key}", + "Content-Type": "application/json", + "X-Pplx-Integration": _get_integration_header(), + } + + async with httpx.AsyncClient() as client: + for attempt in range(tool_config.max_retries): + try: + response = await client.post(PERPLEXITY_SEARCH_URL, headers=headers, json=request_body) + # Short-circuit on non-retriable client errors + if response.status_code in {401, 403, 404}: + logger.error( + "Perplexity search returned non-retriable status %d; aborting without retry", + response.status_code, + ) + return ( + f"Web search failed with non-retriable status {response.status_code}. " + f"Check that PERPLEXITY_API_KEY is set correctly." + ) + response.raise_for_status() + search_response = response.json() + results = search_response.get("results") if isinstance(search_response, dict) else None + if not results: + return f"No web search results found for: {question}" + + web_search_results = "\n\n---\n\n".join( + [ + f'\n{doc.get("snippet", "")}\n' + for doc in results + ] + ) + return web_search_results or f"No web search results found for: {question}" + except httpx.HTTPStatusError as exc: + # raise_for_status() raises HTTPStatusError for 4xx/5xx not handled above (e.g. 5xx) + logger.warning( + "Perplexity search attempt %d of %d failed with status %d", + attempt + 1, + tool_config.max_retries, + exc.response.status_code, + ) + if attempt == tool_config.max_retries - 1: + return f"Web search failed after {tool_config.max_retries} attempts for: {question}" + await asyncio.sleep(2**attempt) + except Exception: + # Return a graceful message instead of raising, so the agent can + # continue reasoning without web search rather than failing entirely. + logger.exception("Perplexity search attempt %d of %d failed", attempt + 1, tool_config.max_retries) + if attempt == tool_config.max_retries - 1: + return f"Web search failed after {tool_config.max_retries} attempts for: {question}" + await asyncio.sleep(2**attempt) + return f"Web search failed after {tool_config.max_retries} attempts for: {question}" + + # Create a Generic NAT tool that can be used with any supported LLM framework + yield FunctionInfo.from_fn( + _perplexity_internet_search, + description=_perplexity_internet_search.__doc__, + ) diff --git a/packages/nvidia_nat_core/src/nat/tool/register.py b/packages/nvidia_nat_core/src/nat/tool/register.py index 0f658d45fe..9b0643de21 100644 --- a/packages/nvidia_nat_core/src/nat/tool/register.py +++ b/packages/nvidia_nat_core/src/nat/tool/register.py @@ -21,6 +21,7 @@ from . import document_search from . import github_tools from . import nvidia_rag +from . import perplexity_internet_search from . import retriever from . import server_tools from .code_execution import register diff --git a/packages/nvidia_nat_core/tests/nat/embedder/test_perplexity_embedder.py b/packages/nvidia_nat_core/tests/nat/embedder/test_perplexity_embedder.py new file mode 100644 index 0000000000..112315a5e3 --- /dev/null +++ b/packages/nvidia_nat_core/tests/nat/embedder/test_perplexity_embedder.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pydantic import SecretStr +from pydantic import ValidationError + +from nat.embedder.perplexity_embedder import PerplexityEmbedderModelConfig + + +def test_defaults(): + """Default config picks the small standard model and full dimensions.""" + config = PerplexityEmbedderModelConfig() + assert config.type == "perplexity" + assert config.model_name == "pplx-embed-v1-0.6b" + assert config.base_url == "https://api.perplexity.ai/v1" + assert config.dimensions is None + assert config.batch_size == 64 + assert config.encoding_format == "base64_int8" + + +def test_accepts_supported_models(): + """All four model identifiers documented by the Perplexity API are accepted.""" + for model in ( + "pplx-embed-v1-0.6b", + "pplx-embed-v1-4b", + "pplx-embed-context-v1-0.6b", + "pplx-embed-context-v1-4b", + ): + cfg = PerplexityEmbedderModelConfig(model_name=model) + assert cfg.model_name == model + + +def test_rejects_unsupported_model(): + """Unsupported model identifiers raise a validation error.""" + with pytest.raises(ValidationError): + PerplexityEmbedderModelConfig(model_name="text-embedding-3-small") + + +def test_dimensions_bounds(): + """Matryoshka ``dimensions`` is bounded to the documented range [128, 2560].""" + PerplexityEmbedderModelConfig(model_name="pplx-embed-v1-4b", dimensions=128) + PerplexityEmbedderModelConfig(model_name="pplx-embed-v1-4b", dimensions=2560) + with pytest.raises(ValidationError): + PerplexityEmbedderModelConfig(dimensions=64) + with pytest.raises(ValidationError): + PerplexityEmbedderModelConfig(dimensions=4096) + + +def test_batch_size_bounds(): + """``batch_size`` is bounded by Perplexity's documented 512-input-per-request cap.""" + PerplexityEmbedderModelConfig(batch_size=1) + PerplexityEmbedderModelConfig(batch_size=512) + with pytest.raises(ValidationError): + PerplexityEmbedderModelConfig(batch_size=0) + with pytest.raises(ValidationError): + PerplexityEmbedderModelConfig(batch_size=1024) + + +def test_encoding_format_choices(): + """Only the two on-wire encoding formats supported by Perplexity are accepted.""" + PerplexityEmbedderModelConfig(encoding_format="base64_int8") + PerplexityEmbedderModelConfig(encoding_format="base64_binary") + with pytest.raises(ValidationError): + # ``float`` is *not* supported by Perplexity's embeddings endpoint. + PerplexityEmbedderModelConfig(encoding_format="float") + + +def test_api_key_secret_str(): + """``api_key`` is stored as a SecretStr-style field, not a plain string.""" + cfg = PerplexityEmbedderModelConfig(api_key="pplx-test-key") + assert cfg.api_key is not None + # Pydantic ``SecretStr`` masks the value in ``repr`` / ``str``. + assert "pplx-test-key" not in repr(cfg.api_key) + assert isinstance(cfg.api_key, SecretStr) + + +def test_model_alias_accepted(): + """The ``model`` alias is accepted in addition to ``model_name`` and round-trips.""" + cfg = PerplexityEmbedderModelConfig(model="pplx-embed-v1-4b") + assert cfg.model_name == "pplx-embed-v1-4b" diff --git a/packages/nvidia_nat_core/tests/nat/tools/test_perplexity_internet_search.py b/packages/nvidia_nat_core/tests/nat/tools/test_perplexity_internet_search.py new file mode 100644 index 0000000000..06ebfa1ffb --- /dev/null +++ b/packages/nvidia_nat_core/tests/nat/tools/test_perplexity_internet_search.py @@ -0,0 +1,218 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest +from pydantic import SecretStr +from pydantic import ValidationError + +# -- Config validation tests -- + + +@pytest.mark.parametrize( + "constructor_args", + [{}, {"api_key": ""}, {"api_key": "my_api_key"}], + ids=["default", "empty_api_key", "provided_api_key"], +) +def test_api_key_is_secret_str(constructor_args: dict): + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig + + expected_api_key = constructor_args.get("api_key", "") + + config = PerplexityInternetSearchToolConfig(**constructor_args) + assert isinstance(config.api_key, SecretStr) + + api_key = config.api_key.get_secret_value() + assert api_key == expected_api_key + + +def test_default_api_key_is_unique_instance(): + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig + + config1 = PerplexityInternetSearchToolConfig() + config2 = PerplexityInternetSearchToolConfig() + + assert config1.api_key is not config2.api_key + + +def test_max_retries_rejects_zero(): + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig + + with pytest.raises(ValidationError): + PerplexityInternetSearchToolConfig(max_retries=0) + + +def test_max_results_rejects_zero(): + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig + + with pytest.raises(ValidationError): + PerplexityInternetSearchToolConfig(max_results=0) + + +def test_max_results_rejects_above_20(): + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig + + with pytest.raises(ValidationError): + PerplexityInternetSearchToolConfig(max_results=21) + + +def test_invalid_search_recency_filter_rejected(): + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig + + with pytest.raises(ValidationError): + PerplexityInternetSearchToolConfig(search_recency_filter="invalid") + + +# -- Tool behavior tests -- + + +@pytest.fixture +def tool_config(): + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig + + return PerplexityInternetSearchToolConfig(api_key="test-key", max_retries=2, max_query_length=50) + + +def _mock_response(results: list[dict] | None): + response = MagicMock() + response.raise_for_status.return_value = None + response.json.return_value = {"results": results} + return response + + +def _mock_async_client(post_mock: AsyncMock): + mock_client = MagicMock() + mock_client.post = post_mock + mock_context_manager = MagicMock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + return mock_context_manager, mock_client + + +async def test_empty_key_returns_unavailable(): + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig + from nat.tool.perplexity_internet_search import perplexity_internet_search + + config = PerplexityInternetSearchToolConfig(api_key="") + with patch.dict(os.environ, {"PERPLEXITY_API_KEY": ""}): + async with perplexity_internet_search(config, None) as func_info: + result = await func_info.single_fn("test query") + assert "unavailable" in result.lower() + assert "PERPLEXITY_API_KEY" in result + + +async def test_query_truncation(tool_config): + from nat.tool.perplexity_internet_search import perplexity_internet_search + + long_query = "a" * 100 # exceeds max_query_length=50 + post_mock = AsyncMock(return_value=_mock_response([])) + mock_context_manager, mock_client = _mock_async_client(post_mock) + + with patch( + "nat.tool.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager + ): + async with perplexity_internet_search(tool_config, None) as func_info: + await func_info.single_fn(long_query) + + # Verify the query was truncated + call_args = mock_client.post.call_args + truncated_query = call_args.kwargs["json"]["query"] + assert len(truncated_query) <= 50 + assert truncated_query.endswith("...") + + +async def test_empty_results(tool_config): + from nat.tool.perplexity_internet_search import perplexity_internet_search + + post_mock = AsyncMock(return_value=_mock_response([])) + mock_context_manager, _ = _mock_async_client(post_mock) + + with patch( + "nat.tool.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager + ): + async with perplexity_internet_search(tool_config, None) as func_info: + result = await func_info.single_fn("test query") + assert "No web search results found" in result + + +async def test_retries_on_exception(tool_config): + from nat.tool.perplexity_internet_search import perplexity_internet_search + + post_mock = AsyncMock(side_effect=Exception("API error")) + mock_context_manager, mock_client = _mock_async_client(post_mock) + + with ( + patch( + "nat.tool.perplexity_internet_search.httpx.AsyncClient", + return_value=mock_context_manager, + ), + patch("nat.tool.perplexity_internet_search.asyncio.sleep", new_callable=AsyncMock), + ): + async with perplexity_internet_search(tool_config, None) as func_info: + result = await func_info.single_fn("test query") + + # Should have retried max_retries times (2) + assert mock_client.post.call_count == 2 + assert "Web search failed after 2 attempts" in result + + +async def test_attribution_header_sent(tool_config): + from nat.tool.perplexity_internet_search import perplexity_internet_search + + post_mock = AsyncMock(return_value=_mock_response([])) + mock_context_manager, mock_client = _mock_async_client(post_mock) + + with patch( + "nat.tool.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager + ): + async with perplexity_internet_search(tool_config, None) as func_info: + await func_info.single_fn("test query") + + call_args = mock_client.post.call_args + assert call_args.kwargs["headers"]["X-Pplx-Integration"].startswith("nemo-agent-toolkit/") + + +async def test_results_formatted_as_documents(tool_config): + from nat.tool.perplexity_internet_search import perplexity_internet_search + + post_mock = AsyncMock( + return_value=_mock_response( + [ + { + "url": "https://example.com/one", + "snippet": "First result.", + }, + { + "url": "https://example.com/two", + "snippet": "Second result.", + }, + ] + ) + ) + mock_context_manager, _ = _mock_async_client(post_mock) + + with patch( + "nat.tool.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager + ): + async with perplexity_internet_search(tool_config, None) as func_info: + result = await func_info.single_fn("test query") + + assert '' in result + assert '' in result + assert "\n\n---\n\n" in result diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/embedder.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/embedder.py index 03b99e26b3..4b135f0ffe 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/embedder.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/embedder.py @@ -25,6 +25,7 @@ from nat.embedder.huggingface_embedder import HuggingFaceEmbedderConfig from nat.embedder.nim_embedder import NIMEmbedderModelConfig from nat.embedder.openai_embedder import OpenAIEmbedderModelConfig +from nat.embedder.perplexity_embedder import PerplexityEmbedderModelConfig from nat.llm.utils.http_client import http_clients from nat.utils.exception_handlers.automatic_retries import patch_with_retry @@ -96,6 +97,48 @@ async def openai_langchain(embedder_config: OpenAIEmbedderModelConfig, builder: yield client +@register_embedder_client(config_type=PerplexityEmbedderModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN) +async def perplexity_langchain(embedder_config: PerplexityEmbedderModelConfig, _builder: Builder): + """LangChain client for the Perplexity Embeddings API. + + Resolves the API key from the config or ``PERPLEXITY_API_KEY``, then yields a + :class:`~nat.plugins.langchain.perplexity_embeddings_client.PerplexityEmbeddings` + instance configured with the requested model, dimensions, and encoding format. + Outbound requests carry an ``X-Pplx-Integration: nemo-agent-toolkit/`` + attribution header. + """ + import os + + from nat.plugins.langchain.perplexity_embeddings_client import PerplexityEmbeddings + + raw_key = get_secret_value(embedder_config.api_key) if embedder_config.api_key else "" + resolved_key = raw_key or os.environ.get("PERPLEXITY_API_KEY", "") + if not resolved_key: + raise ValueError( + "Perplexity embedder requires a non-empty API key. Set ``api_key`` in the " + "workflow config or export ``PERPLEXITY_API_KEY``." + ) + + client = PerplexityEmbeddings( + api_key=resolved_key, + base_url=embedder_config.base_url, + model=embedder_config.model_name, + dimensions=embedder_config.dimensions, + encoding_format=embedder_config.encoding_format, + batch_size=embedder_config.batch_size, + max_retries=getattr(embedder_config, "num_retries", 3), + verify_ssl=getattr(embedder_config, "verify_ssl", True), + ) + + if isinstance(embedder_config, RetryMixin): + client = patch_with_retry(client, + retries=embedder_config.num_retries, + retry_codes=embedder_config.retry_on_status_codes, + retry_on_messages=embedder_config.retry_on_errors) + + yield client + + @register_embedder_client(config_type=HuggingFaceEmbedderConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN) async def huggingface_langchain(embedder_config: HuggingFaceEmbedderConfig, _builder: Builder) -> AsyncIterator[Any]: """LangChain client for HuggingFace embedder - local or remote based on endpoint_url.""" diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/perplexity_embeddings_client.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/perplexity_embeddings_client.py new file mode 100644 index 0000000000..7e51d85f99 --- /dev/null +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/perplexity_embeddings_client.py @@ -0,0 +1,267 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LangChain ``Embeddings`` client for the Perplexity Embeddings API. + +Perplexity's embeddings endpoint (``POST /v1/embeddings``) returns base64-encoded +quantized values (``base64_int8`` or ``base64_binary``) rather than the JSON float +arrays returned by OpenAI-compatible providers. This module provides a thin +LangChain :class:`~langchain_core.embeddings.Embeddings` implementation that +performs the decoding so the resulting vectors plug into the standard NAT +retriever/RAG stack. +""" +from __future__ import annotations + +import base64 +import logging +from collections.abc import Iterable + +import httpx +import numpy as np +from langchain_core.embeddings import Embeddings + +logger = logging.getLogger(__name__) + + +def _decode_int8(payload: str) -> list[float]: + """Decode a ``base64_int8`` embedding into a float32 list. + + Args: + payload: Base64-encoded signed int8 buffer returned by Perplexity. + + Returns: + The decoded vector as a list of Python floats. + """ + return np.frombuffer(base64.b64decode(payload), dtype=np.int8).astype(np.float32).tolist() + + +def _decode_binary(payload: str) -> list[float]: + """Decode a ``base64_binary`` embedding into a 0/1 float list. + + Args: + payload: Base64-encoded packed bits (LSB first, 1 bit per dimension). + + Returns: + The unpacked bit vector as a list of Python floats (0.0 or 1.0). + """ + packed = np.frombuffer(base64.b64decode(payload), dtype=np.uint8) + return np.unpackbits(packed, bitorder="little").astype(np.float32).tolist() + + +class PerplexityEmbeddings(Embeddings): + """LangChain ``Embeddings`` client for the Perplexity Embeddings API. + + The client batches inputs (``batch_size`` per request, default 64), decodes the + base64 payload locally, and forwards the ``X-Pplx-Integration`` attribution + header so Perplexity can identify NeMo Agent Toolkit traffic. + + Args: + api_key: Perplexity API key. Required. + base_url: Base URL for the Perplexity API. Defaults to ``https://api.perplexity.ai/v1``. + model: Embedding model identifier (e.g. ``pplx-embed-v1-0.6b``). + dimensions: Optional Matryoshka output dimension. + encoding_format: ``base64_int8`` (default) or ``base64_binary``. + batch_size: Maximum inputs per request (1–512). Defaults to 64. + max_retries: Number of retry attempts on transient failures. Defaults to 3. + verify_ssl: Whether to verify TLS certificates. Defaults to True. + integration_header: Optional ``X-Pplx-Integration`` header value. Defaults + to ``"nemo-agent-toolkit/"`` resolved at instantiation time. + """ + + def __init__( + self, + api_key: str, + *, + base_url: str = "https://api.perplexity.ai/v1", + model: str = "pplx-embed-v1-0.6b", + dimensions: int | None = None, + encoding_format: str = "base64_int8", + batch_size: int = 64, + max_retries: int = 3, + verify_ssl: bool = True, + integration_header: str | None = None, + ) -> None: + if not api_key: + raise ValueError( + "PerplexityEmbeddings requires a non-empty api_key. " + "Set the PERPLEXITY_API_KEY environment variable or pass api_key explicitly." + ) + if encoding_format not in ("base64_int8", "base64_binary"): + raise ValueError( + f"encoding_format must be 'base64_int8' or 'base64_binary', got {encoding_format!r}." + ) + + self._api_key = api_key + self._base_url = base_url.rstrip("/") + self._model = model + self._dimensions = dimensions + self._encoding_format = encoding_format + self._batch_size = max(1, min(int(batch_size), 512)) + self._max_retries = max(1, int(max_retries)) + self._verify_ssl = verify_ssl + self._integration_header = integration_header or _default_integration_header() + + # ------------------------------------------------------------------ + # LangChain interface + # ------------------------------------------------------------------ + def embed_documents(self, texts: list[str]) -> list[list[float]]: # type: ignore[override] + return self._embed(texts) + + def embed_query(self, text: str) -> list[float]: # type: ignore[override] + results = self._embed([text]) + return results[0] + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: # type: ignore[override] + return await self._aembed(texts) + + async def aembed_query(self, text: str) -> list[float]: # type: ignore[override] + results = await self._aembed([text]) + return results[0] + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _endpoint(self) -> str: + return f"{self._base_url}/embeddings" + + def _headers(self) -> dict[str, str]: + return { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + "X-Pplx-Integration": self._integration_header, + } + + def _build_body(self, inputs: list[str]) -> dict: + body: dict = { + "input": inputs, + "model": self._model, + "encoding_format": self._encoding_format, + } + if self._dimensions is not None: + body["dimensions"] = self._dimensions + return body + + def _decode_response(self, payload: dict) -> list[list[float]]: + data: Iterable[dict] = payload.get("data") or [] + decoder = _decode_int8 if self._encoding_format == "base64_int8" else _decode_binary + # Preserve input order via the ``index`` field, which the API echoes back. + decoded: list[list[float]] = [None] * len(list(data)) # type: ignore[list-item] + items = list(payload.get("data") or []) + decoded = [decoder(item.get("embedding", "")) for item in items] + return decoded + + def _batches(self, texts: list[str]) -> Iterable[list[str]]: + for start in range(0, len(texts), self._batch_size): + yield texts[start:start + self._batch_size] + + def _embed(self, texts: list[str]) -> list[list[float]]: + if not texts: + return [] + embeddings: list[list[float]] = [] + with httpx.Client(verify=self._verify_ssl, timeout=60.0) as client: + for batch in self._batches(texts): + payload = self._post_with_retry_sync(client, batch) + embeddings.extend(self._decode_response(payload)) + return embeddings + + async def _aembed(self, texts: list[str]) -> list[list[float]]: + if not texts: + return [] + embeddings: list[list[float]] = [] + async with httpx.AsyncClient(verify=self._verify_ssl, timeout=60.0) as client: + for batch in self._batches(texts): + payload = await self._post_with_retry_async(client, batch) + embeddings.extend(self._decode_response(payload)) + return embeddings + + def _post_with_retry_sync(self, client: httpx.Client, batch: list[str]) -> dict: + import time + + last_error: Exception | None = None + for attempt in range(self._max_retries): + try: + response = client.post(self._endpoint(), headers=self._headers(), json=self._build_body(batch)) + if response.status_code in {401, 403, 404}: + response.raise_for_status() + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as exc: + last_error = exc + if exc.response.status_code in {401, 403, 404} or attempt == self._max_retries - 1: + raise + logger.warning( + "Perplexity embeddings attempt %d/%d failed with status %d", + attempt + 1, + self._max_retries, + exc.response.status_code, + ) + except httpx.RequestError as exc: + last_error = exc + if attempt == self._max_retries - 1: + raise + logger.warning( + "Perplexity embeddings attempt %d/%d failed: %s", + attempt + 1, + self._max_retries, + exc, + ) + time.sleep(2**attempt) + raise RuntimeError("Perplexity embeddings request failed") from last_error + + async def _post_with_retry_async(self, client: httpx.AsyncClient, batch: list[str]) -> dict: + import asyncio + + last_error: Exception | None = None + for attempt in range(self._max_retries): + try: + response = await client.post( + self._endpoint(), headers=self._headers(), json=self._build_body(batch) + ) + if response.status_code in {401, 403, 404}: + response.raise_for_status() + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as exc: + last_error = exc + if exc.response.status_code in {401, 403, 404} or attempt == self._max_retries - 1: + raise + logger.warning( + "Perplexity embeddings attempt %d/%d failed with status %d", + attempt + 1, + self._max_retries, + exc.response.status_code, + ) + except httpx.RequestError as exc: + last_error = exc + if attempt == self._max_retries - 1: + raise + logger.warning( + "Perplexity embeddings attempt %d/%d failed: %s", + attempt + 1, + self._max_retries, + exc, + ) + await asyncio.sleep(2**attempt) + raise RuntimeError("Perplexity embeddings request failed") from last_error + + +def _default_integration_header() -> str: + """Return the default ``X-Pplx-Integration`` header value for outbound requests.""" + from importlib import metadata + + try: + package_version = metadata.version("nvidia-nat") + except metadata.PackageNotFoundError: + package_version = "unknown" + return f"nemo-agent-toolkit/{package_version}" diff --git a/packages/nvidia_nat_langchain/tests/test_perplexity_embeddings.py b/packages/nvidia_nat_langchain/tests/test_perplexity_embeddings.py new file mode 100644 index 0000000000..6031eab70e --- /dev/null +++ b/packages/nvidia_nat_langchain/tests/test_perplexity_embeddings.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import httpx +import numpy as np +import pytest + +from nat.embedder.perplexity_embedder import PerplexityEmbedderModelConfig +from nat.plugins.langchain.embedder import perplexity_langchain +from nat.plugins.langchain.perplexity_embeddings_client import PerplexityEmbeddings +from nat.plugins.langchain.perplexity_embeddings_client import _decode_binary +from nat.plugins.langchain.perplexity_embeddings_client import _decode_int8 +from nat.plugins.langchain.perplexity_embeddings_client import _default_integration_header + + +# --------------------------------------------------------------------------- +# Decoding helpers +# --------------------------------------------------------------------------- + + +def _encode_int8(values: list[int]) -> str: + return base64.b64encode(np.array(values, dtype=np.int8).tobytes()).decode() + + +def _make_response(status_code: int, payload: dict) -> httpx.Response: + """Build an ``httpx.Response`` with a bound dummy request so ``raise_for_status`` works in tests.""" + return httpx.Response( + status_code, + json=payload, + request=httpx.Request("POST", "https://api.perplexity.ai/v1/embeddings"), + ) + + +def _encode_binary(bits: list[int]) -> str: + packed = np.packbits(np.array(bits, dtype=np.uint8), bitorder="little") + return base64.b64encode(packed.tobytes()).decode() + + +def test_decode_int8_round_trip(): + payload = _encode_int8([-128, -1, 0, 1, 127]) + decoded = _decode_int8(payload) + assert decoded == [-128.0, -1.0, 0.0, 1.0, 127.0] + + +def test_decode_binary_round_trip(): + bits = [1, 0, 1, 1, 0, 0, 1, 0] + payload = _encode_binary(bits) + decoded = _decode_binary(payload) + assert decoded == [float(b) for b in bits] + + +def test_default_integration_header_uses_nemo_slug(): + header = _default_integration_header() + assert header.startswith("nemo-agent-toolkit/") + + +# --------------------------------------------------------------------------- +# PerplexityEmbeddings client +# --------------------------------------------------------------------------- + + +class TestPerplexityEmbeddingsClient: + + def test_requires_api_key(self): + with pytest.raises(ValueError): + PerplexityEmbeddings(api_key="") + + def test_rejects_unsupported_encoding(self): + with pytest.raises(ValueError): + PerplexityEmbeddings(api_key="pplx-x", encoding_format="float") + + def test_batches_inputs_and_decodes_int8(self): + """The client should split inputs into batches of ``batch_size`` and decode int8 payloads.""" + client = PerplexityEmbeddings(api_key="pplx-x", batch_size=2) + + responses = [ + _make_response( + 200, + { + "object": "list", + "data": [ + {"object": "embedding", "index": 0, "embedding": _encode_int8([1, 2])}, + {"object": "embedding", "index": 1, "embedding": _encode_int8([3, 4])}, + ], + "model": "pplx-embed-v1-0.6b", + }, + ), + _make_response( + 200, + { + "object": "list", + "data": [{"object": "embedding", "index": 0, "embedding": _encode_int8([5, 6])}], + "model": "pplx-embed-v1-0.6b", + }, + ), + ] + + with patch("httpx.Client") as mock_client_cls: + instance = MagicMock() + instance.__enter__.return_value = instance + instance.post.side_effect = responses + mock_client_cls.return_value = instance + + vectors = client.embed_documents(["a", "b", "c"]) + + assert vectors == [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + # Two batches of size 2 (last one is size 1). + assert instance.post.call_count == 2 + + def test_request_body_includes_dimensions_and_attribution_header(self): + client = PerplexityEmbeddings(api_key="pplx-x", dimensions=256, integration_header="test-suite/1.2.3") + + with patch("httpx.Client") as mock_client_cls: + instance = MagicMock() + instance.__enter__.return_value = instance + instance.post.return_value = _make_response( + 200, + { + "object": "list", + "data": [{"object": "embedding", "index": 0, "embedding": _encode_int8([0])}], + "model": "pplx-embed-v1-0.6b", + }, + ) + mock_client_cls.return_value = instance + + client.embed_query("hello") + + call = instance.post.call_args + assert call.kwargs["json"]["dimensions"] == 256 + assert call.kwargs["json"]["encoding_format"] == "base64_int8" + assert call.kwargs["json"]["model"] == "pplx-embed-v1-0.6b" + assert call.kwargs["headers"]["X-Pplx-Integration"] == "test-suite/1.2.3" + assert call.kwargs["headers"]["Authorization"] == "Bearer pplx-x" + assert call.args[0].endswith("/v1/embeddings") + + def test_non_retriable_status_raises(self): + client = PerplexityEmbeddings(api_key="pplx-x", max_retries=3) + with patch("httpx.Client") as mock_client_cls: + instance = MagicMock() + instance.__enter__.return_value = instance + instance.post.return_value = _make_response(401, {"error": "unauthorized"}) + mock_client_cls.return_value = instance + + with pytest.raises(httpx.HTTPStatusError): + client.embed_query("hello") + + # 401 short-circuits — exactly one POST attempt. + assert instance.post.call_count == 1 + + async def test_async_embed_query(self): + client = PerplexityEmbeddings(api_key="pplx-x") + + async_response = _make_response( + 200, + { + "object": "list", + "data": [{"object": "embedding", "index": 0, "embedding": _encode_int8([7, 8, 9])}], + "model": "pplx-embed-v1-0.6b", + }, + ) + with patch("httpx.AsyncClient") as mock_client_cls: + instance = MagicMock() + instance.__aenter__ = AsyncMock(return_value=instance) + instance.__aexit__ = AsyncMock(return_value=None) + instance.post = AsyncMock(return_value=async_response) + mock_client_cls.return_value = instance + + vector = await client.aembed_query("hi") + + assert vector == [7.0, 8.0, 9.0] + + +# --------------------------------------------------------------------------- +# perplexity_langchain registration +# --------------------------------------------------------------------------- + + +class TestPerplexityLangChainRegistration: + + async def test_requires_api_key_in_env_or_config(self, monkeypatch, mock_builder): + monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False) + cfg = PerplexityEmbedderModelConfig() + with pytest.raises(ValueError, match="non-empty API key"): + async with perplexity_langchain(cfg, mock_builder): + pass + + async def test_uses_environment_api_key(self, monkeypatch, mock_builder): + monkeypatch.setenv("PERPLEXITY_API_KEY", "env-key") + cfg = PerplexityEmbedderModelConfig() + async with perplexity_langchain(cfg, mock_builder) as client: + assert hasattr(client, "embed_documents") + assert hasattr(client, "embed_query") diff --git a/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/embedder.py b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/embedder.py index b15c9ce017..f1cd07c756 100644 --- a/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/embedder.py +++ b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/embedder.py @@ -17,9 +17,11 @@ from nat.builder.framework_enum import LLMFrameworkEnum from nat.cli.register_workflow import register_embedder_client from nat.data_models.retry_mixin import RetryMixin +from nat.data_models.common import get_secret_value from nat.embedder.azure_openai_embedder import AzureOpenAIEmbedderModelConfig from nat.embedder.nim_embedder import NIMEmbedderModelConfig from nat.embedder.openai_embedder import OpenAIEmbedderModelConfig +from nat.embedder.perplexity_embedder import PerplexityEmbedderModelConfig from nat.llm.utils.http_client import http_clients from nat.utils.exception_handlers.automatic_retries import patch_with_retry @@ -95,3 +97,45 @@ async def openai_llama_index(embedder_config: OpenAIEmbedderModelConfig, _builde retry_on_messages=embedder_config.retry_on_errors) yield client + + +@register_embedder_client(config_type=PerplexityEmbedderModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX) +async def perplexity_llama_index(embedder_config: PerplexityEmbedderModelConfig, _builder: Builder): + """LlamaIndex client for the Perplexity Embeddings API. + + Builds a ``PerplexityLlamaIndexEmbedding`` adapter around the shared + ``PerplexityEmbeddings`` HTTP client. Outbound requests carry the + ``X-Pplx-Integration: nemo-agent-toolkit/`` attribution header. + """ + import os + + from nat.plugins.langchain.perplexity_embeddings_client import PerplexityEmbeddings + from nat.plugins.llama_index.perplexity_embeddings_client import PerplexityLlamaIndexEmbedding + + raw_key = get_secret_value(embedder_config.api_key) if embedder_config.api_key else "" + resolved_key = raw_key or os.environ.get("PERPLEXITY_API_KEY", "") + if not resolved_key: + raise ValueError( + "Perplexity embedder requires a non-empty API key. Set ``api_key`` in the " + "workflow config or export ``PERPLEXITY_API_KEY``." + ) + + inner = PerplexityEmbeddings( + api_key=resolved_key, + base_url=embedder_config.base_url, + model=embedder_config.model_name, + dimensions=embedder_config.dimensions, + encoding_format=embedder_config.encoding_format, + batch_size=embedder_config.batch_size, + max_retries=getattr(embedder_config, "num_retries", 3), + verify_ssl=getattr(embedder_config, "verify_ssl", True), + ) + client = PerplexityLlamaIndexEmbedding(client=inner) + + if isinstance(embedder_config, RetryMixin): + client = patch_with_retry(client, + retries=embedder_config.num_retries, + retry_codes=embedder_config.retry_on_status_codes, + retry_on_messages=embedder_config.retry_on_errors) + + yield client diff --git a/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/perplexity_embeddings_client.py b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/perplexity_embeddings_client.py new file mode 100644 index 0000000000..0f72c424ec --- /dev/null +++ b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/perplexity_embeddings_client.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LlamaIndex ``BaseEmbedding`` client for the Perplexity Embeddings API. + +Wraps the framework-agnostic ``PerplexityEmbeddings`` LangChain client to satisfy +LlamaIndex's :class:`~llama_index.core.embeddings.BaseEmbedding` interface so the +same NAT config can be consumed by LlamaIndex-powered retrievers. +""" +from __future__ import annotations + +from typing import Any + +from llama_index.core.bridge.pydantic import PrivateAttr +from llama_index.core.embeddings import BaseEmbedding + +from nat.plugins.langchain.perplexity_embeddings_client import PerplexityEmbeddings + + +class PerplexityLlamaIndexEmbedding(BaseEmbedding): + """LlamaIndex embedding adapter for the Perplexity Embeddings API.""" + + _client: PerplexityEmbeddings = PrivateAttr() + + def __init__(self, *, client: PerplexityEmbeddings, **data: Any) -> None: + data.setdefault("model_name", getattr(client, "_model", "pplx-embed-v1-0.6b")) + super().__init__(**data) + self._client = client + + @classmethod + def class_name(cls) -> str: + return "PerplexityLlamaIndexEmbedding" + + # ------------------------------------------------------------------ + # Sync interface + # ------------------------------------------------------------------ + def _get_query_embedding(self, query: str) -> list[float]: + return self._client.embed_query(query) + + def _get_text_embedding(self, text: str) -> list[float]: + return self._client.embed_query(text) + + def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]: + return self._client.embed_documents(texts) + + # ------------------------------------------------------------------ + # Async interface + # ------------------------------------------------------------------ + async def _aget_query_embedding(self, query: str) -> list[float]: + return await self._client.aembed_query(query) + + async def _aget_text_embedding(self, text: str) -> list[float]: + return await self._client.aembed_query(text) + + async def _aget_text_embeddings(self, texts: list[str]) -> list[list[float]]: + return await self._client.aembed_documents(texts) diff --git a/packages/nvidia_nat_llama_index/tests/test_perplexity_embedder_llama_index.py b/packages/nvidia_nat_llama_index/tests/test_perplexity_embedder_llama_index.py new file mode 100644 index 0000000000..bd5ad23aa9 --- /dev/null +++ b/packages/nvidia_nat_llama_index/tests/test_perplexity_embedder_llama_index.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock + +import pytest + +from nat.embedder.perplexity_embedder import PerplexityEmbedderModelConfig +from nat.plugins.llama_index.embedder import perplexity_llama_index + + +class TestPerplexityLlamaIndexRegistration: + + async def test_requires_api_key(self, monkeypatch, mock_builder): + monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False) + cfg = PerplexityEmbedderModelConfig() + with pytest.raises(ValueError, match="non-empty API key"): + async with perplexity_llama_index(cfg, mock_builder): + pass + + async def test_yields_llama_index_embedding(self, monkeypatch, mock_builder): + monkeypatch.setenv("PERPLEXITY_API_KEY", "env-key") + cfg = PerplexityEmbedderModelConfig() + async with perplexity_llama_index(cfg, mock_builder) as client: + # LlamaIndex's ``BaseEmbedding`` exposes ``get_text_embedding``. + assert hasattr(client, "get_text_embedding") + assert hasattr(client, "get_query_embedding") + + +class TestPerplexityLlamaIndexAdapter: + + def test_get_text_embedding_delegates_to_client(self): + from nat.plugins.langchain.perplexity_embeddings_client import PerplexityEmbeddings + from nat.plugins.llama_index.perplexity_embeddings_client import PerplexityLlamaIndexEmbedding + + inner = MagicMock(spec=PerplexityEmbeddings) + inner._model = "pplx-embed-v1-0.6b" + inner.embed_query.return_value = [0.1, 0.2, 0.3] + inner.embed_documents.return_value = [[0.1], [0.2]] + inner.aembed_query = AsyncMock(return_value=[0.5]) + inner.aembed_documents = AsyncMock(return_value=[[0.6]]) + + adapter = PerplexityLlamaIndexEmbedding(client=inner) + + assert adapter._get_query_embedding("q") == [0.1, 0.2, 0.3] + assert adapter._get_text_embedding("t") == [0.1, 0.2, 0.3] + assert adapter._get_text_embeddings(["a", "b"]) == [[0.1], [0.2]] + + async def test_async_methods_delegate(self): + from nat.plugins.langchain.perplexity_embeddings_client import PerplexityEmbeddings + from nat.plugins.llama_index.perplexity_embeddings_client import PerplexityLlamaIndexEmbedding + + inner = MagicMock(spec=PerplexityEmbeddings) + inner._model = "pplx-embed-v1-0.6b" + inner.aembed_query = AsyncMock(return_value=[0.5]) + inner.aembed_documents = AsyncMock(return_value=[[0.6]]) + + adapter = PerplexityLlamaIndexEmbedding(client=inner) + + assert await adapter._aget_query_embedding("q") == [0.5] + assert await adapter._aget_text_embedding("t") == [0.5] + assert await adapter._aget_text_embeddings(["x"]) == [[0.6]]