From 9b8f95fcb79db8ceff21e8673d4f1594784fdb8f Mon Sep 17 00:00:00 2001 From: James Date: Thu, 30 Apr 2026 13:25:06 +0000 Subject: [PATCH 1/7] Add Perplexity Search API as internet search tool Signed-off-by: PSI Bot Signed-off-by: James --- .../styles/config/vocabularies/nat/accept.txt | 1 + .../tutorials/add-tools-to-a-workflow.md | 44 ++++- .../tools/perplexity_internet_search.py | 142 ++++++++++++++ .../nat/plugins/langchain/tools/register.py | 1 + .../tests/test_perplexity_internet_search.py | 178 ++++++++++++++++++ 5 files changed, 365 insertions(+), 1 deletion(-) create mode 100644 packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py create mode 100644 packages/nvidia_nat_langchain/tests/test_perplexity_internet_search.py diff --git a/ci/vale/styles/config/vocabularies/nat/accept.txt b/ci/vale/styles/config/vocabularies/nat/accept.txt index b38fae5e93..a834509a6d 100644 --- a/ci/vale/styles/config/vocabularies/nat/accept.txt +++ b/ci/vale/styles/config/vocabularies/nat/accept.txt @@ -146,6 +146,7 @@ Pareto Patronus PCIe PDF(s?) +Perplexity [Pp]luggable [Pp]ostprocess [Pp]ostprocessing diff --git a/docs/source/get-started/tutorials/add-tools-to-a-workflow.md b/docs/source/get-started/tutorials/add-tools-to-a-workflow.md index 3f1c477ed4..969b7fa4cc 100644 --- a/docs/source/get-started/tutorials/add-tools-to-a-workflow.md +++ b/docs/source/get-started/tutorials/add-tools-to-a-workflow.md @@ -109,7 +109,7 @@ Workflow Result: ``` ## Alternate Method Using a Web Search Tool -Adding individual web pages to a workflow can be cumbersome, especially when dealing with multiple web pages. An alternative method is to use a web search tool. NeMo Agent Toolkit provides two web search tools: `tavily_internet_search` which utilizes the [Tavily Search API](https://tavily.com/), and `exa_internet_search` which utilizes the [Exa Search API](https://exa.ai/). +Adding individual web pages to a workflow can be cumbersome, especially when dealing with multiple web pages. An alternative method is to use a web search tool. NeMo Agent Toolkit provides web search tools including: `tavily_internet_search` which utilizes the [Tavily Search API](https://tavily.com/), `exa_internet_search` which utilizes the [Exa Search API](https://exa.ai/), and `perplexity_internet_search` which utilizes the [Perplexity Search API](https://docs.perplexity.ai/api-reference/search-post). ### Using Tavily Search @@ -196,3 +196,45 @@ workflow: _type: react_agent tool_names: [internet_search, current_datetime] ``` + +### Using Perplexity Search + +The `perplexity_internet_search` tool is also part of the `nvidia-nat[langchain]` package. If you haven't already installed it: +```bash +# local package install from source +uv pip install -e ".[langchain]" +``` + +Prior to using the `perplexity_internet_search` tool, create a Perplexity account and obtain an API key from the [API key page](https://www.perplexity.ai/account/api/keys). Once obtained, set the `PERPLEXITY_API_KEY` environment variable to the API key: +```bash +export PERPLEXITY_API_KEY= +``` + +You can use the `perplexity_internet_search` tool in the same way as the other web search tools by updating the `functions` section of the configuration file: +```yaml +functions: + internet_search: + _type: perplexity_internet_search + current_datetime: + _type: current_datetime +``` + +The `perplexity_internet_search` tool supports additional configuration options: +```yaml +functions: + internet_search: + _type: perplexity_internet_search + max_results: 5 + max_retries: 3 + max_query_length: 2000 # queries longer than this are truncated + search_recency_filter: week # 'hour', 'day', 'week', 'month', or 'year' + country: US # ISO 3166-1 alpha-2 country code + max_tokens_per_page: 4096 +``` + +Then ensure the tool is included in the workflow tool list: +```yaml +workflow: + _type: react_agent + tool_names: [internet_search, current_datetime] +``` diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py new file mode 100644 index 0000000000..a3f13e7445 --- /dev/null +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import asyncio +import logging +import os +from importlib import metadata +from typing import Literal + +import httpx +from nat.builder.builder import Builder +from nat.builder.framework_enum import LLMFrameworkEnum +from nat.builder.function_info import FunctionInfo +from nat.cli.register_workflow import register_function +from nat.data_models.common import SerializableSecretStr, get_secret_value +from nat.data_models.function import FunctionBaseConfig +from pydantic import Field + +logger = logging.getLogger(__name__) + +PERPLEXITY_SEARCH_URL = "https://api.perplexity.ai/search" + + +# Internet Search tool +class PerplexityInternetSearchToolConfig(FunctionBaseConfig, name="perplexity_internet_search"): + """ + Tool that retrieves relevant contexts from web search (using Perplexity) for the given question. + Requires a PERPLEXITY_API_KEY. + """ + max_results: int = Field(default=5, ge=1, le=20, description="Maximum number of search results to return.") + api_key: SerializableSecretStr = Field(default_factory=lambda: SerializableSecretStr(""), + description="The API key for the Perplexity service.") + max_retries: int = Field(default=3, ge=1, description="Maximum number of retries for the search request") + max_query_length: int = Field( + default=2000, + ge=1, + description="Maximum query length in characters. Queries exceeding this limit will be truncated.") + search_recency_filter: Literal["hour", "day", "week", "month", "year"] | None = Field( + default=None, description="Filter search results by recency - 'hour', 'day', 'week', 'month', or 'year'.") + country: str | None = Field(default=None, description="Country to filter search results by ISO 3166-1 alpha-2 code.") + max_tokens_per_page: int = Field( + default=4096, ge=1, description="Maximum number of tokens to retrieve per search result page.") + + +def _get_integration_header() -> str: + try: + package_version = metadata.version("nvidia-nat-langchain") + except metadata.PackageNotFoundError: + package_version = "unknown" + return f"nemo-agent-toolkit/{package_version}" + + +@register_function(config_type=PerplexityInternetSearchToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) +async def perplexity_internet_search(tool_config: PerplexityInternetSearchToolConfig, builder: Builder): + api_key = get_secret_value(tool_config.api_key) if tool_config.api_key else "" + resolved_api_key = api_key or os.environ.get("PERPLEXITY_API_KEY", "") + + async def _perplexity_internet_search(question: str) -> str: + """This tool retrieves relevant contexts from web search (using Perplexity) for the given question. + + Args: + question (str): The question to be answered. + + Returns: + str: The web search results. + """ + if not resolved_api_key: + return "Web search is unavailable: 'PERPLEXITY_API_KEY' is not configured." + + # Truncate long queries to the configured limit + max_len = tool_config.max_query_length + if len(question) > max_len: + logger.warning("Perplexity query truncated from %d to %d characters", len(question), max_len) + question = question[:max_len - 3] + "..." if max_len > 3 else question[:max_len] + + request_body = { + "query": question, + "max_results": tool_config.max_results, + "max_tokens_per_page": tool_config.max_tokens_per_page, + } + if tool_config.search_recency_filter is not None: + request_body["search_recency_filter"] = tool_config.search_recency_filter + if tool_config.country is not None: + request_body["country"] = tool_config.country + + headers = { + "Authorization": f"Bearer {resolved_api_key}", + "Content-Type": "application/json", + "X-Pplx-Integration": _get_integration_header(), + } + + async with httpx.AsyncClient() as client: + for attempt in range(tool_config.max_retries): + try: + response = await client.post(PERPLEXITY_SEARCH_URL, headers=headers, json=request_body) + response.raise_for_status() + search_response = response.json() + results = search_response.get("results") if isinstance(search_response, dict) else None + if not results: + return f"No web search results found for: {question}" + + web_search_results = "\n\n---\n\n".join([ + f'\n{doc.get("snippet", "")}\n' + for doc in results + ]) + return web_search_results or f"No web search results found for: {question}" + except httpx.HTTPError: + # Return a graceful message instead of raising, so the agent can + # continue reasoning without web search rather than failing entirely. + logger.exception("Perplexity search HTTP attempt %d of %d failed", + attempt + 1, + tool_config.max_retries) + if attempt == tool_config.max_retries - 1: + return f"Web search failed after {tool_config.max_retries} attempts for: {question}" + await asyncio.sleep(2**attempt) + except Exception: + # Return a graceful message instead of raising, so the agent can + # continue reasoning without web search rather than failing entirely. + logger.exception("Perplexity search attempt %d of %d failed", attempt + 1, tool_config.max_retries) + if attempt == tool_config.max_retries - 1: + return f"Web search failed after {tool_config.max_retries} attempts for: {question}" + await asyncio.sleep(2**attempt) + return f"Web search failed after {tool_config.max_retries} attempts for: {question}" + + # Create a Generic NAT tool that can be used with any supported LLM framework + yield FunctionInfo.from_fn( + _perplexity_internet_search, + description=_perplexity_internet_search.__doc__, + ) diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/register.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/register.py index dc981b627b..688dedfaa6 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/register.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/register.py @@ -20,5 +20,6 @@ from . import code_generation_tool from . import exa_internet_search +from . import perplexity_internet_search from . import tavily_internet_search from . import wikipedia_search diff --git a/packages/nvidia_nat_langchain/tests/test_perplexity_internet_search.py b/packages/nvidia_nat_langchain/tests/test_perplexity_internet_search.py new file mode 100644 index 0000000000..f96518372e --- /dev/null +++ b/packages/nvidia_nat_langchain/tests/test_perplexity_internet_search.py @@ -0,0 +1,178 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from nat.plugins.langchain.tools.perplexity_internet_search import ( + PerplexityInternetSearchToolConfig, + perplexity_internet_search, +) +from pydantic import SecretStr, ValidationError + +# -- Config validation tests -- + + +@pytest.mark.parametrize("constructor_args", [{}, { + "api_key": "" +}, { + "api_key": "my_api_key" +}], + ids=["default", "empty_api_key", "provided_api_key"]) +def test_api_key_is_secret_str(constructor_args: dict): + expected_api_key = constructor_args.get("api_key", "") + + config = PerplexityInternetSearchToolConfig(**constructor_args) + assert isinstance(config.api_key, SecretStr) + + api_key = config.api_key.get_secret_value() + assert api_key == expected_api_key + + +def test_default_api_key_is_unique_instance(): + config1 = PerplexityInternetSearchToolConfig() + config2 = PerplexityInternetSearchToolConfig() + + assert config1.api_key is not config2.api_key + + +def test_max_retries_rejects_zero(): + with pytest.raises(ValidationError): + PerplexityInternetSearchToolConfig(max_retries=0) + + +def test_max_results_rejects_zero(): + with pytest.raises(ValidationError): + PerplexityInternetSearchToolConfig(max_results=0) + + +def test_max_results_rejects_above_20(): + with pytest.raises(ValidationError): + PerplexityInternetSearchToolConfig(max_results=21) + + +def test_invalid_search_recency_filter_rejected(): + with pytest.raises(ValidationError): + PerplexityInternetSearchToolConfig(search_recency_filter="invalid") + + +# -- Tool behavior tests -- + + +@pytest.fixture +def tool_config(): + return PerplexityInternetSearchToolConfig(api_key="test-key", max_retries=2, max_query_length=50) + + +def _mock_response(results: list[dict] | None): + response = MagicMock() + response.raise_for_status.return_value = None + response.json.return_value = {"results": results} + return response + + +def _mock_async_client(post_mock: AsyncMock): + mock_client = MagicMock() + mock_client.post = post_mock + mock_context_manager = MagicMock() + mock_context_manager.__aenter__ = AsyncMock(return_value=mock_client) + mock_context_manager.__aexit__ = AsyncMock(return_value=None) + return mock_context_manager, mock_client + + +async def test_empty_key_returns_unavailable(): + config = PerplexityInternetSearchToolConfig(api_key="") + with patch.dict(os.environ, {"PERPLEXITY_API_KEY": ""}): + async with perplexity_internet_search(config, None) as func_info: + result = await func_info.single_fn("test query") + assert "unavailable" in result.lower() + assert "PERPLEXITY_API_KEY" in result + + +async def test_query_truncation(tool_config): + long_query = "a" * 100 # exceeds max_query_length=50 + post_mock = AsyncMock(return_value=_mock_response([])) + mock_context_manager, mock_client = _mock_async_client(post_mock) + + with patch("nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", + return_value=mock_context_manager): + async with perplexity_internet_search(tool_config, None) as func_info: + await func_info.single_fn(long_query) + + # Verify the query was truncated + call_args = mock_client.post.call_args + truncated_query = call_args.kwargs["json"]["query"] + assert len(truncated_query) <= 50 + assert truncated_query.endswith("...") + + +async def test_empty_results(tool_config): + post_mock = AsyncMock(return_value=_mock_response([])) + mock_context_manager, _ = _mock_async_client(post_mock) + + with patch("nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", + return_value=mock_context_manager): + async with perplexity_internet_search(tool_config, None) as func_info: + result = await func_info.single_fn("test query") + assert "No web search results found" in result + + +async def test_retries_on_exception(tool_config): + post_mock = AsyncMock(side_effect=Exception("API error")) + mock_context_manager, mock_client = _mock_async_client(post_mock) + + with patch("nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", + return_value=mock_context_manager), patch( + "nat.plugins.langchain.tools.perplexity_internet_search.asyncio.sleep", new_callable=AsyncMock): + async with perplexity_internet_search(tool_config, None) as func_info: + result = await func_info.single_fn("test query") + + # Should have retried max_retries times (2) + assert mock_client.post.call_count == 2 + assert "Web search failed after 2 attempts" in result + + +async def test_attribution_header_sent(tool_config): + post_mock = AsyncMock(return_value=_mock_response([])) + mock_context_manager, mock_client = _mock_async_client(post_mock) + + with patch("nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", + return_value=mock_context_manager): + async with perplexity_internet_search(tool_config, None) as func_info: + await func_info.single_fn("test query") + + call_args = mock_client.post.call_args + assert call_args.kwargs["headers"]["X-Pplx-Integration"].startswith("nemo-agent-toolkit/") + + +async def test_results_formatted_as_documents(tool_config): + post_mock = AsyncMock(return_value=_mock_response([{ + "url": "https://example.com/one", + "snippet": "First result.", + }, { + "url": "https://example.com/two", + "snippet": "Second result.", + }])) + mock_context_manager, _ = _mock_async_client(post_mock) + + with patch("nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", + return_value=mock_context_manager): + async with perplexity_internet_search(tool_config, None) as func_info: + result = await func_info.single_fn("test query") + + assert '' in result + assert '' in result + assert "\n\n---\n\n" in result From 42c15bda92db62b5ba2b1b9d3d4f6b660a17e567 Mon Sep 17 00:00:00 2001 From: James Liounis Date: Thu, 30 Apr 2026 13:48:59 +0000 Subject: [PATCH 2/7] Match Exa conventions: import grouping, consolidated exception handling Signed-off-by: James Liounis --- .../tools/perplexity_internet_search.py | 52 ++++----- .../tests/test_perplexity_internet_search.py | 100 ++++++++++++------ 2 files changed, 98 insertions(+), 54 deletions(-) diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py index a3f13e7445..f267024432 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py @@ -16,18 +16,18 @@ import asyncio import logging -import os -from importlib import metadata from typing import Literal import httpx +from pydantic import Field + from nat.builder.builder import Builder from nat.builder.framework_enum import LLMFrameworkEnum from nat.builder.function_info import FunctionInfo from nat.cli.register_workflow import register_function -from nat.data_models.common import SerializableSecretStr, get_secret_value +from nat.data_models.common import SerializableSecretStr +from nat.data_models.common import get_secret_value from nat.data_models.function import FunctionBaseConfig -from pydantic import Field logger = logging.getLogger(__name__) @@ -40,22 +40,31 @@ class PerplexityInternetSearchToolConfig(FunctionBaseConfig, name="perplexity_in Tool that retrieves relevant contexts from web search (using Perplexity) for the given question. Requires a PERPLEXITY_API_KEY. """ + max_results: int = Field(default=5, ge=1, le=20, description="Maximum number of search results to return.") - api_key: SerializableSecretStr = Field(default_factory=lambda: SerializableSecretStr(""), - description="The API key for the Perplexity service.") + api_key: SerializableSecretStr = Field( + default_factory=lambda: SerializableSecretStr(""), description="The API key for the Perplexity service." + ) max_retries: int = Field(default=3, ge=1, description="Maximum number of retries for the search request") max_query_length: int = Field( default=2000, ge=1, - description="Maximum query length in characters. Queries exceeding this limit will be truncated.") + description="Maximum query length in characters. Queries exceeding this limit will be truncated.", + ) search_recency_filter: Literal["hour", "day", "week", "month", "year"] | None = Field( - default=None, description="Filter search results by recency - 'hour', 'day', 'week', 'month', or 'year'.") - country: str | None = Field(default=None, description="Country to filter search results by ISO 3166-1 alpha-2 code.") + default=None, description="Filter search results by recency - 'hour', 'day', 'week', 'month', or 'year'." + ) + country: str | None = Field( + default=None, description="Country to filter search results by ISO 3166-1 alpha-2 code." + ) max_tokens_per_page: int = Field( - default=4096, ge=1, description="Maximum number of tokens to retrieve per search result page.") + default=4096, ge=1, description="Maximum number of tokens to retrieve per search result page." + ) def _get_integration_header() -> str: + from importlib import metadata + try: package_version = metadata.version("nvidia-nat-langchain") except metadata.PackageNotFoundError: @@ -65,6 +74,8 @@ def _get_integration_header() -> str: @register_function(config_type=PerplexityInternetSearchToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) async def perplexity_internet_search(tool_config: PerplexityInternetSearchToolConfig, builder: Builder): + import os + api_key = get_secret_value(tool_config.api_key) if tool_config.api_key else "" resolved_api_key = api_key or os.environ.get("PERPLEXITY_API_KEY", "") @@ -84,7 +95,7 @@ async def _perplexity_internet_search(question: str) -> str: max_len = tool_config.max_query_length if len(question) > max_len: logger.warning("Perplexity query truncated from %d to %d characters", len(question), max_len) - question = question[:max_len - 3] + "..." if max_len > 3 else question[:max_len] + question = question[: max_len - 3] + "..." if max_len > 3 else question[:max_len] request_body = { "query": question, @@ -112,20 +123,13 @@ async def _perplexity_internet_search(question: str) -> str: if not results: return f"No web search results found for: {question}" - web_search_results = "\n\n---\n\n".join([ - f'\n{doc.get("snippet", "")}\n' - for doc in results - ]) + web_search_results = "\n\n---\n\n".join( + [ + f'\n{doc.get("snippet", "")}\n' + for doc in results + ] + ) return web_search_results or f"No web search results found for: {question}" - except httpx.HTTPError: - # Return a graceful message instead of raising, so the agent can - # continue reasoning without web search rather than failing entirely. - logger.exception("Perplexity search HTTP attempt %d of %d failed", - attempt + 1, - tool_config.max_retries) - if attempt == tool_config.max_retries - 1: - return f"Web search failed after {tool_config.max_retries} attempts for: {question}" - await asyncio.sleep(2**attempt) except Exception: # Return a graceful message instead of raising, so the agent can # continue reasoning without web search rather than failing entirely. diff --git a/packages/nvidia_nat_langchain/tests/test_perplexity_internet_search.py b/packages/nvidia_nat_langchain/tests/test_perplexity_internet_search.py index f96518372e..523a9a68a7 100644 --- a/packages/nvidia_nat_langchain/tests/test_perplexity_internet_search.py +++ b/packages/nvidia_nat_langchain/tests/test_perplexity_internet_search.py @@ -14,25 +14,25 @@ # limitations under the License. import os -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch import pytest -from nat.plugins.langchain.tools.perplexity_internet_search import ( - PerplexityInternetSearchToolConfig, - perplexity_internet_search, -) -from pydantic import SecretStr, ValidationError +from pydantic import SecretStr +from pydantic import ValidationError # -- Config validation tests -- -@pytest.mark.parametrize("constructor_args", [{}, { - "api_key": "" -}, { - "api_key": "my_api_key" -}], - ids=["default", "empty_api_key", "provided_api_key"]) +@pytest.mark.parametrize( + "constructor_args", + [{}, {"api_key": ""}, {"api_key": "my_api_key"}], + ids=["default", "empty_api_key", "provided_api_key"], +) def test_api_key_is_secret_str(constructor_args: dict): + from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + expected_api_key = constructor_args.get("api_key", "") config = PerplexityInternetSearchToolConfig(**constructor_args) @@ -43,6 +43,8 @@ def test_api_key_is_secret_str(constructor_args: dict): def test_default_api_key_is_unique_instance(): + from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + config1 = PerplexityInternetSearchToolConfig() config2 = PerplexityInternetSearchToolConfig() @@ -50,21 +52,29 @@ def test_default_api_key_is_unique_instance(): def test_max_retries_rejects_zero(): + from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + with pytest.raises(ValidationError): PerplexityInternetSearchToolConfig(max_retries=0) def test_max_results_rejects_zero(): + from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + with pytest.raises(ValidationError): PerplexityInternetSearchToolConfig(max_results=0) def test_max_results_rejects_above_20(): + from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + with pytest.raises(ValidationError): PerplexityInternetSearchToolConfig(max_results=21) def test_invalid_search_recency_filter_rejected(): + from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + with pytest.raises(ValidationError): PerplexityInternetSearchToolConfig(search_recency_filter="invalid") @@ -74,6 +84,8 @@ def test_invalid_search_recency_filter_rejected(): @pytest.fixture def tool_config(): + from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + return PerplexityInternetSearchToolConfig(api_key="test-key", max_retries=2, max_query_length=50) @@ -94,6 +106,9 @@ def _mock_async_client(post_mock: AsyncMock): async def test_empty_key_returns_unavailable(): + from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + from nat.plugins.langchain.tools.perplexity_internet_search import perplexity_internet_search + config = PerplexityInternetSearchToolConfig(api_key="") with patch.dict(os.environ, {"PERPLEXITY_API_KEY": ""}): async with perplexity_internet_search(config, None) as func_info: @@ -103,12 +118,15 @@ async def test_empty_key_returns_unavailable(): async def test_query_truncation(tool_config): + from nat.plugins.langchain.tools.perplexity_internet_search import perplexity_internet_search + long_query = "a" * 100 # exceeds max_query_length=50 post_mock = AsyncMock(return_value=_mock_response([])) mock_context_manager, mock_client = _mock_async_client(post_mock) - with patch("nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", - return_value=mock_context_manager): + with patch( + "nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager + ): async with perplexity_internet_search(tool_config, None) as func_info: await func_info.single_fn(long_query) @@ -120,23 +138,32 @@ async def test_query_truncation(tool_config): async def test_empty_results(tool_config): + from nat.plugins.langchain.tools.perplexity_internet_search import perplexity_internet_search + post_mock = AsyncMock(return_value=_mock_response([])) mock_context_manager, _ = _mock_async_client(post_mock) - with patch("nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", - return_value=mock_context_manager): + with patch( + "nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager + ): async with perplexity_internet_search(tool_config, None) as func_info: result = await func_info.single_fn("test query") assert "No web search results found" in result async def test_retries_on_exception(tool_config): + from nat.plugins.langchain.tools.perplexity_internet_search import perplexity_internet_search + post_mock = AsyncMock(side_effect=Exception("API error")) mock_context_manager, mock_client = _mock_async_client(post_mock) - with patch("nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", - return_value=mock_context_manager), patch( - "nat.plugins.langchain.tools.perplexity_internet_search.asyncio.sleep", new_callable=AsyncMock): + with ( + patch( + "nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", + return_value=mock_context_manager, + ), + patch("nat.plugins.langchain.tools.perplexity_internet_search.asyncio.sleep", new_callable=AsyncMock), + ): async with perplexity_internet_search(tool_config, None) as func_info: result = await func_info.single_fn("test query") @@ -146,11 +173,14 @@ async def test_retries_on_exception(tool_config): async def test_attribution_header_sent(tool_config): + from nat.plugins.langchain.tools.perplexity_internet_search import perplexity_internet_search + post_mock = AsyncMock(return_value=_mock_response([])) mock_context_manager, mock_client = _mock_async_client(post_mock) - with patch("nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", - return_value=mock_context_manager): + with patch( + "nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager + ): async with perplexity_internet_search(tool_config, None) as func_info: await func_info.single_fn("test query") @@ -159,17 +189,27 @@ async def test_attribution_header_sent(tool_config): async def test_results_formatted_as_documents(tool_config): - post_mock = AsyncMock(return_value=_mock_response([{ - "url": "https://example.com/one", - "snippet": "First result.", - }, { - "url": "https://example.com/two", - "snippet": "Second result.", - }])) + from nat.plugins.langchain.tools.perplexity_internet_search import perplexity_internet_search + + post_mock = AsyncMock( + return_value=_mock_response( + [ + { + "url": "https://example.com/one", + "snippet": "First result.", + }, + { + "url": "https://example.com/two", + "snippet": "Second result.", + }, + ] + ) + ) mock_context_manager, _ = _mock_async_client(post_mock) - with patch("nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", - return_value=mock_context_manager): + with patch( + "nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager + ): async with perplexity_internet_search(tool_config, None) as func_info: result = await func_info.single_fn("test query") From c0bd9ab153831356ae76ab3901eb8624675f1068 Mon Sep 17 00:00:00 2001 From: James Liounis Date: Thu, 30 Apr 2026 15:40:21 +0000 Subject: [PATCH 3/7] Add docstrings to perplexity_internet_search registration and integration-header helper Addresses CodeRabbit review nit: public registration function and the integration-header helper now carry Google-style docstrings, raising docstring coverage on the new module above NeMo's 80% threshold. Signed-off-by: James Liounis --- .../tools/perplexity_internet_search.py | 23 +++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py index f267024432..58d29b5c83 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py @@ -63,6 +63,12 @@ class PerplexityInternetSearchToolConfig(FunctionBaseConfig, name="perplexity_in def _get_integration_header() -> str: + """Build the ``X-Pplx-Integration`` header value for outbound Perplexity requests. + + Returns: + str: A ``"nemo-agent-toolkit/"`` slug. Falls back to ``"unknown"`` + when the ``nvidia-nat-langchain`` package metadata cannot be resolved. + """ from importlib import metadata try: @@ -74,6 +80,23 @@ def _get_integration_header() -> str: @register_function(config_type=PerplexityInternetSearchToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) async def perplexity_internet_search(tool_config: PerplexityInternetSearchToolConfig, builder: Builder): + """Register the Perplexity internet search tool with the NAT runtime. + + Resolves the Perplexity API key from the tool config or the + ``PERPLEXITY_API_KEY`` environment variable, then yields a ``FunctionInfo`` + wrapping an async coroutine that calls Perplexity's Search API + (``POST https://api.perplexity.ai/search``) with retry/backoff and formats + results as ```` blocks suitable for LLM consumption. + + Args: + tool_config: ``PerplexityInternetSearchToolConfig`` carrying API key, + result/retry caps, recency/country filters, and per-page token cap. + builder: NAT ``Builder`` instance (unused by this tool but required by + the ``@register_function`` contract). + + Yields: + FunctionInfo: The registered async search function. + """ import os api_key = get_secret_value(tool_config.api_key) if tool_config.api_key else "" From f3495c03966ec7bfa598fe79a315fe8f384f2b6e Mon Sep 17 00:00:00 2001 From: James Liounis Date: Tue, 5 May 2026 13:28:55 +0000 Subject: [PATCH 4/7] refactor: register Perplexity Search tool with all 8 NAT frameworks Per maintainer @bbednarski9's request on PR #1903, extend the Perplexity internet search tool from LangChain-only to all 8 frameworks supported by LLMFrameworkEnum: LANGCHAIN, LLAMA_INDEX, CREWAI, SEMANTIC_KERNEL, AGNO, ADK, STRANDS, AUTOGEN. The tool implementation is HTTP-only (raw httpx calls to api.perplexity.ai/search) with no framework-specific imports, so it can live in a single framework-agnostic location and be wrapped for each target framework via NAT's existing tool_wrapper.py mechanism in each plugin package. Changes: - Move tool from packages/nvidia_nat_langchain/.../tools/ to packages/nvidia_nat_core/.../tool/ (framework-agnostic location) - Expand framework_wrappers list to include all 8 LLMFrameworkEnum values - Move tests to nvidia_nat_core's test suite - Update docs to reflect multi-framework support Signed-off-by: James Liounis --- .../tutorials/add-tools-to-a-workflow.md | 4 +- .../nat/tool}/perplexity_internet_search.py | 22 ++++++++-- .../nvidia_nat_core/src/nat/tool/register.py | 1 + .../tools}/test_perplexity_internet_search.py | 40 +++++++++---------- .../nat/plugins/langchain/tools/register.py | 1 - 5 files changed, 42 insertions(+), 26 deletions(-) rename packages/{nvidia_nat_langchain/src/nat/plugins/langchain/tools => nvidia_nat_core/src/nat/tool}/perplexity_internet_search.py (91%) rename packages/{nvidia_nat_langchain/tests => nvidia_nat_core/tests/nat/tools}/test_perplexity_internet_search.py (74%) diff --git a/docs/source/get-started/tutorials/add-tools-to-a-workflow.md b/docs/source/get-started/tutorials/add-tools-to-a-workflow.md index 969b7fa4cc..f4a2765dca 100644 --- a/docs/source/get-started/tutorials/add-tools-to-a-workflow.md +++ b/docs/source/get-started/tutorials/add-tools-to-a-workflow.md @@ -199,10 +199,10 @@ workflow: ### Using Perplexity Search -The `perplexity_internet_search` tool is also part of the `nvidia-nat[langchain]` package. If you haven't already installed it: +The `perplexity_internet_search` tool ships with the core `nvidia-nat` package and is framework-agnostic — it can be used with any of the agent frameworks supported by NAT (`langchain`, `llama_index`, `crewai`, `semantic_kernel`, `agno`, `adk`, `strands`, and `autogen`). No framework-specific extra is required to install it: ```bash # local package install from source -uv pip install -e ".[langchain]" +uv pip install -e . ``` Prior to using the `perplexity_internet_search` tool, create a Perplexity account and obtain an API key from the [API key page](https://www.perplexity.ai/account/api/keys). Once obtained, set the `PERPLEXITY_API_KEY` environment variable to the API key: diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py b/packages/nvidia_nat_core/src/nat/tool/perplexity_internet_search.py similarity index 91% rename from packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py rename to packages/nvidia_nat_core/src/nat/tool/perplexity_internet_search.py index 58d29b5c83..1697544cb8 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/perplexity_internet_search.py +++ b/packages/nvidia_nat_core/src/nat/tool/perplexity_internet_search.py @@ -67,18 +67,30 @@ def _get_integration_header() -> str: Returns: str: A ``"nemo-agent-toolkit/"`` slug. Falls back to ``"unknown"`` - when the ``nvidia-nat-langchain`` package metadata cannot be resolved. + when the ``nvidia-nat`` package metadata cannot be resolved. """ from importlib import metadata try: - package_version = metadata.version("nvidia-nat-langchain") + package_version = metadata.version("nvidia-nat") except metadata.PackageNotFoundError: package_version = "unknown" return f"nemo-agent-toolkit/{package_version}" -@register_function(config_type=PerplexityInternetSearchToolConfig, framework_wrappers=[LLMFrameworkEnum.LANGCHAIN]) +@register_function( + config_type=PerplexityInternetSearchToolConfig, + framework_wrappers=[ + LLMFrameworkEnum.LANGCHAIN, + LLMFrameworkEnum.LLAMA_INDEX, + LLMFrameworkEnum.CREWAI, + LLMFrameworkEnum.SEMANTIC_KERNEL, + LLMFrameworkEnum.AGNO, + LLMFrameworkEnum.ADK, + LLMFrameworkEnum.STRANDS, + LLMFrameworkEnum.AUTOGEN, + ], +) async def perplexity_internet_search(tool_config: PerplexityInternetSearchToolConfig, builder: Builder): """Register the Perplexity internet search tool with the NAT runtime. @@ -88,6 +100,10 @@ async def perplexity_internet_search(tool_config: PerplexityInternetSearchToolCo (``POST https://api.perplexity.ai/search``) with retry/backoff and formats results as ```` blocks suitable for LLM consumption. + The tool is framework-agnostic (uses raw ``httpx``) and is registered with + every framework wrapper in :class:`LLMFrameworkEnum` so it can be consumed + by any NAT-supported agent framework. + Args: tool_config: ``PerplexityInternetSearchToolConfig`` carrying API key, result/retry caps, recency/country filters, and per-page token cap. diff --git a/packages/nvidia_nat_core/src/nat/tool/register.py b/packages/nvidia_nat_core/src/nat/tool/register.py index 0f658d45fe..9b0643de21 100644 --- a/packages/nvidia_nat_core/src/nat/tool/register.py +++ b/packages/nvidia_nat_core/src/nat/tool/register.py @@ -21,6 +21,7 @@ from . import document_search from . import github_tools from . import nvidia_rag +from . import perplexity_internet_search from . import retriever from . import server_tools from .code_execution import register diff --git a/packages/nvidia_nat_langchain/tests/test_perplexity_internet_search.py b/packages/nvidia_nat_core/tests/nat/tools/test_perplexity_internet_search.py similarity index 74% rename from packages/nvidia_nat_langchain/tests/test_perplexity_internet_search.py rename to packages/nvidia_nat_core/tests/nat/tools/test_perplexity_internet_search.py index 523a9a68a7..06ebfa1ffb 100644 --- a/packages/nvidia_nat_langchain/tests/test_perplexity_internet_search.py +++ b/packages/nvidia_nat_core/tests/nat/tools/test_perplexity_internet_search.py @@ -31,7 +31,7 @@ ids=["default", "empty_api_key", "provided_api_key"], ) def test_api_key_is_secret_str(constructor_args: dict): - from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig expected_api_key = constructor_args.get("api_key", "") @@ -43,7 +43,7 @@ def test_api_key_is_secret_str(constructor_args: dict): def test_default_api_key_is_unique_instance(): - from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig config1 = PerplexityInternetSearchToolConfig() config2 = PerplexityInternetSearchToolConfig() @@ -52,28 +52,28 @@ def test_default_api_key_is_unique_instance(): def test_max_retries_rejects_zero(): - from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig with pytest.raises(ValidationError): PerplexityInternetSearchToolConfig(max_retries=0) def test_max_results_rejects_zero(): - from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig with pytest.raises(ValidationError): PerplexityInternetSearchToolConfig(max_results=0) def test_max_results_rejects_above_20(): - from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig with pytest.raises(ValidationError): PerplexityInternetSearchToolConfig(max_results=21) def test_invalid_search_recency_filter_rejected(): - from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig with pytest.raises(ValidationError): PerplexityInternetSearchToolConfig(search_recency_filter="invalid") @@ -84,7 +84,7 @@ def test_invalid_search_recency_filter_rejected(): @pytest.fixture def tool_config(): - from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig return PerplexityInternetSearchToolConfig(api_key="test-key", max_retries=2, max_query_length=50) @@ -106,8 +106,8 @@ def _mock_async_client(post_mock: AsyncMock): async def test_empty_key_returns_unavailable(): - from nat.plugins.langchain.tools.perplexity_internet_search import PerplexityInternetSearchToolConfig - from nat.plugins.langchain.tools.perplexity_internet_search import perplexity_internet_search + from nat.tool.perplexity_internet_search import PerplexityInternetSearchToolConfig + from nat.tool.perplexity_internet_search import perplexity_internet_search config = PerplexityInternetSearchToolConfig(api_key="") with patch.dict(os.environ, {"PERPLEXITY_API_KEY": ""}): @@ -118,14 +118,14 @@ async def test_empty_key_returns_unavailable(): async def test_query_truncation(tool_config): - from nat.plugins.langchain.tools.perplexity_internet_search import perplexity_internet_search + from nat.tool.perplexity_internet_search import perplexity_internet_search long_query = "a" * 100 # exceeds max_query_length=50 post_mock = AsyncMock(return_value=_mock_response([])) mock_context_manager, mock_client = _mock_async_client(post_mock) with patch( - "nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager + "nat.tool.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager ): async with perplexity_internet_search(tool_config, None) as func_info: await func_info.single_fn(long_query) @@ -138,13 +138,13 @@ async def test_query_truncation(tool_config): async def test_empty_results(tool_config): - from nat.plugins.langchain.tools.perplexity_internet_search import perplexity_internet_search + from nat.tool.perplexity_internet_search import perplexity_internet_search post_mock = AsyncMock(return_value=_mock_response([])) mock_context_manager, _ = _mock_async_client(post_mock) with patch( - "nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager + "nat.tool.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager ): async with perplexity_internet_search(tool_config, None) as func_info: result = await func_info.single_fn("test query") @@ -152,17 +152,17 @@ async def test_empty_results(tool_config): async def test_retries_on_exception(tool_config): - from nat.plugins.langchain.tools.perplexity_internet_search import perplexity_internet_search + from nat.tool.perplexity_internet_search import perplexity_internet_search post_mock = AsyncMock(side_effect=Exception("API error")) mock_context_manager, mock_client = _mock_async_client(post_mock) with ( patch( - "nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", + "nat.tool.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager, ), - patch("nat.plugins.langchain.tools.perplexity_internet_search.asyncio.sleep", new_callable=AsyncMock), + patch("nat.tool.perplexity_internet_search.asyncio.sleep", new_callable=AsyncMock), ): async with perplexity_internet_search(tool_config, None) as func_info: result = await func_info.single_fn("test query") @@ -173,13 +173,13 @@ async def test_retries_on_exception(tool_config): async def test_attribution_header_sent(tool_config): - from nat.plugins.langchain.tools.perplexity_internet_search import perplexity_internet_search + from nat.tool.perplexity_internet_search import perplexity_internet_search post_mock = AsyncMock(return_value=_mock_response([])) mock_context_manager, mock_client = _mock_async_client(post_mock) with patch( - "nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager + "nat.tool.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager ): async with perplexity_internet_search(tool_config, None) as func_info: await func_info.single_fn("test query") @@ -189,7 +189,7 @@ async def test_attribution_header_sent(tool_config): async def test_results_formatted_as_documents(tool_config): - from nat.plugins.langchain.tools.perplexity_internet_search import perplexity_internet_search + from nat.tool.perplexity_internet_search import perplexity_internet_search post_mock = AsyncMock( return_value=_mock_response( @@ -208,7 +208,7 @@ async def test_results_formatted_as_documents(tool_config): mock_context_manager, _ = _mock_async_client(post_mock) with patch( - "nat.plugins.langchain.tools.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager + "nat.tool.perplexity_internet_search.httpx.AsyncClient", return_value=mock_context_manager ): async with perplexity_internet_search(tool_config, None) as func_info: result = await func_info.single_fn("test query") diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/register.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/register.py index 688dedfaa6..dc981b627b 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/register.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/tools/register.py @@ -20,6 +20,5 @@ from . import code_generation_tool from . import exa_internet_search -from . import perplexity_internet_search from . import tavily_internet_search from . import wikipedia_search From 470d3b49f68475e12ac8f0a9aaca3cfa89eeb44e Mon Sep 17 00:00:00 2001 From: James Liounis Date: Tue, 5 May 2026 20:56:41 +0000 Subject: [PATCH 5/7] fix: address CodeRabbit review feedback - Validate country field as ISO 3166-1 alpha-2 (min/max length 2, pattern ^[A-Z]{2}$) - Skip retry on non-retriable 4xx (401/403/404); continue retrying 5xx and timeouts - Replace 'NAT' acronym with 'the toolkit' in user-facing docs (per repo guideline) - Add AsyncIterator[FunctionInfo] return type annotation to the public registration function Signed-off-by: James Liounis --- .../tutorials/add-tools-to-a-workflow.md | 2 +- .../nat/tool/perplexity_internet_search.py | 32 +++++++++++++++++-- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/docs/source/get-started/tutorials/add-tools-to-a-workflow.md b/docs/source/get-started/tutorials/add-tools-to-a-workflow.md index f4a2765dca..aae36030b0 100644 --- a/docs/source/get-started/tutorials/add-tools-to-a-workflow.md +++ b/docs/source/get-started/tutorials/add-tools-to-a-workflow.md @@ -199,7 +199,7 @@ workflow: ### Using Perplexity Search -The `perplexity_internet_search` tool ships with the core `nvidia-nat` package and is framework-agnostic — it can be used with any of the agent frameworks supported by NAT (`langchain`, `llama_index`, `crewai`, `semantic_kernel`, `agno`, `adk`, `strands`, and `autogen`). No framework-specific extra is required to install it: +The `perplexity_internet_search` tool ships with the core `nvidia-nat` package and is framework-agnostic — it can be used with any of the agent frameworks supported by the toolkit (`langchain`, `llama_index`, `crewai`, `semantic_kernel`, `agno`, `adk`, `strands`, and `autogen`). No framework-specific extra is required to install it: ```bash # local package install from source uv pip install -e . diff --git a/packages/nvidia_nat_core/src/nat/tool/perplexity_internet_search.py b/packages/nvidia_nat_core/src/nat/tool/perplexity_internet_search.py index 1697544cb8..3b437e4037 100644 --- a/packages/nvidia_nat_core/src/nat/tool/perplexity_internet_search.py +++ b/packages/nvidia_nat_core/src/nat/tool/perplexity_internet_search.py @@ -16,6 +16,7 @@ import asyncio import logging +from collections.abc import AsyncIterator from typing import Literal import httpx @@ -55,7 +56,11 @@ class PerplexityInternetSearchToolConfig(FunctionBaseConfig, name="perplexity_in default=None, description="Filter search results by recency - 'hour', 'day', 'week', 'month', or 'year'." ) country: str | None = Field( - default=None, description="Country to filter search results by ISO 3166-1 alpha-2 code." + default=None, + min_length=2, + max_length=2, + pattern=r"^[A-Z]{2}$", + description="Country to filter search results by ISO 3166-1 alpha-2 code (uppercase, e.g. 'US', 'GB').", ) max_tokens_per_page: int = Field( default=4096, ge=1, description="Maximum number of tokens to retrieve per search result page." @@ -91,7 +96,9 @@ def _get_integration_header() -> str: LLMFrameworkEnum.AUTOGEN, ], ) -async def perplexity_internet_search(tool_config: PerplexityInternetSearchToolConfig, builder: Builder): +async def perplexity_internet_search( + tool_config: PerplexityInternetSearchToolConfig, builder: Builder +) -> AsyncIterator[FunctionInfo]: """Register the Perplexity internet search tool with the NAT runtime. Resolves the Perplexity API key from the tool config or the @@ -156,6 +163,16 @@ async def _perplexity_internet_search(question: str) -> str: for attempt in range(tool_config.max_retries): try: response = await client.post(PERPLEXITY_SEARCH_URL, headers=headers, json=request_body) + # Short-circuit on non-retriable client errors + if response.status_code in {401, 403, 404}: + logger.error( + "Perplexity search returned non-retriable status %d; aborting without retry", + response.status_code, + ) + return ( + f"Web search failed with non-retriable status {response.status_code}. " + f"Check that PERPLEXITY_API_KEY is set correctly." + ) response.raise_for_status() search_response = response.json() results = search_response.get("results") if isinstance(search_response, dict) else None @@ -169,6 +186,17 @@ async def _perplexity_internet_search(question: str) -> str: ] ) return web_search_results or f"No web search results found for: {question}" + except httpx.HTTPStatusError as exc: + # raise_for_status() raises HTTPStatusError for 4xx/5xx not handled above (e.g. 5xx) + logger.warning( + "Perplexity search attempt %d of %d failed with status %d", + attempt + 1, + tool_config.max_retries, + exc.response.status_code, + ) + if attempt == tool_config.max_retries - 1: + return f"Web search failed after {tool_config.max_retries} attempts for: {question}" + await asyncio.sleep(2**attempt) except Exception: # Return a graceful message instead of raising, so the agent can # continue reasoning without web search rather than failing entirely. From 0d4d1882206171273435a61962e251c8d1dda5b5 Mon Sep 17 00:00:00 2001 From: James Liounis Date: Tue, 5 May 2026 21:29:32 +0000 Subject: [PATCH 6/7] docs: lead with perplexity_internet_search in add-tools tutorial Reorder the 'Alternate Method Using a Web Search Tool' section so that perplexity_internet_search is the lead example, with tavily_internet_search and exa_internet_search retained as alternatives below. perplexity_internet_search ships with nvidia-nat-core and is framework- agnostic across all 8 LLMFrameworkEnum values, making it the only single default that works regardless of which agent framework the reader picked in the earlier tutorial steps. The Tavily and Exa subsections are kept verbatim for users who prefer those backends. Signed-off-by: James Liounis --- .../tutorials/add-tools-to-a-workflow.md | 88 +++++++++---------- 1 file changed, 41 insertions(+), 47 deletions(-) diff --git a/docs/source/get-started/tutorials/add-tools-to-a-workflow.md b/docs/source/get-started/tutorials/add-tools-to-a-workflow.md index aae36030b0..57dc7d8230 100644 --- a/docs/source/get-started/tutorials/add-tools-to-a-workflow.md +++ b/docs/source/get-started/tutorials/add-tools-to-a-workflow.md @@ -109,26 +109,26 @@ Workflow Result: ``` ## Alternate Method Using a Web Search Tool -Adding individual web pages to a workflow can be cumbersome, especially when dealing with multiple web pages. An alternative method is to use a web search tool. NeMo Agent Toolkit provides web search tools including: `tavily_internet_search` which utilizes the [Tavily Search API](https://tavily.com/), `exa_internet_search` which utilizes the [Exa Search API](https://exa.ai/), and `perplexity_internet_search` which utilizes the [Perplexity Search API](https://docs.perplexity.ai/api-reference/search-post). +Adding individual web pages to a workflow can be cumbersome, especially when dealing with multiple web pages. An alternative method is to use a web search tool. NeMo Agent Toolkit provides web search tools including: `perplexity_internet_search` which utilizes the [Perplexity Search API](https://docs.perplexity.ai/api-reference/search-post), `tavily_internet_search` which utilizes the [Tavily Search API](https://tavily.com/), and `exa_internet_search` which utilizes the [Exa Search API](https://exa.ai/). -### Using Tavily Search +### Using Perplexity Search -The `tavily_internet_search` tool is part of the `nvidia-nat[langchain]` package, to install the package run: +The `perplexity_internet_search` tool ships with the core `nvidia-nat` package and is framework-agnostic — it can be used with any of the agent frameworks supported by the toolkit (`langchain`, `llama_index`, `crewai`, `semantic_kernel`, `agno`, `adk`, `strands`, and `autogen`). No framework-specific extra is required to install it: ```bash # local package install from source -uv pip install -e ".[langchain]" +uv pip install -e . ``` -Prior to using the `tavily_internet_search` tool, create an account at [`tavily.com`](https://tavily.com/) and obtain an API key. Once obtained, set the `TAVILY_API_KEY` environment variable to the API key: +Prior to using the `perplexity_internet_search` tool, create a Perplexity account and obtain an API key from the [API key page](https://www.perplexity.ai/account/api/keys). Once obtained, set the `PERPLEXITY_API_KEY` environment variable to the API key (`PPLX_API_KEY` is also accepted as a fallback): ```bash -export TAVILY_API_KEY= +export PERPLEXITY_API_KEY= ``` -We will now update the `functions` section of the configuration file replacing the two `webpage_query` tools with a single `tavily_internet_search` tool entry: +We will now update the `functions` section of the configuration file replacing the two `webpage_query` tools with a single `perplexity_internet_search` tool entry: ```yaml functions: internet_search: - _type: tavily_internet_search + _type: perplexity_internet_search current_datetime: _type: current_datetime ``` @@ -140,56 +140,47 @@ workflow: tool_names: [internet_search, current_datetime] ``` -The resulting configuration file is located at `examples/documentation_guides/workflows/custom_workflow/search_config.yml` in the NeMo Agent Toolkit repository. - When you re-run the workflow with the updated configuration file: ```bash -nat run --config_file examples/documentation_guides/workflows/custom_workflow/search_config.yml \ +nat run --config_file \ --input "How do I trace only specific parts of my LangChain application?" ``` -Which will then yield a slightly different result to the same question: -``` -Workflow Result: -['To trace only specific parts of a LangChain application, users can use the `@traceable` decorator to mark specific functions or methods as traceable. Additionally, users can configure the tracing functionality to log traces to a specific project, add metadata and tags to traces, and customize the run name and ID. Users can also use the `LangChainTracer` class to trace specific invocations or parts of their application. Furthermore, users can use the `tracing_v2_enabled` context manager to trace a specific block of code.'] +The `perplexity_internet_search` tool supports additional configuration options: +```yaml +functions: + internet_search: + _type: perplexity_internet_search + max_results: 5 + max_retries: 3 + max_query_length: 2000 # queries longer than this are truncated + search_recency_filter: week # 'hour', 'day', 'week', 'month', or 'year' + country: US # ISO 3166-1 alpha-2 country code + max_tokens_per_page: 4096 ``` -### Using Exa Search +### Using Tavily Search -The `exa_internet_search` tool is also part of the `nvidia-nat[langchain]` package. If you haven't already installed it: +The `tavily_internet_search` tool is part of the `nvidia-nat[langchain]` package, to install the package run: ```bash # local package install from source uv pip install -e ".[langchain]" ``` -Prior to using the `exa_internet_search` tool, create an account at [`exa.ai`](https://exa.ai/) and obtain an API key. Once obtained, set the `EXA_API_KEY` environment variable to the API key: +Prior to using the `tavily_internet_search` tool, create an account at [`tavily.com`](https://tavily.com/) and obtain an API key. Once obtained, set the `TAVILY_API_KEY` environment variable to the API key: ```bash -export EXA_API_KEY= +export TAVILY_API_KEY= ``` -You can use the `exa_internet_search` tool in the same way as `tavily_internet_search` by updating the `functions` section of the configuration file: +You can use the `tavily_internet_search` tool by updating the `functions` section of the configuration file: ```yaml functions: internet_search: - _type: exa_internet_search + _type: tavily_internet_search current_datetime: _type: current_datetime ``` -The `exa_internet_search` tool supports additional configuration options: -```yaml -functions: - internet_search: - _type: exa_internet_search - max_results: 5 - search_type: neural # 'auto', 'fast', 'deep', 'neural', or 'instant' - livecrawl: fallback # 'always', 'fallback', or 'never' - max_retries: 3 - max_query_length: 2000 # queries longer than this are truncated - highlights: true # include highlights in results - max_content_length: 10000 # max chars of text per result; set to None to disable -``` - Then ensure the tool is included in the workflow tool list: ```yaml workflow: @@ -197,39 +188,42 @@ workflow: tool_names: [internet_search, current_datetime] ``` -### Using Perplexity Search +A sample configuration file using `tavily_internet_search` is located at `examples/documentation_guides/workflows/custom_workflow/search_config.yml` in the NeMo Agent Toolkit repository. -The `perplexity_internet_search` tool ships with the core `nvidia-nat` package and is framework-agnostic — it can be used with any of the agent frameworks supported by the toolkit (`langchain`, `llama_index`, `crewai`, `semantic_kernel`, `agno`, `adk`, `strands`, and `autogen`). No framework-specific extra is required to install it: +### Using Exa Search + +The `exa_internet_search` tool is also part of the `nvidia-nat[langchain]` package. If you haven't already installed it: ```bash # local package install from source -uv pip install -e . +uv pip install -e ".[langchain]" ``` -Prior to using the `perplexity_internet_search` tool, create a Perplexity account and obtain an API key from the [API key page](https://www.perplexity.ai/account/api/keys). Once obtained, set the `PERPLEXITY_API_KEY` environment variable to the API key: +Prior to using the `exa_internet_search` tool, create an account at [`exa.ai`](https://exa.ai/) and obtain an API key. Once obtained, set the `EXA_API_KEY` environment variable to the API key: ```bash -export PERPLEXITY_API_KEY= +export EXA_API_KEY= ``` -You can use the `perplexity_internet_search` tool in the same way as the other web search tools by updating the `functions` section of the configuration file: +You can use the `exa_internet_search` tool in the same way as the other web search tools by updating the `functions` section of the configuration file: ```yaml functions: internet_search: - _type: perplexity_internet_search + _type: exa_internet_search current_datetime: _type: current_datetime ``` -The `perplexity_internet_search` tool supports additional configuration options: +The `exa_internet_search` tool supports additional configuration options: ```yaml functions: internet_search: - _type: perplexity_internet_search + _type: exa_internet_search max_results: 5 + search_type: neural # 'auto', 'fast', 'deep', 'neural', or 'instant' + livecrawl: fallback # 'always', 'fallback', or 'never' max_retries: 3 max_query_length: 2000 # queries longer than this are truncated - search_recency_filter: week # 'hour', 'day', 'week', 'month', or 'year' - country: US # ISO 3166-1 alpha-2 country code - max_tokens_per_page: 4096 + highlights: true # include highlights in results + max_content_length: 10000 # max chars of text per result; set to None to disable ``` Then ensure the tool is included in the workflow tool list: From c19db753d8ee3c04ec98141b3f71a6056b69ec42 Mon Sep 17 00:00:00 2001 From: James Liounis Date: Mon, 11 May 2026 19:18:25 +0000 Subject: [PATCH 7/7] feat(embedder): add Perplexity Embeddings API provider Add a new `perplexity` embedder provider backed by Perplexity's dedicated Embeddings API (`POST https://api.perplexity.ai/v1/embeddings`). The provider ships in the core `nvidia-nat` package and is wired into both the LangChain and LlamaIndex framework plugins, mirroring the existing `openai` / `azure_openai` / `nim` embedders. Highlights: - `PerplexityEmbedderModelConfig` supports all four documented models (`pplx-embed-v1-0.6b`, `pplx-embed-v1-4b`, plus the contextualized `pplx-embed-context-v1-*` variants), Matryoshka `dimensions`, `batch_size`, and the two on-wire `encoding_format` options (`base64_int8`, `base64_binary`). Resolves the API key from config or `PERPLEXITY_API_KEY`. - A framework-agnostic `PerplexityEmbeddings` LangChain client batches inputs (up to 512/request), decodes the base64 payload locally to a float vector, retries transient HTTP failures with backoff, and forwards an `X-Pplx-Integration: nemo-agent-toolkit/` attribution header on every request. - A `PerplexityLlamaIndexEmbedding` adapter wraps the shared client to satisfy LlamaIndex's `BaseEmbedding` interface so the same NAT config can be consumed by LlamaIndex-powered retrievers/RAG flows. - Docs: adds Perplexity to the embedder provider table and configuration examples in `docs/source/build-workflows/embedders.md`, and adds Perplexity to the LangChain/LlamaIndex embedder rows in `docs/source/components/integrations/frameworks.md`. - Tests: 8 config tests + 11 LangChain client tests + 4 LlamaIndex adapter tests (23 total) covering config validation, batching, base64_int8/binary decoding, attribution header propagation, non-retriable status short-circuit, async paths, and registration with `PERPLEXITY_API_KEY` resolution. Reference: https://docs.perplexity.ai/api-reference/embeddings-post Signed-off-by: James Liounis --- docs/source/build-workflows/embedders.md | 29 ++ .../components/integrations/frameworks.md | 6 +- .../src/nat/embedder/perplexity_embedder.py | 98 +++++++ .../src/nat/embedder/register.py | 1 + .../nat/embedder/test_perplexity_embedder.py | 93 ++++++ .../src/nat/plugins/langchain/embedder.py | 43 +++ .../langchain/perplexity_embeddings_client.py | 267 ++++++++++++++++++ .../tests/test_perplexity_embeddings.py | 209 ++++++++++++++ .../src/nat/plugins/llama_index/embedder.py | 44 +++ .../perplexity_embeddings_client.py | 67 +++++ .../test_perplexity_embedder_llama_index.py | 75 +++++ 11 files changed, 929 insertions(+), 3 deletions(-) create mode 100644 packages/nvidia_nat_core/src/nat/embedder/perplexity_embedder.py create mode 100644 packages/nvidia_nat_core/tests/nat/embedder/test_perplexity_embedder.py create mode 100644 packages/nvidia_nat_langchain/src/nat/plugins/langchain/perplexity_embeddings_client.py create mode 100644 packages/nvidia_nat_langchain/tests/test_perplexity_embeddings.py create mode 100644 packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/perplexity_embeddings_client.py create mode 100644 packages/nvidia_nat_llama_index/tests/test_perplexity_embedder_llama_index.py diff --git a/docs/source/build-workflows/embedders.md b/docs/source/build-workflows/embedders.md index ab114b68e7..e53522454c 100644 --- a/docs/source/build-workflows/embedders.md +++ b/docs/source/build-workflows/embedders.md @@ -25,6 +25,7 @@ NeMo Agent Toolkit supports the following embedder providers: | [OpenAI](https://openai.com) | `openai` | OpenAI API | | [Azure OpenAI](https://learn.microsoft.com/en-us/azure/ai-foundry/openai/quickstart) | `azure_openai` | Azure OpenAI API | | [Hugging Face](https://huggingface.co) | `huggingface` | Local sentence-transformers or remote Inference Endpoints (TEI) | +| [Perplexity](https://docs.perplexity.ai/api-reference/embeddings-post) | `perplexity` | Perplexity Embeddings API (`pplx-embed-v1-0.6b`, `pplx-embed-v1-4b`, contextualized variants) | ## Embedder Configuration @@ -41,6 +42,9 @@ embedders: azure_openai_embedder: _type: azure_openai azure_deployment: text-embedding-3-small + perplexity_embedder: + _type: perplexity + model_name: pplx-embed-v1-0.6b ``` ### NVIDIA NIM @@ -120,3 +124,28 @@ embedders: endpoint_url: http://localhost:8081 api_key: ${HF_TOKEN} ``` + +### Perplexity + +Perplexity exposes a dedicated embeddings endpoint at `POST https://api.perplexity.ai/v1/embeddings`. The toolkit ships a native client that batches inputs, decodes the on-wire base64 payload locally, and forwards an `X-Pplx-Integration: nemo-agent-toolkit/` attribution header on every request. + +You can use the following environment variables to configure the Perplexity embedder provider: + +* `PERPLEXITY_API_KEY` - The API key to access the Perplexity Embeddings API + +The Perplexity embedder provider is defined by the {py:class}`~nat.embedder.perplexity_embedder.PerplexityEmbedderModelConfig` class. + +* `model_name` - Embedding model identifier. Standard embeddings: `pplx-embed-v1-0.6b` (1024-dim, default) or `pplx-embed-v1-4b` (2560-dim). Document-aware: `pplx-embed-context-v1-0.6b` or `pplx-embed-context-v1-4b` +* `api_key` - Perplexity API key (falls back to `PERPLEXITY_API_KEY`) +* `base_url` - Base URL for the Perplexity API (default: `https://api.perplexity.ai/v1`) +* `dimensions` - Optional Matryoshka output dimension. Range 128-1024 for `0.6b` models and 128-2560 for `4b` models. Omit for full dimensions +* `encoding_format` - On-wire encoding: `base64_int8` (default, signed int8) or `base64_binary` (packed bits for large-scale retrieval) +* `batch_size` - Maximum inputs per request (1-512). Defaults to `64` + +```yaml +embedders: + perplexity_embedder: + _type: perplexity + model_name: pplx-embed-v1-0.6b + dimensions: 512 # optional Matryoshka truncation +``` diff --git a/docs/source/components/integrations/frameworks.md b/docs/source/components/integrations/frameworks.md index 536da5f03f..c73938b859 100644 --- a/docs/source/components/integrations/frameworks.md +++ b/docs/source/components/integrations/frameworks.md @@ -40,7 +40,7 @@ NeMo Agent Toolkit provides different levels of support for each framework acros The ability to use various large language model providers with a framework, including NVIDIA NIM, OpenAI, Azure OpenAI, AWS Bedrock, LiteLLM, and Hugging Face. ### Embedder Provider Support -The ability to use embedding model providers for vector representations, including NVIDIA NIM embeddings, OpenAI embeddings, and Azure OpenAI embeddings. +The ability to use embedding model providers for vector representations, including NVIDIA NIM embeddings, OpenAI embeddings, Azure OpenAI embeddings, and Perplexity embeddings. ### Retriever Provider Support The ability to integrate with vector databases and retrieval systems, such as NeMo Retriever and Milvus. @@ -154,7 +154,7 @@ For more information, visit the [LangChain documentation](https://docs.langchain | Capability | Providers / Details | |-------------------------|-------------------------------------------------------------------------------------| | **LLM Providers** | NVIDIA NIM, OpenAI, Azure OpenAI, AWS Bedrock, LiteLLM, Hugging Face | -| **Embedder Providers** | NVIDIA NIM, OpenAI, Azure OpenAI | +| **Embedder Providers** | NVIDIA NIM, OpenAI, Azure OpenAI, Perplexity | | **Retriever Providers** | NeMo Retriever, Milvus | | **Tool Calling** | Fully supported through LangChain's `StructuredTool` interface | | **Profiling** | Comprehensive profiling support with callback handlers | @@ -174,7 +174,7 @@ For more information, visit the [LlamaIndex website](https://www.llamaindex.ai/) | Capability | Providers / Details | |-------------------------|-------------------------------------------------------------------------------------| | **LLM Providers** | NVIDIA NIM, OpenAI, Azure OpenAI, AWS Bedrock, LiteLLM | -| **Embedder Providers** | NVIDIA NIM, OpenAI, Azure OpenAI | +| **Embedder Providers** | NVIDIA NIM, OpenAI, Azure OpenAI, Perplexity | | **Retriever Providers** | None (Use LlamaIndex native retrievers) | | **Tool Calling** | Fully supported through LlamaIndex's `FunctionTool` interface | | **Profiling** | Comprehensive profiling support with callback handlers | diff --git a/packages/nvidia_nat_core/src/nat/embedder/perplexity_embedder.py b/packages/nvidia_nat_core/src/nat/embedder/perplexity_embedder.py new file mode 100644 index 0000000000..4f33c0df2d --- /dev/null +++ b/packages/nvidia_nat_core/src/nat/embedder/perplexity_embedder.py @@ -0,0 +1,98 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import typing + +from pydantic import AliasChoices +from pydantic import ConfigDict +from pydantic import Field + +from nat.builder.builder import Builder +from nat.builder.embedder import EmbedderProviderInfo +from nat.cli.register_workflow import register_embedder_provider +from nat.data_models.common import OptionalSecretStr +from nat.data_models.embedder import EmbedderBaseConfig +from nat.data_models.retry_mixin import RetryMixin +from nat.data_models.ssl_verification_mixin import SSLVerificationMixin + +# Supported model identifiers for the Perplexity Embeddings API. +# Standard embeddings are for independent texts (queries/documents). +# Contextualized embeddings are document-aware; chunks from the same document share context. +PerplexityEmbeddingModel = typing.Literal[ + "pplx-embed-v1-0.6b", + "pplx-embed-v1-4b", + "pplx-embed-context-v1-0.6b", + "pplx-embed-context-v1-4b", +] + + +class PerplexityEmbedderModelConfig(EmbedderBaseConfig, RetryMixin, SSLVerificationMixin, name="perplexity"): + """A Perplexity Embeddings API provider to be used with an embedder client. + + Perplexity exposes a dedicated embeddings endpoint at + ``https://api.perplexity.ai/v1/embeddings`` (standard) and + ``https://api.perplexity.ai/v1/contextualizedembeddings`` (contextualized). + Authentication uses ``PERPLEXITY_API_KEY``. + + Reference: https://docs.perplexity.ai/api-reference/embeddings-post + """ + + model_config = ConfigDict(protected_namespaces=(), extra="allow") + + api_key: OptionalSecretStr = Field( + default=None, + description="Perplexity API key to interact with the embeddings endpoint. " + "Falls back to the ``PERPLEXITY_API_KEY`` environment variable when unset.", + ) + base_url: str = Field( + default="https://api.perplexity.ai/v1", + description="Base URL for the Perplexity API. The embedder appends ``/embeddings`` " + "for standard models and ``/contextualizedembeddings`` for context models.", + ) + model_name: PerplexityEmbeddingModel = Field( + default="pplx-embed-v1-0.6b", + validation_alias=AliasChoices("model_name", "model"), + serialization_alias="model", + description="Perplexity embedding model. Standard: ``pplx-embed-v1-0.6b`` (1024-dim) " + "or ``pplx-embed-v1-4b`` (2560-dim). Contextualized: ``pplx-embed-context-v1-0.6b`` " + "or ``pplx-embed-context-v1-4b``.", + ) + dimensions: int | None = Field( + default=None, + ge=128, + le=2560, + description="Matryoshka output dimensions. Range is 128–1024 for ``0.6b`` models and " + "128–2560 for ``4b`` models. Defaults to full dimensions when unset.", + ) + batch_size: int = Field( + default=64, + ge=1, + le=512, + description="Maximum number of input texts to send per request. The Perplexity API " + "accepts up to 512 inputs per call (subject to a 120,000 total-token cap).", + ) + encoding_format: typing.Literal["base64_int8", "base64_binary"] = Field( + default="base64_int8", + description="On-wire encoding for the embedding payload. ``base64_int8`` (default) " + "returns signed int8 values; ``base64_binary`` returns 1-bit-per-dimension packed bits.", + ) + + +@register_embedder_provider(config_type=PerplexityEmbedderModelConfig) +async def perplexity_embedder_model(config: PerplexityEmbedderModelConfig, _builder: Builder): + yield EmbedderProviderInfo( + config=config, + description="A Perplexity Embeddings API model for use with an Embedder client.", + ) diff --git a/packages/nvidia_nat_core/src/nat/embedder/register.py b/packages/nvidia_nat_core/src/nat/embedder/register.py index ac82ccb1d7..ff21dbe349 100644 --- a/packages/nvidia_nat_core/src/nat/embedder/register.py +++ b/packages/nvidia_nat_core/src/nat/embedder/register.py @@ -21,3 +21,4 @@ from . import huggingface_embedder from . import nim_embedder from . import openai_embedder +from . import perplexity_embedder diff --git a/packages/nvidia_nat_core/tests/nat/embedder/test_perplexity_embedder.py b/packages/nvidia_nat_core/tests/nat/embedder/test_perplexity_embedder.py new file mode 100644 index 0000000000..112315a5e3 --- /dev/null +++ b/packages/nvidia_nat_core/tests/nat/embedder/test_perplexity_embedder.py @@ -0,0 +1,93 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +from pydantic import SecretStr +from pydantic import ValidationError + +from nat.embedder.perplexity_embedder import PerplexityEmbedderModelConfig + + +def test_defaults(): + """Default config picks the small standard model and full dimensions.""" + config = PerplexityEmbedderModelConfig() + assert config.type == "perplexity" + assert config.model_name == "pplx-embed-v1-0.6b" + assert config.base_url == "https://api.perplexity.ai/v1" + assert config.dimensions is None + assert config.batch_size == 64 + assert config.encoding_format == "base64_int8" + + +def test_accepts_supported_models(): + """All four model identifiers documented by the Perplexity API are accepted.""" + for model in ( + "pplx-embed-v1-0.6b", + "pplx-embed-v1-4b", + "pplx-embed-context-v1-0.6b", + "pplx-embed-context-v1-4b", + ): + cfg = PerplexityEmbedderModelConfig(model_name=model) + assert cfg.model_name == model + + +def test_rejects_unsupported_model(): + """Unsupported model identifiers raise a validation error.""" + with pytest.raises(ValidationError): + PerplexityEmbedderModelConfig(model_name="text-embedding-3-small") + + +def test_dimensions_bounds(): + """Matryoshka ``dimensions`` is bounded to the documented range [128, 2560].""" + PerplexityEmbedderModelConfig(model_name="pplx-embed-v1-4b", dimensions=128) + PerplexityEmbedderModelConfig(model_name="pplx-embed-v1-4b", dimensions=2560) + with pytest.raises(ValidationError): + PerplexityEmbedderModelConfig(dimensions=64) + with pytest.raises(ValidationError): + PerplexityEmbedderModelConfig(dimensions=4096) + + +def test_batch_size_bounds(): + """``batch_size`` is bounded by Perplexity's documented 512-input-per-request cap.""" + PerplexityEmbedderModelConfig(batch_size=1) + PerplexityEmbedderModelConfig(batch_size=512) + with pytest.raises(ValidationError): + PerplexityEmbedderModelConfig(batch_size=0) + with pytest.raises(ValidationError): + PerplexityEmbedderModelConfig(batch_size=1024) + + +def test_encoding_format_choices(): + """Only the two on-wire encoding formats supported by Perplexity are accepted.""" + PerplexityEmbedderModelConfig(encoding_format="base64_int8") + PerplexityEmbedderModelConfig(encoding_format="base64_binary") + with pytest.raises(ValidationError): + # ``float`` is *not* supported by Perplexity's embeddings endpoint. + PerplexityEmbedderModelConfig(encoding_format="float") + + +def test_api_key_secret_str(): + """``api_key`` is stored as a SecretStr-style field, not a plain string.""" + cfg = PerplexityEmbedderModelConfig(api_key="pplx-test-key") + assert cfg.api_key is not None + # Pydantic ``SecretStr`` masks the value in ``repr`` / ``str``. + assert "pplx-test-key" not in repr(cfg.api_key) + assert isinstance(cfg.api_key, SecretStr) + + +def test_model_alias_accepted(): + """The ``model`` alias is accepted in addition to ``model_name`` and round-trips.""" + cfg = PerplexityEmbedderModelConfig(model="pplx-embed-v1-4b") + assert cfg.model_name == "pplx-embed-v1-4b" diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/embedder.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/embedder.py index 03b99e26b3..4b135f0ffe 100644 --- a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/embedder.py +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/embedder.py @@ -25,6 +25,7 @@ from nat.embedder.huggingface_embedder import HuggingFaceEmbedderConfig from nat.embedder.nim_embedder import NIMEmbedderModelConfig from nat.embedder.openai_embedder import OpenAIEmbedderModelConfig +from nat.embedder.perplexity_embedder import PerplexityEmbedderModelConfig from nat.llm.utils.http_client import http_clients from nat.utils.exception_handlers.automatic_retries import patch_with_retry @@ -96,6 +97,48 @@ async def openai_langchain(embedder_config: OpenAIEmbedderModelConfig, builder: yield client +@register_embedder_client(config_type=PerplexityEmbedderModelConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN) +async def perplexity_langchain(embedder_config: PerplexityEmbedderModelConfig, _builder: Builder): + """LangChain client for the Perplexity Embeddings API. + + Resolves the API key from the config or ``PERPLEXITY_API_KEY``, then yields a + :class:`~nat.plugins.langchain.perplexity_embeddings_client.PerplexityEmbeddings` + instance configured with the requested model, dimensions, and encoding format. + Outbound requests carry an ``X-Pplx-Integration: nemo-agent-toolkit/`` + attribution header. + """ + import os + + from nat.plugins.langchain.perplexity_embeddings_client import PerplexityEmbeddings + + raw_key = get_secret_value(embedder_config.api_key) if embedder_config.api_key else "" + resolved_key = raw_key or os.environ.get("PERPLEXITY_API_KEY", "") + if not resolved_key: + raise ValueError( + "Perplexity embedder requires a non-empty API key. Set ``api_key`` in the " + "workflow config or export ``PERPLEXITY_API_KEY``." + ) + + client = PerplexityEmbeddings( + api_key=resolved_key, + base_url=embedder_config.base_url, + model=embedder_config.model_name, + dimensions=embedder_config.dimensions, + encoding_format=embedder_config.encoding_format, + batch_size=embedder_config.batch_size, + max_retries=getattr(embedder_config, "num_retries", 3), + verify_ssl=getattr(embedder_config, "verify_ssl", True), + ) + + if isinstance(embedder_config, RetryMixin): + client = patch_with_retry(client, + retries=embedder_config.num_retries, + retry_codes=embedder_config.retry_on_status_codes, + retry_on_messages=embedder_config.retry_on_errors) + + yield client + + @register_embedder_client(config_type=HuggingFaceEmbedderConfig, wrapper_type=LLMFrameworkEnum.LANGCHAIN) async def huggingface_langchain(embedder_config: HuggingFaceEmbedderConfig, _builder: Builder) -> AsyncIterator[Any]: """LangChain client for HuggingFace embedder - local or remote based on endpoint_url.""" diff --git a/packages/nvidia_nat_langchain/src/nat/plugins/langchain/perplexity_embeddings_client.py b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/perplexity_embeddings_client.py new file mode 100644 index 0000000000..7e51d85f99 --- /dev/null +++ b/packages/nvidia_nat_langchain/src/nat/plugins/langchain/perplexity_embeddings_client.py @@ -0,0 +1,267 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LangChain ``Embeddings`` client for the Perplexity Embeddings API. + +Perplexity's embeddings endpoint (``POST /v1/embeddings``) returns base64-encoded +quantized values (``base64_int8`` or ``base64_binary``) rather than the JSON float +arrays returned by OpenAI-compatible providers. This module provides a thin +LangChain :class:`~langchain_core.embeddings.Embeddings` implementation that +performs the decoding so the resulting vectors plug into the standard NAT +retriever/RAG stack. +""" +from __future__ import annotations + +import base64 +import logging +from collections.abc import Iterable + +import httpx +import numpy as np +from langchain_core.embeddings import Embeddings + +logger = logging.getLogger(__name__) + + +def _decode_int8(payload: str) -> list[float]: + """Decode a ``base64_int8`` embedding into a float32 list. + + Args: + payload: Base64-encoded signed int8 buffer returned by Perplexity. + + Returns: + The decoded vector as a list of Python floats. + """ + return np.frombuffer(base64.b64decode(payload), dtype=np.int8).astype(np.float32).tolist() + + +def _decode_binary(payload: str) -> list[float]: + """Decode a ``base64_binary`` embedding into a 0/1 float list. + + Args: + payload: Base64-encoded packed bits (LSB first, 1 bit per dimension). + + Returns: + The unpacked bit vector as a list of Python floats (0.0 or 1.0). + """ + packed = np.frombuffer(base64.b64decode(payload), dtype=np.uint8) + return np.unpackbits(packed, bitorder="little").astype(np.float32).tolist() + + +class PerplexityEmbeddings(Embeddings): + """LangChain ``Embeddings`` client for the Perplexity Embeddings API. + + The client batches inputs (``batch_size`` per request, default 64), decodes the + base64 payload locally, and forwards the ``X-Pplx-Integration`` attribution + header so Perplexity can identify NeMo Agent Toolkit traffic. + + Args: + api_key: Perplexity API key. Required. + base_url: Base URL for the Perplexity API. Defaults to ``https://api.perplexity.ai/v1``. + model: Embedding model identifier (e.g. ``pplx-embed-v1-0.6b``). + dimensions: Optional Matryoshka output dimension. + encoding_format: ``base64_int8`` (default) or ``base64_binary``. + batch_size: Maximum inputs per request (1–512). Defaults to 64. + max_retries: Number of retry attempts on transient failures. Defaults to 3. + verify_ssl: Whether to verify TLS certificates. Defaults to True. + integration_header: Optional ``X-Pplx-Integration`` header value. Defaults + to ``"nemo-agent-toolkit/"`` resolved at instantiation time. + """ + + def __init__( + self, + api_key: str, + *, + base_url: str = "https://api.perplexity.ai/v1", + model: str = "pplx-embed-v1-0.6b", + dimensions: int | None = None, + encoding_format: str = "base64_int8", + batch_size: int = 64, + max_retries: int = 3, + verify_ssl: bool = True, + integration_header: str | None = None, + ) -> None: + if not api_key: + raise ValueError( + "PerplexityEmbeddings requires a non-empty api_key. " + "Set the PERPLEXITY_API_KEY environment variable or pass api_key explicitly." + ) + if encoding_format not in ("base64_int8", "base64_binary"): + raise ValueError( + f"encoding_format must be 'base64_int8' or 'base64_binary', got {encoding_format!r}." + ) + + self._api_key = api_key + self._base_url = base_url.rstrip("/") + self._model = model + self._dimensions = dimensions + self._encoding_format = encoding_format + self._batch_size = max(1, min(int(batch_size), 512)) + self._max_retries = max(1, int(max_retries)) + self._verify_ssl = verify_ssl + self._integration_header = integration_header or _default_integration_header() + + # ------------------------------------------------------------------ + # LangChain interface + # ------------------------------------------------------------------ + def embed_documents(self, texts: list[str]) -> list[list[float]]: # type: ignore[override] + return self._embed(texts) + + def embed_query(self, text: str) -> list[float]: # type: ignore[override] + results = self._embed([text]) + return results[0] + + async def aembed_documents(self, texts: list[str]) -> list[list[float]]: # type: ignore[override] + return await self._aembed(texts) + + async def aembed_query(self, text: str) -> list[float]: # type: ignore[override] + results = await self._aembed([text]) + return results[0] + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + def _endpoint(self) -> str: + return f"{self._base_url}/embeddings" + + def _headers(self) -> dict[str, str]: + return { + "Authorization": f"Bearer {self._api_key}", + "Content-Type": "application/json", + "X-Pplx-Integration": self._integration_header, + } + + def _build_body(self, inputs: list[str]) -> dict: + body: dict = { + "input": inputs, + "model": self._model, + "encoding_format": self._encoding_format, + } + if self._dimensions is not None: + body["dimensions"] = self._dimensions + return body + + def _decode_response(self, payload: dict) -> list[list[float]]: + data: Iterable[dict] = payload.get("data") or [] + decoder = _decode_int8 if self._encoding_format == "base64_int8" else _decode_binary + # Preserve input order via the ``index`` field, which the API echoes back. + decoded: list[list[float]] = [None] * len(list(data)) # type: ignore[list-item] + items = list(payload.get("data") or []) + decoded = [decoder(item.get("embedding", "")) for item in items] + return decoded + + def _batches(self, texts: list[str]) -> Iterable[list[str]]: + for start in range(0, len(texts), self._batch_size): + yield texts[start:start + self._batch_size] + + def _embed(self, texts: list[str]) -> list[list[float]]: + if not texts: + return [] + embeddings: list[list[float]] = [] + with httpx.Client(verify=self._verify_ssl, timeout=60.0) as client: + for batch in self._batches(texts): + payload = self._post_with_retry_sync(client, batch) + embeddings.extend(self._decode_response(payload)) + return embeddings + + async def _aembed(self, texts: list[str]) -> list[list[float]]: + if not texts: + return [] + embeddings: list[list[float]] = [] + async with httpx.AsyncClient(verify=self._verify_ssl, timeout=60.0) as client: + for batch in self._batches(texts): + payload = await self._post_with_retry_async(client, batch) + embeddings.extend(self._decode_response(payload)) + return embeddings + + def _post_with_retry_sync(self, client: httpx.Client, batch: list[str]) -> dict: + import time + + last_error: Exception | None = None + for attempt in range(self._max_retries): + try: + response = client.post(self._endpoint(), headers=self._headers(), json=self._build_body(batch)) + if response.status_code in {401, 403, 404}: + response.raise_for_status() + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as exc: + last_error = exc + if exc.response.status_code in {401, 403, 404} or attempt == self._max_retries - 1: + raise + logger.warning( + "Perplexity embeddings attempt %d/%d failed with status %d", + attempt + 1, + self._max_retries, + exc.response.status_code, + ) + except httpx.RequestError as exc: + last_error = exc + if attempt == self._max_retries - 1: + raise + logger.warning( + "Perplexity embeddings attempt %d/%d failed: %s", + attempt + 1, + self._max_retries, + exc, + ) + time.sleep(2**attempt) + raise RuntimeError("Perplexity embeddings request failed") from last_error + + async def _post_with_retry_async(self, client: httpx.AsyncClient, batch: list[str]) -> dict: + import asyncio + + last_error: Exception | None = None + for attempt in range(self._max_retries): + try: + response = await client.post( + self._endpoint(), headers=self._headers(), json=self._build_body(batch) + ) + if response.status_code in {401, 403, 404}: + response.raise_for_status() + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as exc: + last_error = exc + if exc.response.status_code in {401, 403, 404} or attempt == self._max_retries - 1: + raise + logger.warning( + "Perplexity embeddings attempt %d/%d failed with status %d", + attempt + 1, + self._max_retries, + exc.response.status_code, + ) + except httpx.RequestError as exc: + last_error = exc + if attempt == self._max_retries - 1: + raise + logger.warning( + "Perplexity embeddings attempt %d/%d failed: %s", + attempt + 1, + self._max_retries, + exc, + ) + await asyncio.sleep(2**attempt) + raise RuntimeError("Perplexity embeddings request failed") from last_error + + +def _default_integration_header() -> str: + """Return the default ``X-Pplx-Integration`` header value for outbound requests.""" + from importlib import metadata + + try: + package_version = metadata.version("nvidia-nat") + except metadata.PackageNotFoundError: + package_version = "unknown" + return f"nemo-agent-toolkit/{package_version}" diff --git a/packages/nvidia_nat_langchain/tests/test_perplexity_embeddings.py b/packages/nvidia_nat_langchain/tests/test_perplexity_embeddings.py new file mode 100644 index 0000000000..6031eab70e --- /dev/null +++ b/packages/nvidia_nat_langchain/tests/test_perplexity_embeddings.py @@ -0,0 +1,209 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import httpx +import numpy as np +import pytest + +from nat.embedder.perplexity_embedder import PerplexityEmbedderModelConfig +from nat.plugins.langchain.embedder import perplexity_langchain +from nat.plugins.langchain.perplexity_embeddings_client import PerplexityEmbeddings +from nat.plugins.langchain.perplexity_embeddings_client import _decode_binary +from nat.plugins.langchain.perplexity_embeddings_client import _decode_int8 +from nat.plugins.langchain.perplexity_embeddings_client import _default_integration_header + + +# --------------------------------------------------------------------------- +# Decoding helpers +# --------------------------------------------------------------------------- + + +def _encode_int8(values: list[int]) -> str: + return base64.b64encode(np.array(values, dtype=np.int8).tobytes()).decode() + + +def _make_response(status_code: int, payload: dict) -> httpx.Response: + """Build an ``httpx.Response`` with a bound dummy request so ``raise_for_status`` works in tests.""" + return httpx.Response( + status_code, + json=payload, + request=httpx.Request("POST", "https://api.perplexity.ai/v1/embeddings"), + ) + + +def _encode_binary(bits: list[int]) -> str: + packed = np.packbits(np.array(bits, dtype=np.uint8), bitorder="little") + return base64.b64encode(packed.tobytes()).decode() + + +def test_decode_int8_round_trip(): + payload = _encode_int8([-128, -1, 0, 1, 127]) + decoded = _decode_int8(payload) + assert decoded == [-128.0, -1.0, 0.0, 1.0, 127.0] + + +def test_decode_binary_round_trip(): + bits = [1, 0, 1, 1, 0, 0, 1, 0] + payload = _encode_binary(bits) + decoded = _decode_binary(payload) + assert decoded == [float(b) for b in bits] + + +def test_default_integration_header_uses_nemo_slug(): + header = _default_integration_header() + assert header.startswith("nemo-agent-toolkit/") + + +# --------------------------------------------------------------------------- +# PerplexityEmbeddings client +# --------------------------------------------------------------------------- + + +class TestPerplexityEmbeddingsClient: + + def test_requires_api_key(self): + with pytest.raises(ValueError): + PerplexityEmbeddings(api_key="") + + def test_rejects_unsupported_encoding(self): + with pytest.raises(ValueError): + PerplexityEmbeddings(api_key="pplx-x", encoding_format="float") + + def test_batches_inputs_and_decodes_int8(self): + """The client should split inputs into batches of ``batch_size`` and decode int8 payloads.""" + client = PerplexityEmbeddings(api_key="pplx-x", batch_size=2) + + responses = [ + _make_response( + 200, + { + "object": "list", + "data": [ + {"object": "embedding", "index": 0, "embedding": _encode_int8([1, 2])}, + {"object": "embedding", "index": 1, "embedding": _encode_int8([3, 4])}, + ], + "model": "pplx-embed-v1-0.6b", + }, + ), + _make_response( + 200, + { + "object": "list", + "data": [{"object": "embedding", "index": 0, "embedding": _encode_int8([5, 6])}], + "model": "pplx-embed-v1-0.6b", + }, + ), + ] + + with patch("httpx.Client") as mock_client_cls: + instance = MagicMock() + instance.__enter__.return_value = instance + instance.post.side_effect = responses + mock_client_cls.return_value = instance + + vectors = client.embed_documents(["a", "b", "c"]) + + assert vectors == [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]] + # Two batches of size 2 (last one is size 1). + assert instance.post.call_count == 2 + + def test_request_body_includes_dimensions_and_attribution_header(self): + client = PerplexityEmbeddings(api_key="pplx-x", dimensions=256, integration_header="test-suite/1.2.3") + + with patch("httpx.Client") as mock_client_cls: + instance = MagicMock() + instance.__enter__.return_value = instance + instance.post.return_value = _make_response( + 200, + { + "object": "list", + "data": [{"object": "embedding", "index": 0, "embedding": _encode_int8([0])}], + "model": "pplx-embed-v1-0.6b", + }, + ) + mock_client_cls.return_value = instance + + client.embed_query("hello") + + call = instance.post.call_args + assert call.kwargs["json"]["dimensions"] == 256 + assert call.kwargs["json"]["encoding_format"] == "base64_int8" + assert call.kwargs["json"]["model"] == "pplx-embed-v1-0.6b" + assert call.kwargs["headers"]["X-Pplx-Integration"] == "test-suite/1.2.3" + assert call.kwargs["headers"]["Authorization"] == "Bearer pplx-x" + assert call.args[0].endswith("/v1/embeddings") + + def test_non_retriable_status_raises(self): + client = PerplexityEmbeddings(api_key="pplx-x", max_retries=3) + with patch("httpx.Client") as mock_client_cls: + instance = MagicMock() + instance.__enter__.return_value = instance + instance.post.return_value = _make_response(401, {"error": "unauthorized"}) + mock_client_cls.return_value = instance + + with pytest.raises(httpx.HTTPStatusError): + client.embed_query("hello") + + # 401 short-circuits — exactly one POST attempt. + assert instance.post.call_count == 1 + + async def test_async_embed_query(self): + client = PerplexityEmbeddings(api_key="pplx-x") + + async_response = _make_response( + 200, + { + "object": "list", + "data": [{"object": "embedding", "index": 0, "embedding": _encode_int8([7, 8, 9])}], + "model": "pplx-embed-v1-0.6b", + }, + ) + with patch("httpx.AsyncClient") as mock_client_cls: + instance = MagicMock() + instance.__aenter__ = AsyncMock(return_value=instance) + instance.__aexit__ = AsyncMock(return_value=None) + instance.post = AsyncMock(return_value=async_response) + mock_client_cls.return_value = instance + + vector = await client.aembed_query("hi") + + assert vector == [7.0, 8.0, 9.0] + + +# --------------------------------------------------------------------------- +# perplexity_langchain registration +# --------------------------------------------------------------------------- + + +class TestPerplexityLangChainRegistration: + + async def test_requires_api_key_in_env_or_config(self, monkeypatch, mock_builder): + monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False) + cfg = PerplexityEmbedderModelConfig() + with pytest.raises(ValueError, match="non-empty API key"): + async with perplexity_langchain(cfg, mock_builder): + pass + + async def test_uses_environment_api_key(self, monkeypatch, mock_builder): + monkeypatch.setenv("PERPLEXITY_API_KEY", "env-key") + cfg = PerplexityEmbedderModelConfig() + async with perplexity_langchain(cfg, mock_builder) as client: + assert hasattr(client, "embed_documents") + assert hasattr(client, "embed_query") diff --git a/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/embedder.py b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/embedder.py index b15c9ce017..f1cd07c756 100644 --- a/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/embedder.py +++ b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/embedder.py @@ -17,9 +17,11 @@ from nat.builder.framework_enum import LLMFrameworkEnum from nat.cli.register_workflow import register_embedder_client from nat.data_models.retry_mixin import RetryMixin +from nat.data_models.common import get_secret_value from nat.embedder.azure_openai_embedder import AzureOpenAIEmbedderModelConfig from nat.embedder.nim_embedder import NIMEmbedderModelConfig from nat.embedder.openai_embedder import OpenAIEmbedderModelConfig +from nat.embedder.perplexity_embedder import PerplexityEmbedderModelConfig from nat.llm.utils.http_client import http_clients from nat.utils.exception_handlers.automatic_retries import patch_with_retry @@ -95,3 +97,45 @@ async def openai_llama_index(embedder_config: OpenAIEmbedderModelConfig, _builde retry_on_messages=embedder_config.retry_on_errors) yield client + + +@register_embedder_client(config_type=PerplexityEmbedderModelConfig, wrapper_type=LLMFrameworkEnum.LLAMA_INDEX) +async def perplexity_llama_index(embedder_config: PerplexityEmbedderModelConfig, _builder: Builder): + """LlamaIndex client for the Perplexity Embeddings API. + + Builds a ``PerplexityLlamaIndexEmbedding`` adapter around the shared + ``PerplexityEmbeddings`` HTTP client. Outbound requests carry the + ``X-Pplx-Integration: nemo-agent-toolkit/`` attribution header. + """ + import os + + from nat.plugins.langchain.perplexity_embeddings_client import PerplexityEmbeddings + from nat.plugins.llama_index.perplexity_embeddings_client import PerplexityLlamaIndexEmbedding + + raw_key = get_secret_value(embedder_config.api_key) if embedder_config.api_key else "" + resolved_key = raw_key or os.environ.get("PERPLEXITY_API_KEY", "") + if not resolved_key: + raise ValueError( + "Perplexity embedder requires a non-empty API key. Set ``api_key`` in the " + "workflow config or export ``PERPLEXITY_API_KEY``." + ) + + inner = PerplexityEmbeddings( + api_key=resolved_key, + base_url=embedder_config.base_url, + model=embedder_config.model_name, + dimensions=embedder_config.dimensions, + encoding_format=embedder_config.encoding_format, + batch_size=embedder_config.batch_size, + max_retries=getattr(embedder_config, "num_retries", 3), + verify_ssl=getattr(embedder_config, "verify_ssl", True), + ) + client = PerplexityLlamaIndexEmbedding(client=inner) + + if isinstance(embedder_config, RetryMixin): + client = patch_with_retry(client, + retries=embedder_config.num_retries, + retry_codes=embedder_config.retry_on_status_codes, + retry_on_messages=embedder_config.retry_on_errors) + + yield client diff --git a/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/perplexity_embeddings_client.py b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/perplexity_embeddings_client.py new file mode 100644 index 0000000000..0f72c424ec --- /dev/null +++ b/packages/nvidia_nat_llama_index/src/nat/plugins/llama_index/perplexity_embeddings_client.py @@ -0,0 +1,67 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""LlamaIndex ``BaseEmbedding`` client for the Perplexity Embeddings API. + +Wraps the framework-agnostic ``PerplexityEmbeddings`` LangChain client to satisfy +LlamaIndex's :class:`~llama_index.core.embeddings.BaseEmbedding` interface so the +same NAT config can be consumed by LlamaIndex-powered retrievers. +""" +from __future__ import annotations + +from typing import Any + +from llama_index.core.bridge.pydantic import PrivateAttr +from llama_index.core.embeddings import BaseEmbedding + +from nat.plugins.langchain.perplexity_embeddings_client import PerplexityEmbeddings + + +class PerplexityLlamaIndexEmbedding(BaseEmbedding): + """LlamaIndex embedding adapter for the Perplexity Embeddings API.""" + + _client: PerplexityEmbeddings = PrivateAttr() + + def __init__(self, *, client: PerplexityEmbeddings, **data: Any) -> None: + data.setdefault("model_name", getattr(client, "_model", "pplx-embed-v1-0.6b")) + super().__init__(**data) + self._client = client + + @classmethod + def class_name(cls) -> str: + return "PerplexityLlamaIndexEmbedding" + + # ------------------------------------------------------------------ + # Sync interface + # ------------------------------------------------------------------ + def _get_query_embedding(self, query: str) -> list[float]: + return self._client.embed_query(query) + + def _get_text_embedding(self, text: str) -> list[float]: + return self._client.embed_query(text) + + def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]: + return self._client.embed_documents(texts) + + # ------------------------------------------------------------------ + # Async interface + # ------------------------------------------------------------------ + async def _aget_query_embedding(self, query: str) -> list[float]: + return await self._client.aembed_query(query) + + async def _aget_text_embedding(self, text: str) -> list[float]: + return await self._client.aembed_query(text) + + async def _aget_text_embeddings(self, texts: list[str]) -> list[list[float]]: + return await self._client.aembed_documents(texts) diff --git a/packages/nvidia_nat_llama_index/tests/test_perplexity_embedder_llama_index.py b/packages/nvidia_nat_llama_index/tests/test_perplexity_embedder_llama_index.py new file mode 100644 index 0000000000..bd5ad23aa9 --- /dev/null +++ b/packages/nvidia_nat_llama_index/tests/test_perplexity_embedder_llama_index.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock + +import pytest + +from nat.embedder.perplexity_embedder import PerplexityEmbedderModelConfig +from nat.plugins.llama_index.embedder import perplexity_llama_index + + +class TestPerplexityLlamaIndexRegistration: + + async def test_requires_api_key(self, monkeypatch, mock_builder): + monkeypatch.delenv("PERPLEXITY_API_KEY", raising=False) + cfg = PerplexityEmbedderModelConfig() + with pytest.raises(ValueError, match="non-empty API key"): + async with perplexity_llama_index(cfg, mock_builder): + pass + + async def test_yields_llama_index_embedding(self, monkeypatch, mock_builder): + monkeypatch.setenv("PERPLEXITY_API_KEY", "env-key") + cfg = PerplexityEmbedderModelConfig() + async with perplexity_llama_index(cfg, mock_builder) as client: + # LlamaIndex's ``BaseEmbedding`` exposes ``get_text_embedding``. + assert hasattr(client, "get_text_embedding") + assert hasattr(client, "get_query_embedding") + + +class TestPerplexityLlamaIndexAdapter: + + def test_get_text_embedding_delegates_to_client(self): + from nat.plugins.langchain.perplexity_embeddings_client import PerplexityEmbeddings + from nat.plugins.llama_index.perplexity_embeddings_client import PerplexityLlamaIndexEmbedding + + inner = MagicMock(spec=PerplexityEmbeddings) + inner._model = "pplx-embed-v1-0.6b" + inner.embed_query.return_value = [0.1, 0.2, 0.3] + inner.embed_documents.return_value = [[0.1], [0.2]] + inner.aembed_query = AsyncMock(return_value=[0.5]) + inner.aembed_documents = AsyncMock(return_value=[[0.6]]) + + adapter = PerplexityLlamaIndexEmbedding(client=inner) + + assert adapter._get_query_embedding("q") == [0.1, 0.2, 0.3] + assert adapter._get_text_embedding("t") == [0.1, 0.2, 0.3] + assert adapter._get_text_embeddings(["a", "b"]) == [[0.1], [0.2]] + + async def test_async_methods_delegate(self): + from nat.plugins.langchain.perplexity_embeddings_client import PerplexityEmbeddings + from nat.plugins.llama_index.perplexity_embeddings_client import PerplexityLlamaIndexEmbedding + + inner = MagicMock(spec=PerplexityEmbeddings) + inner._model = "pplx-embed-v1-0.6b" + inner.aembed_query = AsyncMock(return_value=[0.5]) + inner.aembed_documents = AsyncMock(return_value=[[0.6]]) + + adapter = PerplexityLlamaIndexEmbedding(client=inner) + + assert await adapter._aget_query_embedding("q") == [0.5] + assert await adapter._aget_text_embedding("t") == [0.5] + assert await adapter._aget_text_embeddings(["x"]) == [[0.6]]