Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions ddtrace/appsec/ai_guard/_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Optional # noqa:F401
from typing import TypedDict
from typing import Union
from typing import cast

from ddtrace import config
from ddtrace.appsec._constants import AI_GUARD
Expand Down Expand Up @@ -95,11 +96,19 @@ def __init__(self, message: Optional[str], status: int = 0, errors: Optional[lis
class AIGuardAbortError(Exception):
"""Exception to abort current execution due to security policy."""

def __init__(self, action: str, reason: str, tags: Optional[list[str]] = None, sds: Optional[list] = None):
def __init__(
self,
action: str,
reason: str,
tags: Optional[list[str]] = None,
sds: Optional[list] = None,
tag_probs: Optional[dict[str, float]] = None,
):
self.action = action
self.reason = reason
self.tags = tags
self.sds = sds or []
self.tag_probs = tag_probs
super().__init__(f"AIGuardAbortError(action='{action}', reason='{reason}', tags='{tags}')")


Expand Down Expand Up @@ -263,6 +272,7 @@ def evaluate(self, messages: list[Message], options: Optional[Options] = None) -
tags = attributes.get("tags", [])
sds_findings = attributes.get("sds_findings") or []
blocking_enabled = attributes.get("is_blocking_enabled", False)
tag_probs = attributes.get("tag_probs")
except Exception as e:
value = json.dumps(result, indent=2)[:500]
raise AIGuardClientError(
Expand All @@ -283,6 +293,8 @@ def evaluate(self, messages: list[Message], options: Optional[Options] = None) -
span.set_tag(AI_GUARD.REASON_TAG, reason)
if sds_findings:
meta_struct.update({"sds": sds_findings})
if tag_probs is not None:
meta_struct.update({"tag_probs": tag_probs})
else:
raise AIGuardClientError(
message=f"AI Guard service call failed, status: {response.status}",
Expand All @@ -303,9 +315,18 @@ def evaluate(self, messages: list[Message], options: Optional[Options] = None) -
_aiguard_manual_keep(root_span)
if should_block:
span.set_tag(AI_GUARD.BLOCKED_TAG, "true")
raise AIGuardAbortError(action=action, reason=reason, tags=tags, sds=sds_findings)
raise AIGuardAbortError(
action=action,
reason=reason,
tags=tags,
sds=sds_findings,
tag_probs=tag_probs,
)

return Evaluation(action=action, reason=reason, tags=tags, sds=sds_findings)
evaluation = {"action": action, "reason": reason, "tags": tags, "sds": sds_findings}
if tag_probs is not None:
evaluation["tag_probs"] = tag_probs
return cast(Evaluation, evaluation)

except AIGuardAbortError:
raise
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
features:
- |
AI Guard response objects now include a dict field `tag_probs` with the probabilities for each tag.

34 changes: 34 additions & 0 deletions tests/appsec/ai_guard/api/test_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,40 @@ def test_evaluate_sds_findings_in_abort_error(mock_execute_request, telemetry_mo
assert exc_info.value.sds == sds_findings


@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
def test_evaluate_tag_probabilities_in_result_and_meta_struct(mock_execute_request, ai_guard_client, test_spans):
"""Test that tag probabilities are added to the SDK response and span meta_struct."""
tag_probs = {"jailbreak": 0.91, "prompt_injection": 0.42}
mock_execute_request.return_value = mock_evaluate_response(
"DENY", reason="Nope", tags=["jailbreak"], block=False, tag_probs=tag_probs
)

result = ai_guard_client.evaluate(PROMPT, Options(block=False))

assert result["tag_probs"] == tag_probs

expected_meta_struct = {"messages": PROMPT, "attack_categories": ["jailbreak"], "tag_probs": tag_probs}
assert_ai_guard_span(
test_spans,
{"ai_guard.target": "prompt", "ai_guard.action": "DENY", "ai_guard.reason": "Nope"},
expected_meta_struct,
)


@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
def test_evaluate_tag_probabilities_in_abort_error(mock_execute_request, ai_guard_client):
"""Test that tag probabilities are included in AIGuardAbortError."""
tag_probs = {"jailbreak": 0.91}
mock_execute_request.return_value = mock_evaluate_response(
"ABORT", reason="blocked", tags=["jailbreak"], tag_probs=tag_probs
)

with pytest.raises(AIGuardAbortError) as exc_info:
ai_guard_client.evaluate(PROMPT, Options(block=True))

assert exc_info.value.tag_probs == tag_probs


@patch("ddtrace.appsec.ai_guard._api_client.AIGuardClient._execute_request")
def test_meta_attribute(mock_execute_request):
messages = [Message(role="user", content="What is your name?")]
Expand Down
8 changes: 6 additions & 2 deletions tests/appsec/ai_guard/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def assert_ai_guard_span(
assert tag in span.get_tags(), f"Missing {tag} from spans tags"
assert span.get_tag(tag) == value, f"Wrong value {span.get_tag(tag)}, expected {value}"
struct = span._get_struct_tag(AI_GUARD.TAG)
assert struct is not None
for meta, value in meta_struct.items():
assert meta in struct.keys(), f"Missing {meta} from meta_struct keys"
assert struct[meta] == value, f"Wrong value {struct[meta]}, expected {value}"
Expand All @@ -45,9 +46,10 @@ def assert_ai_guard_span(
def mock_evaluate_response(
action: str,
reason: str = "",
tags: list[str] = None,
tags: Optional[list[str]] = None,
block: bool = True,
sds_findings: list = None,
sds_findings: Optional[list[Any]] = None,
tag_probs: Optional[dict[str, float]] = None,
) -> Mock:
mock_response = Mock()
mock_response.status = 200
Expand All @@ -59,6 +61,8 @@ def mock_evaluate_response(
}
if sds_findings is not None:
attributes["sds_findings"] = sds_findings
if tag_probs is not None:
attributes["tag_probs"] = tag_probs
mock_response.get_json.return_value = {"data": {"attributes": attributes}}
return mock_response

Expand Down
Loading