diff --git a/ddtrace/appsec/ai_guard/_api_client.py b/ddtrace/appsec/ai_guard/_api_client.py index 2c286dcf0a7..47b103aaaa4 100644 --- a/ddtrace/appsec/ai_guard/_api_client.py +++ b/ddtrace/appsec/ai_guard/_api_client.py @@ -7,6 +7,7 @@ from typing import Optional # noqa:F401 from typing import TypedDict from typing import Union +from typing import cast from ddtrace import config from ddtrace.appsec._constants import AI_GUARD @@ -95,11 +96,19 @@ def __init__(self, message: Optional[str], status: int = 0, errors: Optional[lis class AIGuardAbortError(Exception): """Exception to abort current execution due to security policy.""" - def __init__(self, action: str, reason: str, tags: Optional[list[str]] = None, sds: Optional[list] = None): + def __init__( + self, + action: str, + reason: str, + tags: Optional[list[str]] = None, + sds: Optional[list] = None, + tag_probs: Optional[dict[str, float]] = None, + ): self.action = action self.reason = reason self.tags = tags self.sds = sds or [] + self.tag_probs = tag_probs super().__init__(f"AIGuardAbortError(action='{action}', reason='{reason}', tags='{tags}')") @@ -263,6 +272,7 @@ def evaluate(self, messages: list[Message], options: Optional[Options] = None) - tags = attributes.get("tags", []) sds_findings = attributes.get("sds_findings") or [] blocking_enabled = attributes.get("is_blocking_enabled", False) + tag_probs = attributes.get("tag_probs") except Exception as e: value = json.dumps(result, indent=2)[:500] raise AIGuardClientError( @@ -283,6 +293,8 @@ def evaluate(self, messages: list[Message], options: Optional[Options] = None) - span.set_tag(AI_GUARD.REASON_TAG, reason) if sds_findings: meta_struct.update({"sds": sds_findings}) + if tag_probs is not None: + meta_struct.update({"tag_probs": tag_probs}) else: raise AIGuardClientError( message=f"AI Guard service call failed, status: {response.status}", @@ -303,9 +315,18 @@ def evaluate(self, messages: list[Message], options: Optional[Options] = None) - _aiguard_manual_keep(root_span) if should_block: span.set_tag(AI_GUARD.BLOCKED_TAG, "true") - raise AIGuardAbortError(action=action, reason=reason, tags=tags, sds=sds_findings) + raise AIGuardAbortError( + action=action, + reason=reason, + tags=tags, + sds=sds_findings, + tag_probs=tag_probs, + ) - return Evaluation(action=action, reason=reason, tags=tags, sds=sds_findings) + evaluation = {"action": action, "reason": reason, "tags": tags, "sds": sds_findings} + if tag_probs is not None: + evaluation["tag_probs"] = tag_probs + return cast(Evaluation, evaluation) except AIGuardAbortError: raise diff --git a/releasenotes/notes/ai-guard-tag-probs-response-70b0145bfe6db459.yaml b/releasenotes/notes/ai-guard-tag-probs-response-70b0145bfe6db459.yaml new file mode 100644 index 00000000000..d30da1583a6 --- /dev/null +++ b/releasenotes/notes/ai-guard-tag-probs-response-70b0145bfe6db459.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + AI Guard response objects now include a dict field `tag_probs` with the probabilities for each tag. + diff --git a/tests/appsec/ai_guard/api/test_api_client.py b/tests/appsec/ai_guard/api/test_api_client.py index 7b35da080e7..a6efed9bb91 100644 --- a/tests/appsec/ai_guard/api/test_api_client.py +++ b/tests/appsec/ai_guard/api/test_api_client.py @@ -442,6 +442,40 @@ def test_evaluate_sds_findings_in_abort_error(mock_execute_request, telemetry_mo assert exc_info.value.sds == sds_findings +@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request") +def test_evaluate_tag_probabilities_in_result_and_meta_struct(mock_execute_request, ai_guard_client, test_spans): + """Test that tag probabilities are added to the SDK response and span meta_struct.""" + tag_probs = {"jailbreak": 0.91, "prompt_injection": 0.42} + mock_execute_request.return_value = mock_evaluate_response( + "DENY", reason="Nope", tags=["jailbreak"], block=False, tag_probs=tag_probs + ) + + result = ai_guard_client.evaluate(PROMPT, Options(block=False)) + + assert result["tag_probs"] == tag_probs + + expected_meta_struct = {"messages": PROMPT, "attack_categories": ["jailbreak"], "tag_probs": tag_probs} + assert_ai_guard_span( + test_spans, + {"ai_guard.target": "prompt", "ai_guard.action": "DENY", "ai_guard.reason": "Nope"}, + expected_meta_struct, + ) + + +@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request") +def test_evaluate_tag_probabilities_in_abort_error(mock_execute_request, ai_guard_client): + """Test that tag probabilities are included in AIGuardAbortError.""" + tag_probs = {"jailbreak": 0.91} + mock_execute_request.return_value = mock_evaluate_response( + "ABORT", reason="blocked", tags=["jailbreak"], tag_probs=tag_probs + ) + + with pytest.raises(AIGuardAbortError) as exc_info: + ai_guard_client.evaluate(PROMPT, Options(block=True)) + + assert exc_info.value.tag_probs == tag_probs + + @patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request") def test_meta_attribute(mock_execute_request): messages = [Message(role="user", content="What is your name?")] diff --git a/tests/appsec/ai_guard/utils.py b/tests/appsec/ai_guard/utils.py index c2c8290a0f9..d2ded512cc7 100644 --- a/tests/appsec/ai_guard/utils.py +++ b/tests/appsec/ai_guard/utils.py @@ -37,6 +37,7 @@ def assert_ai_guard_span( assert tag in span.get_tags(), f"Missing {tag} from spans tags" assert span.get_tag(tag) == value, f"Wrong value {span.get_tag(tag)}, expected {value}" struct = span._get_struct_tag(AI_GUARD.TAG) + assert struct is not None for meta, value in meta_struct.items(): assert meta in struct.keys(), f"Missing {meta} from meta_struct keys" assert struct[meta] == value, f"Wrong value {struct[meta]}, expected {value}" @@ -45,9 +46,10 @@ def assert_ai_guard_span( def mock_evaluate_response( action: str, reason: str = "", - tags: list[str] = None, + tags: Optional[list[str]] = None, block: bool = True, - sds_findings: list = None, + sds_findings: Optional[list[Any]] = None, + tag_probs: Optional[dict[str, float]] = None, ) -> Mock: mock_response = Mock() mock_response.status = 200 @@ -59,6 +61,8 @@ def mock_evaluate_response( } if sds_findings is not None: attributes["sds_findings"] = sds_findings + if tag_probs is not None: + attributes["tag_probs"] = tag_probs mock_response.get_json.return_value = {"data": {"attributes": attributes}} return mock_response