diff --git a/src/google/adk_community/plugins/__init__.py b/src/google/adk_community/plugins/__init__.py index ab61116..2e4b2ee 100644 --- a/src/google/adk_community/plugins/__init__.py +++ b/src/google/adk_community/plugins/__init__.py @@ -15,5 +15,23 @@ from google.adk_community.plugins.agent_governance_plugin import ( AgentGovernancePlugin, ) +from google.adk_community.plugins.taxonomy import ( + DefaultSkillPolicy, + SkillPolicy, + TaxonomyPipeline, + TaxonomyPlugin, + TaxonomyRegistry, + TaxonomyResolver, + TaxonomyTerm, +) -__all__ = ["AgentGovernancePlugin"] +__all__ = [ + "AgentGovernancePlugin", + "DefaultSkillPolicy", + "SkillPolicy", + "TaxonomyPipeline", + "TaxonomyPlugin", + "TaxonomyRegistry", + "TaxonomyResolver", + "TaxonomyTerm", +] diff --git a/src/google/adk_community/plugins/taxonomy/__init__.py b/src/google/adk_community/plugins/taxonomy/__init__.py new file mode 100644 index 0000000..780840c --- /dev/null +++ b/src/google/adk_community/plugins/taxonomy/__init__.py @@ -0,0 +1,33 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pluggable Policy & Taxonomy Security Engine for ADK Community.""" + +from .policy import DefaultSkillPolicy +from .policy import SkillPolicy +from .policy import TaxonomyPipeline +from .policy import TaxonomyResolver +from .taxonomy_config import TaxonomyRegistry +from .taxonomy_config import TaxonomyTerm +from .taxonomy_plugin import TaxonomyPlugin + +__all__ = [ + "DefaultSkillPolicy", + "SkillPolicy", + "TaxonomyPipeline", + "TaxonomyPlugin", + "TaxonomyRegistry", + "TaxonomyResolver", + "TaxonomyTerm", +] diff --git a/src/google/adk_community/plugins/taxonomy/policy.py b/src/google/adk_community/plugins/taxonomy/policy.py new file mode 100644 index 0000000..13a1285 --- /dev/null +++ b/src/google/adk_community/plugins/taxonomy/policy.py @@ -0,0 +1,223 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Abstract interfaces for taxonomy resolution and skill policy enforcement.""" + +from __future__ import annotations + +from abc import ABC +from abc import abstractmethod + +from google.adk.agents.readonly_context import ReadonlyContext +from google.adk.models.llm_request import LlmRequest +from google.adk.skills.models import Skill + + +class TaxonomyResolver(ABC): + """Abstract base class for taxonomy resolution. + + Resolvers analyze context and LLM history to determine which taxonomy + classification domains (e.g. URI strings) are currently active and relevant. + """ + + @abstractmethod + async def resolve_taxonomies( + self, context: ReadonlyContext, llm_request: LlmRequest + ) -> list[str]: + """Resolves active taxonomy domain URIs from context and LLM history. + + Args: + context: The current read-only execution context. + llm_request: The upcoming LLM request holding prompt configurations. + + Returns: + A list of resolved active taxonomy strings/URIs. + """ + pass + + +class TaxonomyPipeline(TaxonomyResolver): + """Executes a sequence of taxonomy resolvers in order (multi-step pipeline). + + This implements a composite/pipeline pattern to merge active taxonomy domains + identified by multiple independent heuristics (e.g. lexical, model-based). + """ + + def __init__(self, resolvers: list[TaxonomyResolver]): + self.resolvers = resolvers + + async def resolve_taxonomies( + self, context: ReadonlyContext, llm_request: LlmRequest + ) -> list[str]: + # Aggregates unique taxonomy domains across all registered resolvers + active_domains: set[str] = set() + for resolver in self.resolvers: + domains = await resolver.resolve_taxonomies(context, llm_request) + if domains: + active_domains.update(domains) + return list(active_domains) + + +class SkillPolicy(ABC): + """Abstract policy engine determining skill execution permissions and instruction shaping. + + This class defines the interface for two main responsibilities: + 1. Access Control (Authorization): Blocking or permitting skills based on active taxonomies. + 2. Cognitive Steering (Behavioral Shaping): Altering skill instructions, descriptions, + prioritization, and global system prompts to steer agent execution dynamically. + + Implements the Hook Method pattern, providing concrete default pass-throughs + for steering while keeping authorization and core shaping abstract. + """ + + @abstractmethod + def is_skill_allowed( + self, + skill: Skill, + context: ReadonlyContext, + active_taxonomies: list[str], + ) -> bool: + """Determines if a skill can be loaded/used under the active taxonomies and context. + + Args: + skill: The target Skill model instance. + context: The read-only interaction context. + active_taxonomies: The list of currently active taxonomy domains. + + Returns: + True if the skill is permitted to run, False otherwise. + """ + pass + + @abstractmethod + def shape_instructions( + self, + skill: Skill, + context: ReadonlyContext, + original_instructions: str, + ) -> str: + """Applies dynamic instruction shaping/guardrails to a skill's instructions. + + Use this to append safety restrictions, enforce compliance constraints, + or adjust operating parameters of a skill before execution. + """ + pass + + def shape_description( + self, + skill: Skill, + context: ReadonlyContext, + original_description: str, + ) -> str: + """Applies dynamic description shaping before the tool reaches the agent. + + This can be used to emphasize specific features of a skill to the LLM or + prune redundant information to fit within context limits. + """ + return original_description + + def shape_system_instruction( + self, + context: ReadonlyContext, + active_taxonomies: list[str], + original_instructions: str, + ) -> str: + """Applies dynamic instruction shaping to the global agent system instructions. + + Use this to dynamically inject directives (e.g. telling the LLM to trigger + certain tools almost by default or prioritize specific workflows) depending + on the current active taxonomy classification. + """ + return original_instructions + + def prioritize_skills( + self, + skills: list[Skill], + context: ReadonlyContext, + active_taxonomies: list[str], + ) -> list[Skill]: + """Prioritizes, reorders, or accentuates skills under the active taxonomy. + + Allows the policy to sort key tools to the top of the available_skills XML list + presented in the prompt, encouraging the LLM to select preferred actions. + """ + return skills + + +def _get_taxonomy_binds(skill: Skill) -> list[str]: + """Dynamically extracts taxonomy binds, supporting both modified and unmodified core SDKs. + + This utility functions as a robust protocol layer. If the SDK natively supports + frontmatter taxonomy binds, it reads them directly. Otherwise, it falls back to parsing + Pydantic extra fields (since core SDK uses `extra="allow"`), handling variations in + hyphenation/naming conventions. + """ + # Direct attribute access check + if hasattr(skill.frontmatter, "taxonomy_binds"): + return skill.frontmatter.taxonomy_binds + + # Fallback: Read from Pydantic's model_extra dictionary (natively populated because of extra="allow") + extra = getattr(skill.frontmatter, "model_extra", None) or {} + binds = extra.get("taxonomy-binds") or extra.get("taxonomy_binds") or [] + if isinstance(binds, str): + return [binds] + return list(binds) + + +class DefaultSkillPolicy(SkillPolicy): + """Default skill policy using taxonomy-bind set-intersection matching. + + If a skill has no taxonomy binds defined, it is treated as unrestricted/allowed by default. + If it has binds, at least one bind must intersect with the active taxonomy set. + """ + + def is_skill_allowed( + self, + skill: Skill, + context: ReadonlyContext, + active_taxonomies: list[str], + ) -> bool: + binds = _get_taxonomy_binds(skill) + # Unrestricted skills are always allowed + if not binds: + return True + # Require at least one matching taxonomy between active set and skill binds + return bool(set(binds) & set(active_taxonomies)) + + def shape_instructions( + self, + skill: Skill, + context: ReadonlyContext, + original_instructions: str, + ) -> str: + # No-op pass-through for default behavior + return original_instructions + + def shape_system_instruction( + self, + context: ReadonlyContext, + active_taxonomies: list[str], + original_instructions: str, + ) -> str: + # No-op pass-through for default behavior + return original_instructions + + def prioritize_skills( + self, + skills: list[Skill], + context: ReadonlyContext, + active_taxonomies: list[str], + ) -> list[Skill]: + # No-op pass-through for default behavior + return skills diff --git a/src/google/adk_community/plugins/taxonomy/taxonomy_config.py b/src/google/adk_community/plugins/taxonomy/taxonomy_config.py new file mode 100644 index 0000000..eea3506 --- /dev/null +++ b/src/google/adk_community/plugins/taxonomy/taxonomy_config.py @@ -0,0 +1,128 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pydantic models for taxonomy configuration parsing.""" + +from __future__ import annotations + +from typing import Optional + +from pydantic import BaseModel +from pydantic import ConfigDict +from pydantic import Field + + +class TaxonomyTerm(BaseModel): + """A single taxonomy term with metadata for validation and LLM disambiguation. + Attributes: + id: (str) + parent_id: (Optional[str]) + name: (str) + definition: (Optional[str]) + alt_labels: (list[str]) + """ + + model_config = ConfigDict(populate_by_name=True) + + id: str + parent_id: Optional[str] = Field(None, alias="parentId") + name: str + definition: Optional[str] = None + alt_labels: list[str] = Field(default_factory=list, alias="altLabels") + + +class TaxonomyRegistry(BaseModel): + """Central registry for taxonomy term definitions. + + Supported JSON Schemas: + + **Flat Key-Value JSON** (``from_flat_json``): + id: str + parentId: Optional[str] + name: str + definition: Optional[str] + + **JSON-LD with SKOS** (``from_json_ld``): + @context: str + @type: str + @id: str + prefLabel: dict (``{"@value": str, "@language": str}``) + altLabel: list[dict] (``[{"@value": str, "@language": str}]``) + definition: dict (``{"@value": str, "@language": str}``) + broader: Optional[str] + """ + + terms: dict[str, TaxonomyTerm] = {} + + @classmethod + def from_flat_json(cls, data: list[dict]) -> TaxonomyRegistry: + """Parse taxonomy terms from flat key-value JSON format.""" + terms = {} + for item in data: + term = TaxonomyTerm.model_validate(item) + terms[term.id] = term + return cls(terms=terms) + + @classmethod + def from_json_ld(cls, data: list[dict]) -> TaxonomyRegistry: + """Parse JSON-LD SKOS format into TaxonomyRegistry.""" + terms = {} + for item in data: + term_id = item.get("@id") + if not term_id: + continue + + pref_label = item.get("prefLabel", {}) + if isinstance(pref_label, dict): + pref_label = pref_label.get("@value", "") + + definition_raw = item.get("definition", {}) + if isinstance(definition_raw, dict): + definition = definition_raw.get("@value") or None + elif isinstance(definition_raw, str): + definition = definition_raw or None + else: + definition = None + + alt_labels_raw = item.get("altLabel", []) + if not isinstance(alt_labels_raw, list): + alt_labels_raw = [alt_labels_raw] + alt_labels = [ + label.get("@value") + for label in alt_labels_raw + if isinstance(label, dict) and label.get("@value") + ] + + broader = item.get("broader") + term = TaxonomyTerm( + id=term_id, + parent_id=broader, + name=pref_label, + definition=definition, + alt_labels=alt_labels, + ) + terms[term_id] = term + return cls(terms=terms) + + def get_term(self, term_id: str) -> Optional[TaxonomyTerm]: + """Lookup a term by its ID.""" + return self.terms.get(term_id) + + def get_children(self, parent_id: str) -> list[TaxonomyTerm]: + """Get all direct children of a term.""" + return [t for t in self.terms.values() if t.parent_id == parent_id] + + def list_ids(self) -> list[str]: + """List all term IDs in the registry.""" + return list(self.terms.keys()) diff --git a/src/google/adk_community/plugins/taxonomy/taxonomy_plugin.py b/src/google/adk_community/plugins/taxonomy/taxonomy_plugin.py new file mode 100644 index 0000000..bd9c3fe --- /dev/null +++ b/src/google/adk_community/plugins/taxonomy/taxonomy_plugin.py @@ -0,0 +1,254 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TaxonomyPlugin — ADK BasePlugin for pluggable taxonomy policy enforcement.""" + +from __future__ import annotations + +import logging +from typing import Any +from typing import Optional + +from google.adk.plugins.base_plugin import BasePlugin +from google.adk.agents.callback_context import CallbackContext +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.adk.skills import prompt +from google.adk.tools.base_tool import BaseTool +from google.adk.tools.tool_context import ToolContext + +from .policy import SkillPolicy +from .policy import TaxonomyResolver +from .taxonomy_config import TaxonomyRegistry + +logger = logging.getLogger("google_adk_community." + __name__) + +_ACTIVE_TAXONOMIES_STATE_KEY = "_active_taxonomies" + +_SKILL_GATE_TOOLS = frozenset({ + "list_skills", + "load_skill", + "load_skill_resource", + "run_skill_script", +}) + + +class TaxonomyPlugin(BasePlugin): + """Native ADK Plugin enforcing pluggable taxonomy policies.""" + + def __init__( + self, + name: str = "taxonomy_plugin", + *, + taxonomy_registry: Optional[TaxonomyRegistry] = None, + resolver: Optional[TaxonomyResolver] = None, + policy: Optional[SkillPolicy] = None, + ): + super().__init__(name) + self.taxonomy_registry = taxonomy_registry or TaxonomyRegistry() + self.resolver = resolver + self.policy = policy + + async def before_model_callback( + self, *, callback_context: CallbackContext, llm_request: LlmRequest + ) -> Optional[LlmResponse]: + """Resolves active taxonomies and stores them in session state.""" + if not self.resolver: + return None + + active_taxonomies = await self.resolver.resolve_taxonomies( + callback_context, llm_request + ) + callback_context.state[_ACTIVE_TAXONOMIES_STATE_KEY] = active_taxonomies + + logger.debug( + "[%s] Resolved active taxonomies: %s", self.name, active_taxonomies + ) + + if self.policy: + orig_instructions = llm_request.config.system_instruction or "" + shaped_instructions = self.policy.shape_system_instruction( + callback_context, active_taxonomies, orig_instructions + ) + if shaped_instructions != orig_instructions: + logger.debug( + "[%s] Active taxonomy dynamic system prompt shaping applied.", + self.name, + ) + llm_request.config.system_instruction = shaped_instructions + + return None + + async def before_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + ) -> Optional[dict]: + """Intercepts skill tools to enforce taxonomy policy and path validation.""" + if tool.name not in _SKILL_GATE_TOOLS: + return None + + active_taxonomies = ( + tool_context.state.get(_ACTIVE_TAXONOMIES_STATE_KEY) or [] + ) + + if tool.name == "list_skills": + return self._filter_list_skills(tool, tool_context, active_taxonomies) + + skill_name = tool_args.get("skill_name") + if not skill_name: + return None + + # Inline path validation (avoids importing private _validate_path_segment) + if ( + not skill_name + or "\x00" in skill_name + or "/" in skill_name + or "\\" in skill_name + or skill_name in (".", "..") + or ".." in skill_name.split("/") + ): + return { + "error": f"Invalid skill_name parameter: {skill_name!r}", + "error_code": "INVALID_ARGUMENTS", + } + + file_path = tool_args.get("file_path") + if file_path: + if ".." in file_path or file_path.startswith(("/", "\\")): + return { + "error": f"Path traversal attempt blocked: {file_path}", + "error_code": "INVALID_ARGUMENTS", + } + + if self.policy and self.resolver: + toolset = getattr(tool, "_toolset", None) + if toolset: + skill = await toolset._get_or_fetch_skill( + skill_name, tool_context.invocation_id + ) + if skill and not self.policy.is_skill_allowed( + skill, tool_context, active_taxonomies + ): + logger.warning( + "[%s] Skill '%s' blocked by policy. Active taxonomies: %s", + self.name, + skill_name, + active_taxonomies, + ) + return { + "error": ( + f"Access to skill '{skill_name}' is not permitted" + " under active policy constraints." + ), + "error_code": "SKILL_NOT_PERMITTED", + } + + return None + + def _filter_list_skills( + self, tool: BaseTool, tool_context: ToolContext, active_taxonomies: list[str] + ) -> Optional[dict]: + """Filters the list_skills result to only show policy-permitted skills.""" + if not self.policy or not self.resolver: + return None + + toolset = getattr(tool, "_toolset", None) + if not toolset: + return None + + all_skills = toolset._list_skills() + allowed_skills = [ + skill + for skill in all_skills + if self.policy.is_skill_allowed(skill, tool_context, active_taxonomies) + ] + + # Reorder and prioritize skills dynamically + prioritized_skills = self.policy.prioritize_skills( + allowed_skills, tool_context, active_taxonomies + ) + + from google.adk.skills.models import Skill, Frontmatter + + shaped_skills = [] + for skill in prioritized_skills: + original_desc = skill.frontmatter.description or "" + shaped_desc = self.policy.shape_description(skill, tool_context, original_desc) + extra = getattr(skill.frontmatter, "model_extra", None) or {} + new_skill = Skill( + frontmatter=Frontmatter( + name=skill.frontmatter.name, + description=shaped_desc, + **extra + ), + instructions=skill.instructions + ) + shaped_skills.append(new_skill) + + logger.debug( + "[%s] Filtered skills: %d/%d visible", + self.name, + len(shaped_skills), + len(all_skills), + ) + return {"result": prompt.format_skills_as_xml(shaped_skills)} + + + async def after_tool_callback( + self, + *, + tool: BaseTool, + tool_args: dict[str, Any], + tool_context: ToolContext, + result: dict, + ) -> Optional[dict]: + """Applies dynamic instruction shaping to load_skill results.""" + if tool.name != "load_skill": + return None + if not self.policy or not self.resolver: + return None + if not isinstance(result, dict) or "instructions" not in result: + return None + + skill_name = tool_args.get("skill_name") + if not skill_name: + return None + + toolset = getattr(tool, "_toolset", None) + if not toolset: + return None + + skill = await toolset._get_or_fetch_skill( + skill_name, tool_context.invocation_id + ) + if not skill: + return None + + shaped_instructions = self.policy.shape_instructions( + skill, tool_context, result["instructions"] + ) + + if shaped_instructions != result["instructions"]: + logger.debug( + "[%s] Shaped instructions for skill '%s'", + self.name, + skill_name, + ) + + shaped_result = dict(result) + shaped_result["instructions"] = shaped_instructions + return shaped_result diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..c4e4f3d --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,36 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +from types import ModuleType + +# Pre-emptively mock/patch google.genai.types.AvatarConfig if it's missing or fails to import +try: + import google.genai.types as genai_types + if not hasattr(genai_types, "AvatarConfig"): + from pydantic import BaseModel + class AvatarConfig(BaseModel): + pass + genai_types.AvatarConfig = AvatarConfig +except Exception: + try: + sys.modules["google.genai"] = ModuleType("google.genai") + + from pydantic import BaseModel + class AvatarConfig(BaseModel): + pass + genai_types = sys.modules["google.genai.types"] = ModuleType("google.genai.types") + genai_types.AvatarConfig = AvatarConfig + except Exception: + pass diff --git a/tests/plugins/test_taxonomy_plugin.py b/tests/plugins/test_taxonomy_plugin.py new file mode 100644 index 0000000..61ef84f --- /dev/null +++ b/tests/plugins/test_taxonomy_plugin.py @@ -0,0 +1,308 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the Pluggable Policy & Taxonomy Security Engine in Community. + +This test suite covers taxonomy classification data loading formats, resolver aggregation, +access-control authorization filtering, path validation/traversal prevention, and +cognitive steering/behavioral shaping mechanisms. +""" + +from unittest import mock +import pytest + +from google.adk_community.plugins.taxonomy import DefaultSkillPolicy +from google.adk_community.plugins.taxonomy import SkillPolicy +from google.adk_community.plugins.taxonomy import TaxonomyPipeline +from google.adk_community.plugins.taxonomy import TaxonomyPlugin +from google.adk_community.plugins.taxonomy import TaxonomyRegistry +from google.adk_community.plugins.taxonomy import TaxonomyResolver +from google.adk_community.plugins.taxonomy import TaxonomyTerm +from google.adk_community.plugins.taxonomy.policy import _get_taxonomy_binds +from google.adk.skills.models import Frontmatter +from google.adk.skills.models import Skill + + +def test_taxonomy_term(): + """Tests TaxonomyTerm model instantiation and defaults. + + Ensures taxonomy term instances hold core metadata and instantiate with standard + defaults (like empty alternate labels and no parents). + """ + term = TaxonomyTerm(id="tech", name="Technology", definition="Tech domain") + assert term.id == "tech" + assert term.name == "Technology" + assert term.definition == "Tech domain" + assert term.parent_id is None + assert term.alt_labels == [] + + +def test_registry_flat_json(): + """Tests parsing flat JSON structure into TaxonomyRegistry. + + Verifies that a plain list of objects defining IDs and parent IDs are correctly + loaded and indexed into hierarchical parent-child relationships. + """ + data = [ + { + "id": "eng", + "parentId": None, + "name": "Engineering", + "definition": "Eng dept", + }, + { + "id": "ml", + "parentId": "eng", + "name": "Machine Learning", + "definition": "ML team", + }, + ] + registry = TaxonomyRegistry.from_flat_json(data) + assert len(registry.list_ids()) == 2 + assert "eng" in registry.list_ids() + assert "ml" in registry.list_ids() + + term_eng = registry.get_term("eng") + term_ml = registry.get_term("ml") + assert term_eng.name == "Engineering" + assert term_ml.parent_id == "eng" + + children = registry.get_children("eng") + assert len(children) == 1 + assert children[0].id == "ml" + + +def test_registry_json_ld(): + """Tests parsing JSON-LD SKOS structure into TaxonomyRegistry. + + Validates SKOS standard structure imports, including URI mapping, prefLabel + mapping, altLabel array conversions, and broader relation parsing. + """ + data = [ + { + "@context": "http://w3.org", + "@type": "Concept", + "@id": "https://example.com/eng", + "prefLabel": {"@value": "Engineering", "@language": "en"}, + "definition": {"@value": "Eng dept", "@language": "en"}, + }, + { + "@context": "http://w3.org", + "@type": "Concept", + "@id": "https://example.com/ml", + "prefLabel": {"@value": "Machine Learning", "@language": "en"}, + "altLabel": [{"@value": "ML", "@language": "en"}], + "definition": {"@value": "ML team", "@language": "en"}, + "broader": "https://example.com/eng", + }, + ] + registry = TaxonomyRegistry.from_json_ld(data) + assert len(registry.list_ids()) == 2 + + term_eng = registry.get_term("https://example.com/eng") + term_ml = registry.get_term("https://example.com/ml") + assert term_eng.name == "Engineering" + assert term_ml.parent_id == "https://example.com/eng" + assert term_ml.alt_labels == ["ML"] + + +@pytest.mark.asyncio +async def test_taxonomy_pipeline(): + """Tests pipeline resolution chaining multiple resolvers. + + Ensures that the composite pipeline runs each individual resolver and merges + their outputs into a unique, aggregated active taxonomy list. + """ + + class SimpleResolver(TaxonomyResolver): + + def __init__(self, resolved_domains: list[str]): + self.resolved_domains = resolved_domains + + async def resolve_taxonomies(self, context, llm_request) -> list[str]: + return self.resolved_domains + + context = mock.MagicMock() + llm_request = mock.MagicMock() + + pipeline = TaxonomyPipeline([SimpleResolver(["eng"]), SimpleResolver(["finance"])]) + resolved = await pipeline.resolve_taxonomies(context, llm_request) + assert sorted(resolved) == ["eng", "finance"] + + +def test_default_skill_policy(): + """Tests DefaultSkillPolicy filter mechanism. + + Checks that the default intersection policy correctly authorizes matching skills, + blocks skills with non-overlapping binds, and always allows unrestricted skills. + """ + policy = DefaultSkillPolicy() + + skill_eng = Skill( + frontmatter=Frontmatter( + name="eng-skill", + description="Desc", + taxonomy_binds=["eng"], + ), + instructions="body", + ) + skill_finance = Skill( + frontmatter=Frontmatter( + name="finance-skill", + description="Desc", + taxonomy_binds=["finance"], + ), + instructions="body", + ) + + context = mock.MagicMock() + assert policy.is_skill_allowed(skill_eng, context, ["eng"]) is True + assert policy.is_skill_allowed(skill_finance, context, ["eng"]) is False + assert policy.is_skill_allowed(skill_finance, context, ["eng", "finance"]) is True + + skill_unrestricted = Skill( + frontmatter=Frontmatter(name="any-skill", description="Desc"), + instructions="body", + ) + assert policy.is_skill_allowed(skill_unrestricted, context, ["marketing"]) is True + + assert policy.shape_instructions(skill_eng, context, "original") == "original" + + +@pytest.mark.asyncio +async def test_taxonomy_plugin_list_skills(): + """Tests TaxonomyPlugin intercepts and filters skill lists correctly. + + Verifies that list_skills tool calls are intercepted in before_tool_callback + and that the return payload contains only the policy-allowed skills in valid XML format. + """ + + class RestrictedPolicy(SkillPolicy): + + def is_skill_allowed(self, skill: Skill, context, active_taxonomies: list[str]) -> bool: + binds = _get_taxonomy_binds(skill) + return "eng" in binds + + def shape_instructions(self, skill: Skill, context, original_instructions: str) -> str: + return original_instructions + + mock_resolver = mock.MagicMock() + plugin = TaxonomyPlugin(policy=RestrictedPolicy(), resolver=mock_resolver) + + skills = { + "skill-1": Skill( + frontmatter=Frontmatter( + name="skill-1", + description="Desc", + taxonomy_binds=["eng"], + ), + instructions="body", + ), + "skill-2": Skill( + frontmatter=Frontmatter( + name="skill-2", + description="Desc", + taxonomy_binds=["finance"], + ), + instructions="body", + ), + } + + context = mock.MagicMock() + context.state = {"_active_taxonomies": ["eng"]} + + mock_tool = mock.MagicMock() + mock_tool.name = "list_skills" + mock_tool._toolset._list_skills.return_value = list(skills.values()) + + # Patch XML formatter to focus purely on verifying taxonomy filtration behavior + with mock.patch("google.adk_community.plugins.taxonomy.taxonomy_plugin.prompt.format_skills_as_xml") as mock_format: + mock_format.return_value = "" + + result = await plugin.before_tool_callback( + tool=mock_tool, + tool_args={}, + tool_context=context, + ) + + assert isinstance(result, dict) + assert "result" in result + assert "skill-1" in result["result"] + assert "skill-2" not in result["result"] + + +@pytest.mark.asyncio +async def test_taxonomy_steering_capabilities(): + """Tests prioritizing/sorting skills and injecting global system prompts. + + Verifies cognitive steering hooks: + 1. System Instruction Shaping (injecting dynamic instructions into LLM system prompts). + 2. Skill Prioritization (reordering skills in list_skills results). + """ + + class SteeringPolicy(SkillPolicy): + + def is_skill_allowed(self, skill: Skill, context, active_taxonomies: list[str]) -> bool: + return True + + def shape_instructions(self, skill: Skill, context, original_instructions: str) -> str: + return original_instructions + + def shape_system_instruction(self, context, active_taxonomies: list[str], original_instructions: str) -> str: + if "strict" in active_taxonomies: + return original_instructions + " - MANDATED COMPLIANCE TURN" + return original_instructions + + def prioritize_skills(self, skills: list[Skill], context, active_taxonomies: list[str]) -> list[Skill]: + if "strict" in active_taxonomies: + return sorted(skills, key=lambda s: 0 if s.frontmatter.name == "important" else 1) + return skills + + class MockResolver(TaxonomyResolver): + async def resolve_taxonomies(self, context, llm_request) -> list[str]: + return ["strict"] + + plugin = TaxonomyPlugin(policy=SteeringPolicy(), resolver=MockResolver()) + + # 1. Verify before_model_callback system instruction injection + context = mock.MagicMock() + context.state = {} + llm_request = mock.MagicMock() + llm_request.config = mock.MagicMock() + llm_request.config.system_instruction = "Original Prompt" + + await plugin.before_model_callback(callback_context=context, llm_request=llm_request) + assert context.state["_active_taxonomies"] == ["strict"] + assert llm_request.config.system_instruction == "Original Prompt - MANDATED COMPLIANCE TURN" + + # 2. Verify skill prioritization/sorting in list_skills + skills = [ + Skill(frontmatter=Frontmatter(name="normal", description="Desc"), instructions="body"), + Skill(frontmatter=Frontmatter(name="important", description="Desc"), instructions="body"), + ] + + mock_tool = mock.MagicMock() + mock_tool.name = "list_skills" + mock_tool._toolset._list_skills.return_value = skills + + with mock.patch("google.adk_community.plugins.taxonomy.taxonomy_plugin.prompt.format_skills_as_xml") as mock_format: + await plugin.before_tool_callback( + tool=mock_tool, + tool_args={}, + tool_context=context, + ) + # Check that format_skills_as_xml was called with "important" sorted first + called_skills = mock_format.call_args[0][0] + assert called_skills[0].frontmatter.name == "important" + assert called_skills[1].frontmatter.name == "normal"