-
Notifications
You must be signed in to change notification settings - Fork 0
fix: fail guardrail validation only when all rules are violated #24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,7 +40,7 @@ def evaluate_pre_deterministic_guardrail( | |
| if has_output_rule: | ||
| return GuardrailValidationResult( | ||
| result=GuardrailValidationResultType.PASSED, | ||
| reason="Guardrail contains output-dependent rules that will be evaluated during post-execution", | ||
| reason="No rules to apply for input data.", | ||
| ) | ||
| return self._evaluate_deterministic_guardrail( | ||
| input_data=input_data, | ||
|
|
@@ -66,7 +66,7 @@ def evaluate_post_deterministic_guardrail( | |
| if not has_output_rule: | ||
| return GuardrailValidationResult( | ||
| result=GuardrailValidationResultType.PASSED, | ||
| reason="Guardrail contains only input-dependent rules that were evaluated during pre-execution", | ||
| reason="No rules to apply for output data.", | ||
| ) | ||
|
|
||
| return self._evaluate_deterministic_guardrail( | ||
|
|
@@ -117,7 +117,12 @@ def _evaluate_deterministic_guardrail( | |
| output_data: dict[str, Any], | ||
| guardrail: DeterministicGuardrail, | ||
| ) -> GuardrailValidationResult: | ||
| """Evaluate deterministic guardrail rules against input and output data.""" | ||
| """Evaluate deterministic guardrail rules against input and output data. | ||
|
|
||
| Validation fails only if ALL guardrail rules are violated. | ||
| """ | ||
| validated_conditions: list[str] = [] | ||
|
|
||
| for rule in guardrail.rules: | ||
| if isinstance(rule, WordRule): | ||
| passed, reason = evaluate_word_rule(rule, input_data, output_data) | ||
|
|
@@ -132,14 +137,25 @@ def _evaluate_deterministic_guardrail( | |
| result=GuardrailValidationResultType.VALIDATION_FAILED, | ||
| reason=f"Unknown rule type: {type(rule)}", | ||
| ) | ||
|
|
||
| if not passed: | ||
| validated_conditions.append(reason) | ||
| if passed: | ||
| return GuardrailValidationResult( | ||
| result=GuardrailValidationResultType.VALIDATION_FAILED, | ||
| reason=reason or "Rule validation failed", | ||
| result=GuardrailValidationResultType.PASSED, | ||
| reason=reason, | ||
| ) | ||
|
|
||
| has_always_rule = any( | ||
| condition == "Always rule enforced" for condition in validated_conditions | ||
| ) | ||
|
|
||
| validated_conditions_str = ", ".join(validated_conditions) | ||
| final_reason = ( | ||
| "Always rule enforced" | ||
| if has_always_rule | ||
| else f"Data matched all guardrail conditions: [{validated_conditions_str}]" | ||
| ) | ||
|
|
||
| return GuardrailValidationResult( | ||
| result=GuardrailValidationResultType.PASSED, | ||
| reason="All deterministic guardrail rules passed", | ||
| result=GuardrailValidationResultType.VALIDATION_FAILED, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only when all rules are violated, then the validation should fail |
||
| reason=final_reason, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -159,24 +159,44 @@ def get_fields_from_selector( | |
| return fields | ||
|
|
||
|
|
||
| def format_guardrail_error_message( | ||
| def format_guardrail_passed_validation_result_message( | ||
| field_ref: FieldReference, | ||
| operator: str, | ||
| expected_value: str | None = None, | ||
| operator: str | None, | ||
| rule_description: str | None, | ||
| ) -> str: | ||
| """Format a guardrail error message following the standard pattern.""" | ||
| """Format a guardrail validation result message following the standard pattern.""" | ||
| source = "Input" if field_ref.source == FieldSource.INPUT else "Output" | ||
| message = f"{source} data didn't match the guardrail condition: [{field_ref.path}] comparing function [{operator}]" | ||
| if expected_value and expected_value.strip(): | ||
| message += f" [{expected_value.strip()}]" | ||
| return message | ||
|
|
||
| if rule_description: | ||
| return ( | ||
| f"{source} data didn't match the guardrail condition for field " | ||
| f"[{field_ref.path}]: {rule_description}" | ||
| ) | ||
|
|
||
| return ( | ||
| f"{source} data didn't match the guardrail condition: " | ||
| f"[{field_ref.path}] comparing function [{operator}]" | ||
| ) | ||
|
|
||
|
|
||
| def get_validated_conditions_description( | ||
| field_path: str, | ||
| operator: str | None, | ||
| rule_description: str | None, | ||
| ) -> str: | ||
| if rule_description: | ||
| return rule_description | ||
|
|
||
| return f"[{field_path}] comparing function [{operator}]" | ||
|
|
||
|
|
||
| def evaluate_word_rule( | ||
| rule: WordRule, input_data: dict[str, Any], output_data: dict[str, Any] | ||
| ) -> tuple[bool, str | None]: | ||
| ) -> tuple[bool, str]: | ||
| """Evaluate a word rule against input and output data.""" | ||
| fields = get_fields_from_selector(rule.field_selector, input_data, output_data) | ||
| operator = _humanize_guardrail_func(rule.detects_violation) or "violation check" | ||
| field_paths = ", ".join({field_ref.path for _, field_ref in fields}) | ||
|
|
||
| for field_value, field_ref in fields: | ||
| if field_value is None: | ||
|
|
@@ -197,22 +217,28 @@ def evaluate_word_rule( | |
| # If function raises an exception, treat as failure | ||
| violation_detected = True | ||
|
|
||
| if violation_detected: | ||
| operator = ( | ||
| _humanize_guardrail_func(rule.detects_violation) or "violation check" | ||
| if not violation_detected: | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If any of the fields don't violate the guardrail condition, then the validation should pass |
||
| reason = format_guardrail_passed_validation_result_message( | ||
| field_ref=field_ref, | ||
| operator=operator, | ||
| rule_description=rule.rule_description, | ||
| ) | ||
| reason = format_guardrail_error_message(field_ref, operator, None) | ||
| return False, reason | ||
| return True, reason | ||
|
|
||
| return True, "All word rule validations passed" | ||
| return False, get_validated_conditions_description( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Only when all fields violate the guardrail condition, then the validation should fail |
||
| field_path=field_paths, | ||
| operator=operator, | ||
| rule_description=rule.rule_description, | ||
| ) | ||
|
|
||
|
|
||
| def evaluate_number_rule( | ||
| rule: NumberRule, input_data: dict[str, Any], output_data: dict[str, Any] | ||
| ) -> tuple[bool, str | None]: | ||
| ) -> tuple[bool, str]: | ||
| """Evaluate a number rule against input and output data.""" | ||
| fields = get_fields_from_selector(rule.field_selector, input_data, output_data) | ||
|
|
||
| operator = _humanize_guardrail_func(rule.detects_violation) or "violation check" | ||
| field_paths = ", ".join({field_ref.path for _, field_ref in fields}) | ||
| for field_value, field_ref in fields: | ||
| if field_value is None: | ||
| continue | ||
|
|
@@ -233,24 +259,30 @@ def evaluate_number_rule( | |
| # If function raises an exception, treat as failure | ||
| violation_detected = True | ||
|
|
||
| if violation_detected: | ||
| operator = ( | ||
| _humanize_guardrail_func(rule.detects_violation) or "violation check" | ||
| if not violation_detected: | ||
| reason = format_guardrail_passed_validation_result_message( | ||
| field_ref=field_ref, | ||
| operator=operator, | ||
| rule_description=rule.rule_description, | ||
| ) | ||
| reason = format_guardrail_error_message(field_ref, operator, None) | ||
| return False, reason | ||
| return True, reason | ||
|
|
||
| return True, "All number rule validations passed" | ||
| return False, get_validated_conditions_description( | ||
| field_path=field_paths, | ||
| operator=operator, | ||
| rule_description=rule.rule_description, | ||
| ) | ||
|
|
||
|
|
||
| def evaluate_boolean_rule( | ||
| rule: BooleanRule, | ||
| input_data: dict[str, Any], | ||
| output_data: dict[str, Any], | ||
| ) -> tuple[bool, str | None]: | ||
| ) -> tuple[bool, str]: | ||
| """Evaluate a boolean rule against input and output data.""" | ||
| fields = get_fields_from_selector(rule.field_selector, input_data, output_data) | ||
|
|
||
| operator = _humanize_guardrail_func(rule.detects_violation) or "violation check" | ||
| field_paths = ", ".join({field_ref.path for _, field_ref in fields}) | ||
| for field_value, field_ref in fields: | ||
| if field_value is None: | ||
| continue | ||
|
|
@@ -270,20 +302,25 @@ def evaluate_boolean_rule( | |
| # If function raises an exception, treat as failure | ||
| violation_detected = True | ||
|
|
||
| if violation_detected: | ||
| operator = ( | ||
| _humanize_guardrail_func(rule.detects_violation) or "violation check" | ||
| if not violation_detected: | ||
| reason = format_guardrail_passed_validation_result_message( | ||
| field_ref=field_ref, | ||
| operator=operator, | ||
| rule_description=rule.rule_description, | ||
| ) | ||
| reason = format_guardrail_error_message(field_ref, operator, None) | ||
| return False, reason | ||
| return True, reason | ||
|
|
||
| return True, "All boolean rule validations passed" | ||
| return False, get_validated_conditions_description( | ||
| field_path=field_paths, | ||
| operator=operator, | ||
| rule_description=rule.rule_description, | ||
| ) | ||
|
|
||
|
|
||
| def evaluate_universal_rule( | ||
| rule: UniversalRule, | ||
| output_data: dict[str, Any], | ||
| ) -> tuple[bool, str | None]: | ||
| ) -> tuple[bool, str]: | ||
| """Evaluate a universal rule against input and output data. | ||
|
|
||
| Universal rules trigger based on the apply_to scope and execution phase: | ||
|
|
@@ -302,18 +339,18 @@ def evaluate_universal_rule( | |
| if rule.apply_to == ApplyTo.INPUT: | ||
| # INPUT: triggers in pre-execution, does not trigger in post-execution | ||
| if is_pre_execution: | ||
| return False, "Universal rule validation triggered (pre-execution, input)" | ||
| return False, "Always rule enforced" | ||
| else: | ||
| return True, "Universal rule validation passed (post-execution, input)" | ||
| return True, "No rules to apply for output data" | ||
| elif rule.apply_to == ApplyTo.OUTPUT: | ||
| # OUTPUT: does not trigger in pre-execution, triggers in post-execution | ||
| if is_pre_execution: | ||
| return True, "Universal rule validation passed (pre-execution, output)" | ||
| return True, "No rules to apply for input data" | ||
| else: | ||
| return False, "Universal rule validation triggered (post-execution, output)" | ||
| return False, "Always rule enforced" | ||
| elif rule.apply_to == ApplyTo.INPUT_AND_OUTPUT: | ||
| # INPUT_AND_OUTPUT: triggers in both phases | ||
| return False, "Universal rule validation triggered (input and output)" | ||
| return False, "Always rule enforced" | ||
| else: | ||
| return False, f"Unknown apply_to value: {rule.apply_to}" | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If any of the rules is not violated, then the validation should pass