refactor(llm): framework-owned provider registry#1773
Conversation
Greptile SummaryThis PR completes stack-7 of the LangChain decoupling effort by adding
|
| Filename | Overview |
|---|---|
| nemoguardrails/types.py | Added register_provider and get_provider_names to the LLMFramework @runtime_checkable Protocol; tests in test_types.py correctly verify both satisfaction and failure paths. |
| nemoguardrails/llm/providers/init.py | Public API rewired to delegate to the active framework; deprecated wrappers retain hasattr guards and emit DeprecationWarning; get_chat_provider_names and get_llm_provider_names now return distinct lists via LangChain-specific passthrough. |
| nemoguardrails/integrations/langchain/llm_adapter.py | New LangChainFramework class added; register_provider routes unconditionally to _register_chat (chat-first design); register_llm_provider/get_chat_provider_names/get_llm_provider_names kept as backward-compat passthrough to internal provider registries. |
| nemoguardrails/cli/providers.py | Wraps get_llm_provider_names calls in warnings.catch_warnings to suppress the new deprecation warning; CLI continues to render chat and text-completion lists as separate sections. |
| tests/llm/test_frameworks.py | FakeFramework only implements create_model, leaving it protocol-incomplete after register_provider and get_provider_names became required; existing tests pass because no provider operations are called on it, but it misrepresents the new protocol contract. |
| tests/test_types.py | Added TestLLMFrameworkProtocol that correctly asserts a full three-method MockFramework satisfies the protocol and that an empty class fails the isinstance check. |
Flowchart
%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["Public API\nnemoguardrails.llm.providers"] -->|register_provider / get_provider_names| B["_active_framework()\n→ LangChainFramework"]
A -->|register_chat_provider| B
A -->|register_llm_provider\n⚠️ DeprecationWarning| C{hasattr register_llm_provider?}
A -->|get_chat_provider_names| D{hasattr get_chat_provider_names?}
A -->|get_llm_provider_names\n⚠️ DeprecationWarning| E{hasattr get_llm_provider_names?}
B -->|register_provider| F["_register_chat()\n→ _chat_providers dict"]
B -->|register_llm_provider| G["_register_llm()\n→ _llm_providers dict\n(validates _acall)"]
B -->|get_provider_names| H["sorted union\nchat ∪ llm names"]
B -->|get_chat_provider_names| I["_chat_providers names\n+ partner providers"]
B -->|get_llm_provider_names| J["_llm_providers names"]
C -->|yes| G
C -->|no| F
D -->|yes| I
D -->|no| H
E -->|yes| J
E -->|no| H
Prompt To Fix All With AI
This is a comment left during a code review.
Path: tests/llm/test_frameworks.py
Line: 46-49
Comment:
**`FakeFramework` doesn't satisfy the updated `LLMFramework` protocol**
`FakeFramework` only implements `create_model`, but `LLMFramework` now requires `register_provider` and `get_provider_names` as well (added in `types.py` in this PR). `isinstance(FakeFramework(), LLMFramework)` returns `False`. The existing `TestRegistry` tests still pass because none of them call provider operations on `FakeFramework`, but `test_set_and_get_default_framework` sets it as the active default framework — any provider call inside that window would raise `AttributeError`. `TestLLMFrameworkProtocol` in `test_types.py` already verifies the protocol shape; keeping `FakeFramework` complete here avoids silent failures for future test additions.
```suggestion
class FakeFramework:
def create_model(self, model_name, provider_name, model_kwargs=None):
return MagicMock(spec=LLMModel)
def register_provider(self, name, provider_cls):
pass
def get_provider_names(self):
return []
```
How can I resolve this? If you propose a fix, please make it concise.Reviews (10): Last reviewed commit: "fix(llm): keep separate chat/llm provide..." | Re-trigger Greptile
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
aca95ea to
db03710
Compare
db03710 to
6cc77c5
Compare
a7f2246 to
a43a857
Compare
bfb40c5 to
d441096
Compare
a43a857 to
2140eed
Compare
3224c49 to
cfe3d8d
Compare
d048b6f to
5d6d8e9
Compare
e55fbcd to
b781f72
Compare
Add register_provider/get_provider_names to LLMFramework protocol. Public API (llm/providers/__init__.py) delegates to the active framework instead of importing from LangChain internals directly. LangChainFramework implements the new methods, routing to its internal chat/llm provider registries. Backwards-compat aliases register_chat_provider/register_llm_provider still work.
…ings LangChainFramework exposes get_chat_provider_names() and get_llm_provider_names() separately so CLI and existing tests work. register_llm_provider and get_llm_provider_names emit DeprecationWarning (removal in 0.23.0). register_provider always targets chat registry. Add 7 provider registration tests with cleanup fixture to prevent polluting the global LangChain provider dicts.
cfe3d8d to
af30ae5
Compare
📝 WalkthroughWalkthroughThis pull request introduces a framework-driven provider management system. The Changes
Sequence Diagram(s)sequenceDiagram
actor User
participant CLI
participant Provider API
participant LLMFramework
participant LangChainFramework as LangChain<br/>Framework
participant ProviderRegistry
User->>CLI: List providers (text completion)
CLI->>Provider API: get_llm_provider_names()
note over Provider API: Emit DeprecationWarning
Provider API->>LLMFramework: get_framework()
LLMFramework->>LangChainFramework: get_llm_provider_names()
LangChainFramework->>ProviderRegistry: Query LLM providers
ProviderRegistry-->>LangChainFramework: Provider names list
LangChainFramework-->>Provider API: Names
Provider API-->>CLI: Names (warning suppressed)
CLI-->>User: Display provider list
User->>CLI: Register new provider
CLI->>Provider API: register_provider(name, cls)
Provider API->>LLMFramework: get_framework()
LLMFramework->>LangChainFramework: register_provider(name, cls)
LangChainFramework->>ProviderRegistry: Store provider
ProviderRegistry-->>LangChainFramework: Stored
LangChainFramework-->>Provider API: Success
Provider API-->>CLI: Confirmed
CLI-->>User: Provider registered
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches📝 Generate docstrings
🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
🧹 Nitpick comments (3)
nemoguardrails/llm/providers/__init__.py (1)
22-23: Add a concrete return type to_active_frameworkfor stronger static checks.A return annotation improves readability/type-safety for all delegated calls in this module.
♻️ Suggested typing tweak
from typing import Any, List @@ +from nemoguardrails.types import LLMFramework @@ -def _active_framework(): +def _active_framework() -> LLMFramework: return get_framework(get_default_framework())🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemoguardrails/llm/providers/__init__.py` around lines 22 - 23, The helper function _active_framework currently has no return annotation; add a concrete return type to improve static checking by annotating _active_framework with the framework interface/type returned by get_framework (e.g., import and use the exported Framework/FrameworkProtocol type from the module that defines get_framework) and keep the implementation as return get_framework(get_default_framework()) so callers of _active_framework have a precise type.tests/llm/test_frameworks.py (1)
129-145: Consider a local helper/context manager for deprecation-warning suppression in tests.You repeat the same warning-suppression pattern several times; consolidating it will reduce noise in this test class.
♻️ Suggested test helper pattern
+from contextlib import contextmanager + +@contextmanager +def _ignore_deprecated_llm_provider_warning(): + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + yield @@ - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + with _ignore_deprecated_llm_provider_warning(): register_llm_provider("test_llm", FakeLLMProvider) - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) + with _ignore_deprecated_llm_provider_warning(): assert "test_llm" in get_llm_provider_names()Also applies to: 154-156
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/llm/test_frameworks.py` around lines 129 - 145, Create a small local context manager (e.g., suppress_deprecation_warnings) in tests/llm/test_frameworks.py that uses warnings.catch_warnings() and warnings.simplefilter("ignore", DeprecationWarning), then replace each repeated with warnings.catch_warnings(): ... block around calls to register_llm_provider, register_chat_provider, get_llm_provider_names, and get_chat_provider_names with a single with suppress_deprecation_warnings(): wrapper; apply the same replacement for the other occurrences mentioned (the later get_* calls).nemoguardrails/cli/providers.py (1)
35-42: Extract deprecated LLM-provider lookup into one helper.The behavior is correct, but the warning-suppression block is duplicated. A tiny helper will keep this consistent and easier to update.
♻️ Suggested refactor
+def _get_llm_provider_names_silenced() -> List[str]: + with warnings.catch_warnings(): + warnings.simplefilter("ignore", DeprecationWarning) + return sorted(get_llm_provider_names()) + def _list_providers() -> None: @@ - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - console.print("\n[bold]Text Completion Providers:[/]") - for provider in sorted(get_llm_provider_names()): - console.print(f" • {provider}") + console.print("\n[bold]Text Completion Providers:[/]") + for provider in _get_llm_provider_names_silenced(): + console.print(f" • {provider}") @@ def _get_provider_completions( provider_type: Optional[ProviderType] = None, ) -> List[str]: @@ if provider_type == "text completion": - # See comment in _list_providers for why we suppress this warning. - with warnings.catch_warnings(): - warnings.simplefilter("ignore", DeprecationWarning) - return sorted(get_llm_provider_names()) + return _get_llm_provider_names_silenced()Also applies to: 54-57
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@nemoguardrails/cli/providers.py` around lines 35 - 42, Extract the duplicated warning-suppression + printing logic into a small helper (e.g., _print_providers_suppressing_deprecation) that accepts a heading string and a provider-getter function; inside the helper use warnings.catch_warnings() with warnings.simplefilter("ignore", DeprecationWarning) and print the heading then iterate sorted(provider_getter()) to print each provider. Replace the two duplicated blocks that call get_llm_provider_names() and get_embedding_provider_names() with calls to this helper (keep console.print formatting identical).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@nemoguardrails/cli/providers.py`:
- Around line 35-42: Extract the duplicated warning-suppression + printing logic
into a small helper (e.g., _print_providers_suppressing_deprecation) that
accepts a heading string and a provider-getter function; inside the helper use
warnings.catch_warnings() with warnings.simplefilter("ignore",
DeprecationWarning) and print the heading then iterate sorted(provider_getter())
to print each provider. Replace the two duplicated blocks that call
get_llm_provider_names() and get_embedding_provider_names() with calls to this
helper (keep console.print formatting identical).
In `@nemoguardrails/llm/providers/__init__.py`:
- Around line 22-23: The helper function _active_framework currently has no
return annotation; add a concrete return type to improve static checking by
annotating _active_framework with the framework interface/type returned by
get_framework (e.g., import and use the exported Framework/FrameworkProtocol
type from the module that defines get_framework) and keep the implementation as
return get_framework(get_default_framework()) so callers of _active_framework
have a precise type.
In `@tests/llm/test_frameworks.py`:
- Around line 129-145: Create a small local context manager (e.g.,
suppress_deprecation_warnings) in tests/llm/test_frameworks.py that uses
warnings.catch_warnings() and warnings.simplefilter("ignore",
DeprecationWarning), then replace each repeated with warnings.catch_warnings():
... block around calls to register_llm_provider, register_chat_provider,
get_llm_provider_names, and get_chat_provider_names with a single with
suppress_deprecation_warnings(): wrapper; apply the same replacement for the
other occurrences mentioned (the later get_* calls).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 0ff883f7-0fb0-4212-ac5f-06dd4e8f2b91
📒 Files selected for processing (6)
nemoguardrails/cli/providers.pynemoguardrails/integrations/langchain/llm_adapter.pynemoguardrails/llm/providers/__init__.pynemoguardrails/types.pytests/llm/test_frameworks.pytests/test_types.py
Part of the LangChain decoupling stack:
Description
Add register_provider/get_provider_names to LLMFramework protocol.
our public API delegates to the active framework instead of importing from LangChain internals directly. LangChainFramework implements the new methods, routing to its internal chat/llm provider registries. Backward compat aliases register_chat_provider/register_llm_provider still work.
Summary by CodeRabbit
New Features
Deprecations