diff --git a/src/agents/engine_agent/agent.py b/src/agents/engine_agent/agent.py index f64c307..5fffa97 100644 --- a/src/agents/engine_agent/agent.py +++ b/src/agents/engine_agent/agent.py @@ -111,7 +111,16 @@ async def execute(self, event_type: str, event_data: dict[str, Any], rules: list logger.info(f"🔧 Rule Engine evaluation completed in {execution_time:.2f}s") # Extract violations from result - violations = result.violations if hasattr(result, "violations") else [] + # violations = result.violations if hasattr(result, "violations") else [] + + # Extract violations from result (EngineState) + violations = [] + if hasattr(result, "violations"): + violations = result.violations + elif isinstance(result, dict) and "violations" in result: + violations = result["violations"] + + logger.info(f"🔧 Rule Engine extracted {len(violations)} violations from state") logger.info(f"🔧 Rule Engine extracted {len(violations)} violations") diff --git a/src/event_processors/push.py b/src/event_processors/push.py index 92d67f3..b197b38 100644 --- a/src/event_processors/push.py +++ b/src/event_processors/push.py @@ -43,6 +43,7 @@ async def process(self, task: Task) -> ProcessingResult: "head_commit": payload.get("head_commit", {}), "before": payload.get("before"), "after": payload.get("after"), + "forced": payload.get("forced", False), }, "triggering_user": {"login": payload.get("pusher", {}).get("name")}, "repository": payload.get("repository", {}), @@ -68,7 +69,14 @@ async def process(self, task: Task) -> ProcessingResult: # Run agentic analysis using the instance result = await self.engine_agent.execute(event_type="push", event_data=event_data, rules=formatted_rules) - violations = result.data.get("violations", []) + violations: list[dict[str, Any]] = [] + + try: + eval_result = result.data.get("evaluation_result") if result and result.data else None + if eval_result and hasattr(eval_result, "violations"): + violations = [v.__dict__ if hasattr(v, "__dict__") else v for v in eval_result.violations] + except Exception as e: + logger.error(f"Error extracting violations from engine result: {e}") processing_time = int((time.time() - start_time) * 1000) @@ -88,7 +96,9 @@ async def process(self, task: Task) -> ProcessingResult: if violations: logger.warning("🚨 VIOLATION SUMMARY:") for i, violation in enumerate(violations, 1): - logger.warning(f" {i}. {violation.get('rule', 'Unknown')} ({violation.get('severity', 'medium')})") + # logger.warning(f" {i}. {violation.get('rule', 'Unknown')} ({violation.get('severity', 'medium')})") + rule_name = violation.get("rule", violation.get("rule_description", "Unknown")) + logger.warning(f" {i}. {rule_name} ({violation.get('severity', 'medium')})") logger.warning(f" {violation.get('message', '')}") else: logger.info("✅ All rules passed - no violations detected!") diff --git a/src/rules/validators.py b/src/rules/validators.py index b04ac19..1fd83fc 100644 --- a/src/rules/validators.py +++ b/src/rules/validators.py @@ -75,6 +75,31 @@ async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> b return author_login in team_memberships.get(team_name, []) +class CommitCountLimitCondition(Condition): + """Validates that a push does not exceed a maximum number of commits.""" + + name = "commit_count_limit" + description = "Validates that a push does not exceed a maximum number of commits" + parameter_patterns = ["max_commits"] + event_types = ["push"] + examples = [{"max_commits": 5}, {"max_commits": 2}] + + async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> bool: + max_commits = int(parameters.get("max_commits", 2)) + push_data = event.get("push", {}) + commits = push_data.get("commits", []) + + commit_count = len(commits) + logger.debug(f"CommitCountLimitCondition: found {commit_count} commits, max allowed={max_commits}") + + if commit_count > max_commits: + logger.debug("CommitCountLimitCondition: VIOLATION - too many commits in push") + return False + + logger.debug("CommitCountLimitCondition: PASS - within commit limit") + return True + + class FilePatternCondition(Condition): """Validates if files in the event match or don't match a pattern.""" @@ -454,11 +479,44 @@ class AllowForcePushCondition(Condition): examples = [{"allow_force_push": False}, {"allow_force_push": True}] async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> bool: - # allow_force_push = parameters.get("allow_force_push", False) + push_data = event.get("push", {}) + if not push_data: + return True # No violation if we can't check - # This would typically check if the push was a force push - # For now, return True (no violation) as placeholder - return True + allow_force_push = parameters.get("allow_force_push", False) + if allow_force_push: + return True # No violation if force pushes are allowed + + is_force_push = bool(push_data.get("forced", False)) + if not is_force_push: + return True # No violation if not a force push + + ref = push_data.get("ref") + if not ref: + logger.debug("AllowForcePushCondition: No ref found in push data") + return True # No violation if we can't determine branch + + is_force_push = bool(push_data.get("forced", False)) + if not is_force_push: + return True + + # Extract branch name from ref (e.g., "refs/heads/main" -> "main") + branch_name = ref.replace("refs/heads/", "") if ref.startswith("refs/heads/") else ref + + protected_branches = parameters.get("protected_branches", []) + if not protected_branches: + logger.debug( + f"AllowForcePushCondition: Force push to {branch_name} allowed (no protected branches specified so we select all branches as protected)" + ) + return False # No protected branches means all branches are protected + + is_protected_branch = branch_name in protected_branches + if is_protected_branch: + logger.debug(f"AllowForcePushCondition: Force push to protected branch {branch_name} blocked") + return False # Violation + else: + logger.debug(f"AllowForcePushCondition: Force push to unprotected branch {branch_name} allowed") + return True # return True (no violation) as placeholder class ProtectedBranchesCondition(Condition): @@ -762,6 +820,7 @@ def _is_new_contributor(self, username: str) -> bool: "required_checks": RequiredChecksCondition(), "code_owners": CodeOwnersCondition(), "past_contributor_approval": PastContributorApprovalCondition(), + "commit_count_limit": CommitCountLimitCondition(), }