diff --git a/Improvement.md b/Improvement.md new file mode 100644 index 0000000..adfe4ce --- /dev/null +++ b/Improvement.md @@ -0,0 +1,405 @@ +# Watchflow Improvements + +## I THINK these are ISSUES (Must Fix Soon) + +### 1. Agents Don't Talk to Each Other +**What it means:** Watchflow has multiple AI agents (like workers), but they work alone. They don't coordinate. + +**Real-world example:** Imagine you have 3 security guards, but they never talk. Guard 1 sees something suspicious, but Guard 2 doesn't know about it. They can't work together to solve complex problems. + +**Why it matters:** For complex rules that need multiple checks, the agents can't combine their knowledge. They each do their own thing independently. + +**What needs to happen:** Make agents work together. When one agent finds something, others should know. They should be able to discuss and make better decisions together. + +--- + +### 2. Same Violations Reported Multiple Times +**What it means:** If someone breaks a rule, Watchflow might tell you about it 5 times instead of once. + +**Real-world example:** Like getting 5 emails about the same meeting reminder. Annoying, right? + +**Why it matters:** Developers get spammed with the same violation messages. It's noise, not useful information. + +**What needs to happen:** Track what violations have already been reported. If we've seen this exact violation before, don't report it again (or at least mark it as "already reported"). + +--- + +### 3. System Doesn't Learn from Mistakes +**What it means:** Watchflow makes the same wrong decisions over and over. It doesn't learn. + +**Real-world example:** Like a teacher who keeps giving the same wrong answer to students, never learning from feedback. + +**Why it matters:** If Watchflow incorrectly blocks a PR (false positive), it will keep doing it. If it misses a real violation (false negative), it keeps missing it. No improvement over time. + +**What needs to happen:** When developers say "this was wrong" or "this was right", Watchflow should remember and adjust. Over time, it gets smarter. + +--- + +### 4. Error Handling is Confusing +**What it means:** When something goes wrong, the system sometimes says "everything is fine" instead of "something broke." + +**Real-world example:** Your car's check engine light is broken, so it never lights up even when there's a problem. You think everything is fine, but it's not. + +**Why it matters:** If a validator (rule checker) crashes, Watchflow might say "no violations found" when really it just couldn't check. This is dangerous - it looks like everything passed, but actually we don't know. + +**What needs to happen:** Clearly distinguish between: +- ✅ "Rule passed - everything is good" +- ❌ "Rule failed - violation found" +- ⚠️ "Error - couldn't check, need to investigate" + +--- + +## TECHNICAL DEBT (Code Quality Issues) + +### 5. Abstract Classes Use "Pass" Instead of Proper Errors +**What it means:** In programming, there are "abstract" classes - templates that other classes must fill in. Currently, if someone forgets to fill in a required part, the code just says "pass" (do nothing) instead of raising an error. + +**Real-world example:** Like a job application form where you can skip required fields and it still accepts it, instead of saying "you must fill this out." + +**Why it matters:** If a developer forgets to implement something, the code will silently fail later, making it hard to debug. + +**What needs to happen:** Change `pass` to `raise NotImplementedError` so if someone forgets to implement something, they get an immediate, clear error message. + +--- + +### 6. Not Enough Tests +**What it means:** Many parts of the code don't have automated tests to verify they work correctly. + +**Real-world example:** Like a car manufacturer that only tests the engine, but never tests the brakes, steering, or lights. + +**Why it matters:** When you change code, you don't know if you broke something. Tests catch bugs before they reach production. + +**What needs to happen:** Write tests for: +- Acknowledgment agent (handles when developers say "I know about this violation") +- Repository analysis agent (analyzes repos to suggest rules) +- Deployment processors (handles deployment events) +- End-to-end workflows (test the whole process from PR to decision) + +--- + +### 7. Can't Combine Rules with AND/OR Logic +**What it means:** You can't create complex rules like "Block if (author is X AND file is /auth) OR (author is Y AND it's weekend)" + +**Real-world example:** Like a security system that can check "is door locked?" OR "is window closed?" but can't check "is door locked AND window closed at the same time?" + +**Why it matters:** Real-world policies are complex. You might want: "Prevent John from modifying the authentication code, unless it's an emergency and he has approval." That needs multiple conditions combined. + +**What needs to happen:** Add support for combining validators with AND, OR, and NOT operators. Allow nested conditions. + +--- + +## PERFORMANCE & SCALABILITY + +### 8. Worker Count is Hardcoded +**What it means:** The system uses exactly 5 workers (background processes) to handle tasks. This number is written in code, not configurable. + +**Real-world example:** Like a restaurant that always has exactly 5 waiters, even if it's super busy (needs 10) or empty (needs 1). + +**Why it matters:** Can't scale up when busy, wastes resources when idle. + +**What needs to happen:** Make worker count configurable via environment variable. Allow auto-scaling based on load. + +--- + +### 9. Caching Strategy is Unclear +**What it means:** The system caches (stores) some data to avoid re-fetching it, but we don't know: +- How long data is cached +- When cache is cleared +- How much memory is used + +**Real-world example:** Like a library that caches books, but you don't know how long books stay in cache, when they're removed, or if the cache is full. + +**Why it matters:** Without understanding caching, you can't optimize performance or debug issues. + +**What needs to happen:** Document the caching strategy. Make cache settings (TTL, size limits) configurable. + +--- + +### 10. AI Costs Not Optimized +**What it means:** Every time Watchflow uses AI (LLM), it costs money. There's no clear strategy to reduce these costs. + +**Real-world example:** Like making expensive phone calls every time you need information, instead of writing it down and reusing it. + +**Why it matters:** AI calls are expensive. If you're checking 100 PRs per day, costs add up quickly. + +**What needs to happen:** +- Track how much each AI call costs +- Cache similar rule evaluations (if we checked this before, reuse the result) +- Batch multiple rules together when possible + +--- + +## MONITORING & OBSERVABILITY + +### 11. No Metrics or Monitoring Dashboard +**What it means:** Documentation says "Prometheus and Grafana" but they're not actually implemented. + +**Real-world example:** Like a car with no dashboard - you can't see speed, fuel level, or if the engine is overheating. + +**Why it matters:** In production, you need to know: +- Is the system healthy? +- How fast are responses? +- How many errors are happening? +- How much is this costing? + +**What needs to happen:** +- Add Prometheus metrics endpoint (exposes metrics) +- Create Grafana dashboards (visualize metrics) +- Track: response times, error rates, AI costs, cache performance + +--- + +### 12. Logging is Messy +**What it means:** Lots of debug logs everywhere, but no clear structure. Hard to find what you need. + +**Real-world example:** Like a diary with no dates, no organization, just random thoughts scattered everywhere. + +**Why it matters:** When something breaks in production, you need to find the relevant logs quickly. Too much noise makes it hard. + +**What needs to happen:** +- Standardize log levels (INFO for normal operations, DEBUG for development) +- Use structured logging (JSON format, easier to search) +- Add correlation IDs (track one request across multiple log entries) + +--- + +## SECURITY & COMPLIANCE + +### 13. Audit Trail Not Clear +**What it means:** Documentation says "complete audit trail" but it's unclear where logs are stored, how long they're kept, or how to search them. + +**Real-world example:** Like a security camera system that records everything, but you don't know where the recordings are stored, how long they're kept, or how to find a specific event. + +**Why it matters:** For compliance (SOC2, GDPR, etc.), you need to prove what decisions were made and why. You need to be able to search and retrieve audit logs. + +**What needs to happen:** +- Implement audit log storage (database or file-based) +- Define retention policy (how long to keep logs) +- Add search/query API for audit logs + +--- + +### 14. Secrets Stored in Environment Variables +**What it means:** GitHub App private keys are stored as base64-encoded environment variables. + +**Real-world example:** Like writing your password on a sticky note and putting it on your desk. It works, but not secure. + +**Why it matters:** If environment variables are logged, exposed in error messages, or accessed by unauthorized people, secrets are compromised. + +**What needs to happen:** +- Use a secret management service (AWS Secrets Manager, HashiCorp Vault) +- Support secret rotation (change keys periodically) +- Never log secrets, even in debug mode + +--- + +## ARCHITECTURE IMPROVEMENTS + +### 15. Decision Orchestrator Missing +**What it means:** Documentation describes a "Decision Orchestrator" that combines rule-based and AI-based decisions, but it doesn't actually exist in code. + +**Real-world example:** Like a recipe that says "combine ingredients in the mixer" but you don't have a mixer - you're just mixing by hand inconsistently. + +**Why it matters:** Without a central orchestrator, decisions are made inconsistently. Sometimes rules win, sometimes AI wins, but there's no smart way to combine them. + +**What needs to happen:** Build the Decision Orchestrator that: +- Takes input from both rule engine and AI agents +- Intelligently combines them (maybe rules for simple cases, AI for complex) +- Handles conflicts (what if rule says "pass" but AI says "fail"?) + +--- + +### 16. Only GitHub Supported +**What it means:** Watchflow only works with GitHub. Documentation mentions GitLab and Azure DevOps as future features, but they're not implemented. + +**Real-world example:** Like a phone that only works with one carrier, when you could support multiple carriers and reach more customers. + +**Why it matters:** Limits market reach. Many companies use GitLab or Azure DevOps. + +**What needs to happen:** +- Abstract the provider interface (make it easy to add new platforms) +- Implement GitLab support +- Implement Azure DevOps support + +--- + +### 17. No Specialized Agents +**What it means:** All agents are general-purpose. There are no specialized agents for security, compliance, or performance. + +**Real-world example:** Like having general doctors but no specialists. A general doctor can help, but a cardiologist is better for heart problems. + +**Why it matters:** Specialized agents would be better at their specific domains. A security agent would understand security patterns better than a general agent. + +**What needs to happen:** +- Create security-focused agent (specializes in security rules) +- Create compliance-focused agent (specializes in compliance rules) +- Create performance-focused agent (specializes in performance rules) + +--- + +## DOCUMENTATION & DEVELOPER EXPERIENCE + +### 18. API Documentation is Basic +**What it means:** FastAPI auto-generates API docs, but they're missing examples, error codes, and rate limiting info. + +**Real-world example:** Like a product manual that lists features but doesn't show how to use them or what to do when something goes wrong. + +**Why it matters:** Developers using the API need clear examples and error handling guidance. + +**What needs to happen:** Enhance API documentation with: +- Example requests and responses +- All possible error codes and what they mean +- Rate limiting information (how many requests per minute) + +--- + +### 19. Configuration is Scattered +**What it means:** Configuration options are spread across multiple files. Hard to know all available options. + +**Real-world example:** Like settings for your phone scattered across 10 different menus instead of one settings page. + +**Why it matters:** Hard to configure the system. You might miss important settings. + +**What needs to happen:** +- Create comprehensive configuration guide +- Add configuration validation (warn if settings are wrong) +- Provide examples for common scenarios + +--- + +## TESTING & QUALITY + +### 20. No Load Testing +**What it means:** No tests to see how the system performs under heavy load (many PRs at once). + +**Real-world example:** Like opening a restaurant without testing if the kitchen can handle a full house. + +**Why it matters:** In production, you might get 100 PRs at once. Will the system handle it? Will it crash? Slow down? We don't know. + +**What needs to happen:** +- Add load testing with Locust (mentioned in docs but not implemented) +- Define performance SLAs (e.g., "must respond in < 2 seconds") +- Add performance regression tests (make sure new code doesn't slow things down) + +--- + +### 21. No Real GitHub Integration Tests +**What it means:** All tests use mocks (fake GitHub API). Never tested against real GitHub. + +**Real-world example:** Like practicing driving in a parking lot but never on real roads. It's good practice, but real conditions are different. + +**Why it matters:** Real GitHub API might behave differently than mocks. API might change. We need to know it actually works. + +**What needs to happen:** +- Add optional integration tests with real GitHub (behind a flag, so they don't run in CI by default) +- Use a test GitHub App for CI/CD +- Test against GitHub API changes + +--- + +## FEATURE ENHANCEMENTS + +### 22. No Custom Agent Framework +**What it means:** Users can't create their own custom agents. They're stuck with what Watchflow provides. + +**Real-world example:** Like a LEGO set with fixed pieces - you can only build what the instructions say, not your own creations. + +**Why it matters:** Different companies have different needs. They should be able to create custom agents for their specific use cases. + +**What needs to happen:** +- Create agent plugin system (allow users to add custom agents) +- Provide agent development SDK (tools to build agents) +- Add examples of custom agents + +--- + +### 23. No Analytics Dashboard +**What it means:** Documentation mentions analytics, but there's no dashboard to see: +- Which rules are violated most often? +- How many false positives? +- How effective are rules? + +**Real-world example:** Like a business with no sales reports. You don't know what's working and what's not. + +**Why it matters:** Can't measure effectiveness. Can't improve. Can't show value to management. + +**What needs to happen:** +- Build analytics dashboard +- Track: violation rates, acknowledgment patterns, false positive rates +- Show trends over time + +--- + +### 24. No Rule Versioning +**What it means:** When you change a rule, there's no history. Can't see what changed, when, or rollback if something breaks. + +**Real-world example:** Like editing a document without "track changes" - you can't see what you changed or go back. + +**Why it matters:** If a rule change breaks things, you need to rollback quickly. You also need to see rule history for compliance. + +**What needs to happen:** +- Add rule versioning (track all changes) +- Add rollback capability (revert to previous version) +- Track who changed what and when + +--- + +## BUGS & EDGE CASES + +### 25. Validator Errors Treated as "Passed" +**What it means:** If a validator crashes, the system says "no violation found" instead of "error occurred." + +**Real-world example:** Like a smoke detector that breaks and just stays silent. You think everything is fine, but it's actually broken. + +**Why it matters:** Dangerous - looks like rules passed, but actually we don't know. + +**What needs to happen:** Return error state instead of treating as "passed." Maybe block PR to be safe, or retry. + +--- + +### 26. LLM Response Parsing is Fragile +**What it means:** When AI returns a response, sometimes it's malformed (truncated JSON). The fallback logic is complex and might miss violations. + +**Real-world example:** Like a translator that sometimes gets cut off mid-sentence, and you have to guess what they meant. + +**Why it matters:** Might miss real violations if parsing fails. + +**What needs to happen:** Improve error handling and retry logic for malformed responses. + +--- + +### 27. Deployment Scheduler Started Twice +**What it means:** Code starts the deployment scheduler twice (line 44 and line 68). It's safe (has a check), but redundant and confusing. + +**Real-world example:** Like pressing the "start" button twice on your car - it's already running, so nothing happens, but why press it twice? + +**Why it matters:** Confusing code. Future developers might think it's intentional and add more redundant code. + +**What needs to happen:** Remove one of the calls. Keep the one with the safety check. + +--- + +## PRIORITY SUMMARY + +### CRITICAL (Fix First) +1. **Agent Coordination** - Make agents work together +2. **Regression Prevention** - Stop duplicate violation reports +3. **Error Handling** - Don't hide errors as "passed" +4. **Test Coverage** - Test all the things + +### HIGH PRIORITY (Fix Soon) +5. **Learning Agent** - Learn from feedback +6. **Decision Orchestrator** - Smart decision combining +7. **Monitoring** - Know what's happening +8. **Validator Combinations** - Support complex rules + +### MEDIUM PRIORITY (Nice to Have) +9. **Enterprise Policies** - More rule types +10. **Cross-Platform** - Support GitLab/Azure DevOps +11. **Custom Agents** - Let users build their own +12. **Analytics** - Measure effectiveness + +### LOW PRIORITY (Future) +13. **Agent Specialization** - Specialized agents +14. **Rule Versioning** - Track rule changes +15. **Performance** - Optimize costs and speed diff --git a/src/agents/base.py b/src/agents/base.py index 44e6617..fde16bd 100644 --- a/src/agents/base.py +++ b/src/agents/base.py @@ -54,7 +54,7 @@ def __init__(self, max_retries: int = 3, retry_delay: float = 1.0, agent_name: s @abstractmethod def _build_graph(self): """Build the LangGraph workflow for this agent.""" - pass + raise NotImplementedError("Subclasses must implement _build_graph") async def _retry_structured_output(self, llm, output_model, prompt, **kwargs) -> T: """ @@ -110,4 +110,4 @@ async def _execute_with_timeout(self, coro, timeout: float = 30.0): @abstractmethod async def execute(self, **kwargs) -> AgentResult: """Execute the agent with given parameters.""" - pass + raise NotImplementedError("Subclasses must implement execute") diff --git a/src/core/config/cache_config.py b/src/core/config/cache_config.py new file mode 100644 index 0000000..d278c2e --- /dev/null +++ b/src/core/config/cache_config.py @@ -0,0 +1,25 @@ +""" +Cache configuration. + +Defines configurable settings for caching strategy including TTL, size limits, +and cache behavior. +""" + +from dataclasses import dataclass + + +@dataclass +class CacheConfig: + """Cache configuration.""" + + # Global cache settings (used by recommendations API and other module-level caches) + global_maxsize: int = 1024 + global_ttl: int = 3600 # 1 hour in seconds + + # Default cache settings for new AsyncCache instances + default_maxsize: int = 100 + default_ttl: int = 3600 # 1 hour in seconds + + # Cache behavior settings + enable_cache: bool = True # Master switch to disable all caching + enable_metrics: bool = False # Track cache hit/miss rates (future feature) diff --git a/src/core/config/settings.py b/src/core/config/settings.py index add5f3b..4e01f1b 100644 --- a/src/core/config/settings.py +++ b/src/core/config/settings.py @@ -7,6 +7,7 @@ from dotenv import load_dotenv +from src.core.config.cache_config import CacheConfig from src.core.config.cors_config import CORSConfig from src.core.config.github_config import GitHubConfig from src.core.config.langsmith_config import LangSmithConfig @@ -107,6 +108,16 @@ def __init__(self): file_path=os.getenv("LOG_FILE_PATH"), ) + # Cache configuration + self.cache = CacheConfig( + global_maxsize=int(os.getenv("CACHE_GLOBAL_MAXSIZE", "1024")), + global_ttl=int(os.getenv("CACHE_GLOBAL_TTL", "3600")), + default_maxsize=int(os.getenv("CACHE_DEFAULT_MAXSIZE", "100")), + default_ttl=int(os.getenv("CACHE_DEFAULT_TTL", "3600")), + enable_cache=os.getenv("CACHE_ENABLE", "true").lower() == "true", + enable_metrics=os.getenv("CACHE_ENABLE_METRICS", "false").lower() == "true", + ) + # Development settings self.debug = os.getenv("DEBUG", "false").lower() == "true" self.environment = os.getenv("ENVIRONMENT", "development") diff --git a/src/core/utils/__init__.py b/src/core/utils/__init__.py index 03af47c..cc34142 100644 --- a/src/core/utils/__init__.py +++ b/src/core/utils/__init__.py @@ -1,5 +1,5 @@ """ -Shared utilities for retry, caching, logging, metrics, and timeout handling. +Shared utilities for retry, caching, logging, metrics, timeout handling, and violation tracking. This module provides reusable utilities that can be used across the codebase to avoid code duplication and ensure consistent behavior. @@ -10,6 +10,7 @@ from src.core.utils.metrics import track_metrics from src.core.utils.retry import retry_with_backoff from src.core.utils.timeout import execute_with_timeout +from src.core.utils.violation_tracker import ViolationTracker, get_violation_tracker __all__ = [ "AsyncCache", @@ -18,4 +19,6 @@ "track_metrics", "retry_with_backoff", "execute_with_timeout", + "ViolationTracker", + "get_violation_tracker", ] diff --git a/src/core/utils/caching.py b/src/core/utils/caching.py index 6c0ce3a..91b4d38 100644 --- a/src/core/utils/caching.py +++ b/src/core/utils/caching.py @@ -1,8 +1,27 @@ """ Caching utilities for async operations. -Provides async-friendly caching with TTL support and decorators +Provides async-friendly caching with TTL (Time To Live) support and decorators for caching function results. + +Caching Strategy +---------------- +This module implements a caching strategy with the following features: + +1. **TTL (Time To Live)**: Automatic expiration of cached entries after a configurable time period +2. **Eviction Policy**: Uses LRU (Least Recently Used) eviction when cache reaches max size +3. **Configuration**: Configurable via environment variables and CacheConfig +4. **Async Support**: Designed for async operations with proper async/await support + +Configuration +------------- +Cache behavior can be configured via environment variables: +- CACHE_GLOBAL_MAXSIZE: Maximum number of entries in global cache (default: 1024) +- CACHE_GLOBAL_TTL: Global cache TTL in seconds (default: 3600) +- CACHE_DEFAULT_MAXSIZE: Default max size for new cache instances (default: 100) +- CACHE_DEFAULT_TTL: Default TTL for new cache instances (default: 3600) +- CACHE_ENABLE: Master switch to enable/disable caching (default: true) + """ import logging @@ -15,6 +34,19 @@ logger = logging.getLogger(__name__) +# Lazy import to avoid circular dependency +_config: Any = None + + +def _get_config(): + """Lazy load config to avoid circular dependencies.""" + global _config + if _config is None: + from src.core.config.settings import config + + _config = config + return _config + class AsyncCache: """ @@ -50,6 +82,10 @@ def get(self, key: str) -> Any | None: Returns: Cached value or None if not found or expired + + Note: + Expired entries are removed lazily (on access) to avoid + background cleanup overhead. """ if key not in self._cache: return None @@ -72,9 +108,13 @@ def set(self, key: str, value: Any) -> None: Args: key: Cache key value: Value to cache + + Note: + Uses LRU (Least Recently Used) eviction policy when cache is full. + The oldest entry (by timestamp) is removed to make room. """ if len(self._cache) >= self.maxsize: - # Remove oldest entry + # Remove oldest entry (LRU eviction) oldest_key = min( self._cache.keys(), key=lambda k: self._cache[k].get("timestamp", 0), @@ -115,31 +155,78 @@ def size(self) -> int: return len(self._cache) -# Simple module-level cache used by recommendations API -_GLOBAL_CACHE = AsyncCache(maxsize=1024, ttl=3600) +# Global module-level cache used by recommendations API and other shared operations +# Initialized lazily with config values to avoid circular dependencies +_GLOBAL_CACHE: AsyncCache | None = None + + +def _get_global_cache() -> AsyncCache: + """ + Get or initialize the global cache with config values. + + Returns: + Global AsyncCache instance configured from settings + """ + global _GLOBAL_CACHE + if _GLOBAL_CACHE is None: + config = _get_config() + _GLOBAL_CACHE = AsyncCache( + maxsize=config.cache.global_maxsize, + ttl=config.cache.global_ttl, + ) + return _GLOBAL_CACHE async def get_cache(key: str) -> Any | None: """ Async helper to fetch from the module-level cache. + + Args: + key: Cache key to retrieve + + Returns: + Cached value or None if not found, expired, or caching disabled + + Note: + Respects CACHE_ENABLE setting - returns None if caching is disabled. """ - return _GLOBAL_CACHE.get(key) + config = _get_config() + if not config.cache.enable_cache: + return None + return _get_global_cache().get(key) async def set_cache(key: str, value: Any, ttl: int | None = None) -> None: """ Async helper to store into the module-level cache. + + Args: + key: Cache key + value: Value to cache + ttl: Optional TTL override (applies to entire cache, not just this entry) + + Note: + Respects CACHE_ENABLE setting - no-op if caching is disabled. + If ttl is provided, it updates the cache's TTL for all entries. + Individual entry TTL is not supported; all entries share the cache TTL. """ - if ttl and ttl != _GLOBAL_CACHE.ttl: - _GLOBAL_CACHE.ttl = ttl - _GLOBAL_CACHE.set(key, value) + config = _get_config() + if not config.cache.enable_cache: + return + + cache = _get_global_cache() + if ttl and ttl != cache.ttl: + # Update cache TTL (affects all entries) + cache.ttl = ttl + logger.debug(f"Updated global cache TTL to {ttl}s") + cache.set(key, value) def cached_async( cache: AsyncCache | TTLCache | None = None, key_func: Callable[..., str] | None = None, ttl: int | None = None, - maxsize: int = 100, + maxsize: int | None = None, ): """ Decorator for caching async function results. @@ -148,7 +235,7 @@ def cached_async( cache: Cache instance to use (creates new AsyncCache if None) key_func: Function to generate cache key from function arguments ttl: Time to live in seconds (only used if cache is None) - maxsize: Maximum cache size (only used if cache is None) + maxsize: Maximum cache size (only used if cache is None, defaults to config) Returns: Decorated async function with caching @@ -157,17 +244,26 @@ def cached_async( @cached_async(ttl=3600, key_func=lambda repo, *args: f"repo:{repo}") async def fetch_repo_data(repo: str): return await api_call(repo) + + Note: + Respects CACHE_ENABLE setting - bypasses cache if disabled. + Uses config defaults for ttl and maxsize if not provided. """ if cache is None: - if ttl: - cache = AsyncCache(maxsize=maxsize, ttl=ttl) - else: - # Use TTLCache as fallback - cache = TTLCache(maxsize=maxsize, ttl=ttl or 3600) + config = _get_config() + # Use provided values or fall back to config defaults + cache_ttl = ttl if ttl is not None else config.cache.default_ttl + cache_maxsize = maxsize if maxsize is not None else config.cache.default_maxsize + cache = AsyncCache(maxsize=cache_maxsize, ttl=cache_ttl) def decorator(func: Callable[..., Any]) -> Callable[..., Any]: @wraps(func) async def wrapper(*args: Any, **kwargs: Any) -> Any: + config = _get_config() + # Bypass cache if disabled + if not config.cache.enable_cache: + return await func(*args, **kwargs) + # Generate cache key if key_func: cache_key = key_func(*args, **kwargs) @@ -201,4 +297,4 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: return wrapper - return decorator + return decorator \ No newline at end of file diff --git a/src/core/utils/violation_tracker.py b/src/core/utils/violation_tracker.py new file mode 100644 index 0000000..da026b2 --- /dev/null +++ b/src/core/utils/violation_tracker.py @@ -0,0 +1,258 @@ +""" +Violation tracking and deduplication service. + +Tracks reported violations to prevent duplicate reports of the same violation. +Uses fingerprinting to identify unique violations based on rule, context, and event data. +""" + +import hashlib +import json +import logging +from datetime import datetime +from typing import Any + +from src.core.config.settings import config + +logger = logging.getLogger(__name__) + + +class ViolationTracker: + """ + Tracks reported violations to prevent duplicates. + + Uses in-memory storage with TTL-based expiration. Each violation is + fingerprinted based on its content and context to identify duplicates. + """ + + def __init__(self, ttl_seconds: int = 86400): + """ + Initialize violation tracker. + + Args: + ttl_seconds: Time to keep violation records (default: 24 hours) + """ + # Store: {fingerprint: {"reported_at": timestamp, "count": int}} + self._reported: dict[str, dict[str, Any]] = {} + self.ttl_seconds = ttl_seconds + self._cleanup_threshold = 1000 # Clean up when we have this many entries + + def generate_fingerprint( + self, + violation: dict[str, Any], + repo_full_name: str, + context: dict[str, Any] | None = None, + ) -> str: + """ + Generate a unique fingerprint for a violation. + + The fingerprint is based on: + - Rule description + - Violation message + - Severity + - Repository + - Context-specific data (PR number, commit SHA, etc.) + + Args: + violation: Violation dictionary + repo_full_name: Repository full name (e.g., "owner/repo") + context: Optional context data (PR number, commit SHA, etc.) + + Returns: + SHA256 hash string representing the violation fingerprint + """ + # Extract key fields that make a violation unique + rule_description = violation.get("rule_description", "") + message = violation.get("message", "") + severity = violation.get("severity", "") + details = violation.get("details", {}) + + # Build fingerprint data + fingerprint_data = { + "rule_description": rule_description, + "message": message, + "severity": severity, + "repo": repo_full_name, + # Include relevant details (but not all, to avoid too much variation) + "validator": details.get("validator_used", ""), + "parameters": details.get("parameters", {}), + } + + # Add context-specific data if provided + if context: + # Include PR number if available (violations on same PR are duplicates) + if "pr_number" in context: + fingerprint_data["pr_number"] = context["pr_number"] + # Include commit SHA if available (violations on same commit are duplicates) + if "commit_sha" in context: + fingerprint_data["commit_sha"] = context["commit_sha"] + # Include branch if available + if "branch" in context: + fingerprint_data["branch"] = context["branch"] + + # Create deterministic JSON string (sorted keys for consistency) + json_str = json.dumps(fingerprint_data, sort_keys=True, default=str) + + # Generate SHA256 hash + fingerprint = hashlib.sha256(json_str.encode("utf-8")).hexdigest() + + logger.debug(f"Generated fingerprint for violation: {fingerprint[:16]}...") + return fingerprint + + def is_reported(self, fingerprint: str) -> bool: + """ + Check if a violation has already been reported. + + Args: + fingerprint: Violation fingerprint + + Returns: + True if violation was already reported, False otherwise + """ + if fingerprint not in self._reported: + return False + + # Check if entry has expired + entry = self._reported[fingerprint] + age = datetime.now().timestamp() - entry.get("reported_at", 0) + + if age >= self.ttl_seconds: + # Entry expired, remove it + del self._reported[fingerprint] + logger.debug(f"Violation fingerprint {fingerprint[:16]}... expired and removed") + return False + + return True + + def mark_reported( + self, + fingerprint: str, + violation: dict[str, Any] | None = None, + ) -> None: + """ + Mark a violation as reported. + + Args: + fingerprint: Violation fingerprint + violation: Optional violation data for logging + """ + self._reported[fingerprint] = { + "reported_at": datetime.now().timestamp(), + "count": self._reported.get(fingerprint, {}).get("count", 0) + 1, + } + + if violation: + rule_desc = violation.get("rule_description", "Unknown") + logger.debug(f"Marked violation as reported: {rule_desc} (fingerprint: {fingerprint[:16]}...)") + + # Periodic cleanup to prevent memory growth + if len(self._reported) > self._cleanup_threshold: + self._cleanup_expired() + + def filter_new_violations( + self, + violations: list[dict[str, Any]], + repo_full_name: str, + context: dict[str, Any] | None = None, + ) -> list[dict[str, Any]]: + """ + Filter out violations that have already been reported. + + Args: + violations: List of violation dictionaries + repo_full_name: Repository full name + context: Optional context data (PR number, commit SHA, etc.) + + Returns: + List of violations that haven't been reported yet + """ + new_violations = [] + duplicate_count = 0 + + for violation in violations: + fingerprint = self.generate_fingerprint(violation, repo_full_name, context) + + if self.is_reported(fingerprint): + duplicate_count += 1 + logger.debug( + f"Skipping duplicate violation: {violation.get('rule_description', 'Unknown')} " + f"(fingerprint: {fingerprint[:16]}...)" + ) + else: + new_violations.append(violation) + # Mark as reported immediately + self.mark_reported(fingerprint, violation) + + if duplicate_count > 0: + logger.info( + f"Filtered out {duplicate_count} duplicate violation(s), {len(new_violations)} new violation(s) remain" + ) + + return new_violations + + def _cleanup_expired(self) -> None: + """Remove expired entries to free memory.""" + now = datetime.now().timestamp() + expired_keys = [ + fingerprint + for fingerprint, entry in self._reported.items() + if (now - entry.get("reported_at", 0)) >= self.ttl_seconds + ] + + for key in expired_keys: + del self._reported[key] + + if expired_keys: + logger.debug(f"Cleaned up {len(expired_keys)} expired violation records") + + def get_stats(self) -> dict[str, Any]: + """ + Get statistics about tracked violations. + + Returns: + Dictionary with statistics + """ + now = datetime.now().timestamp() + active = sum(1 for entry in self._reported.values() if (now - entry.get("reported_at", 0)) < self.ttl_seconds) + + total_reports = sum(entry.get("count", 0) for entry in self._reported.values()) + + return { + "total_tracked": len(self._reported), + "active": active, + "expired": len(self._reported) - active, + "total_reports": total_reports, + "ttl_seconds": self.ttl_seconds, + } + + def clear(self) -> None: + """Clear all tracked violations (useful for testing).""" + count = len(self._reported) + self._reported.clear() + logger.debug(f"Cleared {count} violation records") + + +# Global violation tracker instance +_global_tracker: ViolationTracker | None = None + + +def get_violation_tracker() -> ViolationTracker: + """ + Get or create the global violation tracker instance. + + Returns: + Global ViolationTracker instance + """ + global _global_tracker + if _global_tracker is None: + # Use config if available, otherwise default TTL + ttl = getattr(config, "cache", None) + if ttl and hasattr(ttl, "global_ttl"): + # Use cache TTL as a reasonable default for violation tracking + ttl_seconds = ttl.global_ttl + else: + ttl_seconds = 86400 # 24 hours default + + _global_tracker = ViolationTracker(ttl_seconds=ttl_seconds) + logger.info(f"Initialized violation tracker with TTL: {ttl_seconds}s") + + return _global_tracker diff --git a/src/event_processors/base.py b/src/event_processors/base.py index 04de24d..74c0571 100644 --- a/src/event_processors/base.py +++ b/src/event_processors/base.py @@ -1,5 +1,6 @@ import logging from abc import ABC, abstractmethod +from enum import Enum from typing import Any from pydantic import BaseModel, Field @@ -13,15 +14,40 @@ logger = logging.getLogger(__name__) +class ProcessingState(str, Enum): + """ + Processing state for event processing results. + + - PASS: Rules passed - everything is good, no violations found + - FAIL: Rules failed - violations found, action required + - ERROR: Error occurred - couldn't check, need to investigate + """ + + PASS = "pass" + FAIL = "fail" + ERROR = "error" + + class ProcessingResult(BaseModel): """Result of event processing.""" - success: bool + state: ProcessingState violations: list[dict[str, Any]] = Field(default_factory=list) api_calls_made: int processing_time_ms: int error: str | None = None + @property + def success(self) -> bool: + """ + Legacy property for backward compatibility. + + Returns True only for PASS state, False for FAIL or ERROR. + Note: This doesn't distinguish between FAIL and ERROR. + Use .state instead for explicit state checking. + """ + return self.state == ProcessingState.PASS + class BaseEventProcessor(ABC): """Base class for all event processors.""" @@ -33,22 +59,22 @@ def __init__(self): @abstractmethod async def process(self, task: Task) -> ProcessingResult: """Process the event task.""" - pass + raise NotImplementedError("Subclasses must implement process") @abstractmethod def get_event_type(self) -> str: """Get the event type this processor handles.""" - pass + raise NotImplementedError("Subclasses must implement get_event_type") @abstractmethod async def prepare_webhook_data(self, task: Task) -> dict[str, Any]: """Prepare data from webhook payload.""" - pass + raise NotImplementedError("Subclasses must implement prepare_webhook_data") @abstractmethod async def prepare_api_data(self, task: Task) -> dict[str, Any]: """Prepare data from GitHub API calls.""" - pass + raise NotImplementedError("Subclasses must implement prepare_api_data") def _get_rule_provider(self) -> RuleLoader: """Get the rule provider for this processor.""" diff --git a/src/event_processors/check_run.py b/src/event_processors/check_run.py index 6d80ee6..3cb0484 100644 --- a/src/event_processors/check_run.py +++ b/src/event_processors/check_run.py @@ -3,7 +3,7 @@ from typing import Any from src.agents import get_agent -from src.event_processors.base import BaseEventProcessor, ProcessingResult +from src.event_processors.base import BaseEventProcessor, ProcessingResult, ProcessingState from src.tasks.task_queue import Task logger = logging.getLogger(__name__) @@ -31,7 +31,10 @@ async def process(self, task: Task) -> ProcessingResult: if "watchflow" in check_run.get("name", "").lower(): logger.info("Ignoring Watchflow's own check run to prevent recursive loops.") return ProcessingResult( - success=True, violations=[], api_calls_made=0, processing_time_ms=int((time.time() - start_time) * 1000) + state=ProcessingState.PASS, + violations=[], + api_calls_made=0, + processing_time_ms=int((time.time() - start_time) * 1000), ) logger.info("=" * 80) @@ -64,6 +67,18 @@ async def process(self, task: Task) -> ProcessingResult: rules=formatted_rules, ) + # Check if agent execution failed + if not result.success: + processing_time = int((time.time() - start_time) * 1000) + logger.error(f"❌ Agent execution failed: {result.message}") + return ProcessingResult( + state=ProcessingState.ERROR, + violations=[], + api_calls_made=1, + processing_time_ms=processing_time, + error=f"Agent execution failed: {result.message}", + ) + violations = result.data.get("violations", []) logger.info("=" * 80) @@ -72,7 +87,7 @@ async def process(self, task: Task) -> ProcessingResult: logger.info("=" * 80) return ProcessingResult( - success=(not violations), + state=ProcessingState.PASS if not violations else ProcessingState.FAIL, violations=violations, api_calls_made=1, processing_time_ms=int((time.time() - start_time) * 1000), diff --git a/src/event_processors/deployment.py b/src/event_processors/deployment.py index 0441c50..aeb2776 100644 --- a/src/event_processors/deployment.py +++ b/src/event_processors/deployment.py @@ -2,7 +2,7 @@ import time from typing import Any -from src.event_processors.base import BaseEventProcessor, ProcessingResult +from src.event_processors.base import BaseEventProcessor, ProcessingResult, ProcessingState from src.tasks.task_queue import Task logger = logging.getLogger(__name__) @@ -44,7 +44,10 @@ async def process(self, task: Task) -> ProcessingResult: logger.info("=" * 80) return ProcessingResult( - success=True, violations=[], api_calls_made=0, processing_time_ms=int((time.time() - start_time) * 1000) + state=ProcessingState.PASS, + violations=[], + api_calls_made=0, + processing_time_ms=int((time.time() - start_time) * 1000), ) async def prepare_webhook_data(self, task: Task) -> dict[str, Any]: diff --git a/src/event_processors/deployment_protection_rule.py b/src/event_processors/deployment_protection_rule.py index d97c77b..d2eb000 100644 --- a/src/event_processors/deployment_protection_rule.py +++ b/src/event_processors/deployment_protection_rule.py @@ -3,7 +3,7 @@ from typing import Any from src.agents import get_agent -from src.event_processors.base import BaseEventProcessor, ProcessingResult +from src.event_processors.base import BaseEventProcessor, ProcessingResult, ProcessingState from src.tasks.scheduler.deployment_scheduler import get_deployment_scheduler from src.tasks.task_queue import Task @@ -50,7 +50,7 @@ async def process(self, task: Task) -> ProcessingResult: deployment_callback_url, environment, "No rules configured", installation_id ) return ProcessingResult( - success=True, + state=ProcessingState.PASS, violations=[], api_calls_made=1, processing_time_ms=int((time.time() - start_time) * 1000), @@ -76,7 +76,7 @@ async def process(self, task: Task) -> ProcessingResult: deployment_callback_url, environment, "No deployment rules configured", installation_id ) return ProcessingResult( - success=True, + state=ProcessingState.PASS, violations=[], api_calls_made=1, processing_time_ms=int((time.time() - start_time) * 1000), @@ -103,6 +103,18 @@ async def process(self, task: Task) -> ProcessingResult: rules=formatted_rules, ) + # Check if agent execution failed + if not analysis_result.success: + processing_time = int((time.time() - start_time) * 1000) + logger.error(f"❌ Agent execution failed: {analysis_result.message}") + return ProcessingResult( + state=ProcessingState.ERROR, + violations=[], + api_calls_made=1, + processing_time_ms=processing_time, + error=f"Agent execution failed: {analysis_result.message}", + ) + # Extract violations from AgentResult - same pattern as acknowledgment processor violations = [] if analysis_result.data and "evaluation_result" in analysis_result.data: @@ -166,13 +178,16 @@ async def process(self, task: Task) -> ProcessingResult: logger.info("=" * 80) return ProcessingResult( - success=(not violations), violations=violations, api_calls_made=1, processing_time_ms=processing_time + state=ProcessingState.PASS if not violations else ProcessingState.FAIL, + violations=violations, + api_calls_made=1, + processing_time_ms=processing_time, ) except Exception as e: logger.error(f"❌ Error processing deployment protection rule: {str(e)}") return ProcessingResult( - success=False, + state=ProcessingState.ERROR, violations=[], api_calls_made=0, processing_time_ms=int((time.time() - start_time) * 1000), diff --git a/src/event_processors/deployment_review.py b/src/event_processors/deployment_review.py index 50be9aa..38dd7ae 100644 --- a/src/event_processors/deployment_review.py +++ b/src/event_processors/deployment_review.py @@ -3,7 +3,7 @@ from typing import Any from src.agents import get_agent -from src.event_processors.base import BaseEventProcessor, ProcessingResult +from src.event_processors.base import BaseEventProcessor, ProcessingResult, ProcessingState from src.tasks.task_queue import Task logger = logging.getLogger(__name__) @@ -80,7 +80,10 @@ async def process(self, task: Task) -> ProcessingResult: if not deployment_review_rules: logger.info("📋 No deployment_review rules found") return ProcessingResult( - success=True, violations=[], api_calls_made=1, processing_time_ms=int((time.time() - start_time) * 1000) + state=ProcessingState.PASS, + violations=[], + api_calls_made=1, + processing_time_ms=int((time.time() - start_time) * 1000), ) logger.info(f"📋 Found {len(deployment_review_rules)} applicable rules for deployment_review") @@ -95,6 +98,18 @@ async def process(self, task: Task) -> ProcessingResult: rules=formatted_rules, ) + # Check if agent execution failed + if not result.success: + processing_time = int((time.time() - start_time) * 1000) + logger.error(f"❌ Agent execution failed: {result.message}") + return ProcessingResult( + state=ProcessingState.ERROR, + violations=[], + api_calls_made=1, + processing_time_ms=processing_time, + error=f"Agent execution failed: {result.message}", + ) + violations = result.data.get("violations", []) logger.info("=" * 80) @@ -103,7 +118,7 @@ async def process(self, task: Task) -> ProcessingResult: logger.info("=" * 80) return ProcessingResult( - success=(not violations), + state=ProcessingState.PASS if not violations else ProcessingState.FAIL, violations=violations, api_calls_made=1, processing_time_ms=int((time.time() - start_time) * 1000), diff --git a/src/event_processors/deployment_status.py b/src/event_processors/deployment_status.py index e0a7aa4..20c614d 100644 --- a/src/event_processors/deployment_status.py +++ b/src/event_processors/deployment_status.py @@ -2,7 +2,7 @@ import time from typing import Any -from src.event_processors.base import BaseEventProcessor, ProcessingResult +from src.event_processors.base import BaseEventProcessor, ProcessingResult, ProcessingState from src.tasks.task_queue import Task logger = logging.getLogger(__name__) @@ -52,7 +52,10 @@ async def process(self, task: Task) -> ProcessingResult: logger.info("=" * 80) return ProcessingResult( - success=True, violations=[], api_calls_made=0, processing_time_ms=int((time.time() - start_time) * 1000) + state=ProcessingState.PASS, + violations=[], + api_calls_made=0, + processing_time_ms=int((time.time() - start_time) * 1000), ) async def prepare_webhook_data(self, task: Task) -> dict[str, Any]: diff --git a/src/event_processors/pull_request.py b/src/event_processors/pull_request.py index f7a4a66..c6c8d80 100644 --- a/src/event_processors/pull_request.py +++ b/src/event_processors/pull_request.py @@ -4,7 +4,8 @@ from typing import Any from src.agents import get_agent -from src.event_processors.base import BaseEventProcessor, ProcessingResult +from src.core.utils.violation_tracker import get_violation_tracker +from src.event_processors.base import BaseEventProcessor, ProcessingResult, ProcessingState from src.rules.loaders.github_loader import RulesFileNotFoundError from src.tasks.task_queue import Task @@ -59,7 +60,7 @@ async def process(self, task: Task) -> ProcessingResult: error="Rules not configured. Please create `.watchflow/rules.yaml` in your repository.", ) return ProcessingResult( - success=True, # Not a failure, just needs setup + state=ProcessingState.PASS, # Not a failure, just needs setup violations=[], api_calls_made=api_calls, processing_time_ms=int((time.time() - start_time) * 1000), @@ -99,6 +100,19 @@ async def process(self, task: Task) -> ProcessingResult: event_type="pull_request", event_data=event_data, rules=formatted_rules ) + # Check if agent execution failed + if not result.success: + processing_time = int((time.time() - start_time) * 1000) + logger.error(f"❌ Agent execution failed: {result.message}") + await self._create_check_run(task, [], "failure", error=result.message) + return ProcessingResult( + state=ProcessingState.ERROR, + violations=[], + api_calls_made=api_calls, + processing_time_ms=processing_time, + error=f"Agent execution failed: {result.message}", + ) + # Extract violations from engine result violations = [] if result.data and "evaluation_result" in result.data: @@ -131,25 +145,50 @@ async def process(self, task: Task) -> ProcessingResult: # Use violations requiring fixes for final result violations = require_acknowledgment_violations + # Filter out duplicate violations before posting + pr_data = task.payload.get("pull_request", {}) + pr_number = pr_data.get("number") + context = ( + { + "pr_number": pr_number, + "commit_sha": pr_data.get("head", {}).get("sha"), + "branch": pr_data.get("head", {}).get("ref"), + } + if pr_number + else {} + ) + + violation_tracker = get_violation_tracker() + new_violations = violation_tracker.filter_new_violations(violations, task.repo_full_name, context) + + if len(new_violations) < len(violations): + logger.info( + f"🔍 Deduplication: {len(violations) - len(new_violations)} duplicate violation(s) filtered out, " + f"{len(new_violations)} new violation(s) to report" + ) + # Create check run based on whether we have acknowledgments if previous_acknowledgments and original_violations: # Create check run with acknowledgment context await self._create_check_run_with_acknowledgment( - task, acknowledgable_violations, violations, previous_acknowledgments + task, acknowledgable_violations, new_violations, previous_acknowledgments ) else: # No acknowledgments or no violations - create normal check run - await self._create_check_run(task, violations) + await self._create_check_run(task, new_violations) processing_time = int((time.time() - start_time) * 1000) - # Post violations as comments (if any) - if violations: - logger.info(f"🚨 Found {len(violations)} violations, posting to PR...") - await self._post_violations_to_github(task, violations) + # Post violations as comments (if any new violations) + if new_violations: + logger.info(f"🚨 Found {len(new_violations)} new violations, posting to PR...") + await self._post_violations_to_github(task, new_violations) api_calls += 1 else: - logger.info("✅ No violations found, skipping PR comment") + if violations: + logger.info(f"✅ All {len(violations)} violations were already reported, skipping PR comment") + else: + logger.info("✅ No violations found, skipping PR comment") # Summary logger.info("=" * 80) @@ -158,10 +197,10 @@ async def process(self, task: Task) -> ProcessingResult: logger.info(f" Violations found: {len(violations)}") logger.info(f" API calls made: {api_calls}") - if violations: + if new_violations: logger.warning("🚨 VIOLATION SUMMARY:") # Format violations for logging - for i, violation in enumerate(violations, 1): + for i, violation in enumerate(new_violations, 1): logger.info( f" {i}. {violation.get('rule_description', 'Unknown')} ({violation.get('severity', 'medium')})" ) @@ -172,8 +211,8 @@ async def process(self, task: Task) -> ProcessingResult: logger.info("=" * 80) return ProcessingResult( - success=(not violations), - violations=violations, + state=ProcessingState.PASS if not new_violations else ProcessingState.FAIL, + violations=new_violations, api_calls_made=api_calls, processing_time_ms=processing_time, ) @@ -183,7 +222,7 @@ async def process(self, task: Task) -> ProcessingResult: # Create a failing check run for errors await self._create_check_run(task, [], "failure", error=str(e)) return ProcessingResult( - success=False, + state=ProcessingState.ERROR, violations=[], api_calls_made=api_calls, processing_time_ms=int((time.time() - start_time) * 1000), diff --git a/src/event_processors/push.py b/src/event_processors/push.py index 0954554..75a214f 100644 --- a/src/event_processors/push.py +++ b/src/event_processors/push.py @@ -3,7 +3,8 @@ from typing import Any from src.agents import get_agent -from src.event_processors.base import BaseEventProcessor, ProcessingResult +from src.core.utils.violation_tracker import get_violation_tracker +from src.event_processors.base import BaseEventProcessor, ProcessingResult, ProcessingState from src.tasks.task_queue import Task logger = logging.getLogger(__name__) @@ -57,7 +58,10 @@ async def process(self, task: Task) -> ProcessingResult: if not rules: logger.info("No rules found for this repository") return ProcessingResult( - success=True, violations=[], api_calls_made=1, processing_time_ms=int((time.time() - start_time) * 1000) + state=ProcessingState.PASS, + violations=[], + api_calls_made=1, + processing_time_ms=int((time.time() - start_time) * 1000), ) logger.info(f"📋 Loaded {len(rules)} rules for evaluation") @@ -68,35 +72,83 @@ async def process(self, task: Task) -> ProcessingResult: # Run agentic analysis using the instance result = await self.engine_agent.execute(event_type="push", event_data=event_data, rules=formatted_rules) - violations = result.data.get("violations", []) + # Check if agent execution failed + if not result.success: + processing_time = int((time.time() - start_time) * 1000) + logger.error(f"❌ Agent execution failed: {result.message}") + return ProcessingResult( + state=ProcessingState.ERROR, + violations=[], + api_calls_made=1, + processing_time_ms=processing_time, + error=f"Agent execution failed: {result.message}", + ) + + # Extract violations from engine result (same pattern as PullRequestProcessor) + violations = [] + if result.data and "evaluation_result" in result.data: + eval_result = result.data["evaluation_result"] + if hasattr(eval_result, "violations"): + # Convert RuleViolation objects to dictionaries + violations = [v.__dict__ for v in eval_result.violations] + else: + # Fallback to old format if evaluation_result not present + violations = result.data.get("violations", []) + + # Filter out duplicate violations before creating check run + commit_sha = payload.get("after") + context = ( + { + "commit_sha": commit_sha, + "branch": ref.replace("refs/heads/", "") if ref.startswith("refs/heads/") else ref, + } + if commit_sha + else {} + ) + + violation_tracker = get_violation_tracker() + new_violations = violation_tracker.filter_new_violations(violations, task.repo_full_name, context) + + if len(new_violations) < len(violations): + logger.info( + f"🔍 Deduplication: {len(violations) - len(new_violations)} duplicate violation(s) filtered out, " + f"{len(new_violations)} new violation(s) to report" + ) processing_time = int((time.time() - start_time) * 1000) # Post results to GitHub (create check run) api_calls = 1 # Initial rule fetch - if violations: - await self._create_check_run(task, violations) + if new_violations: + await self._create_check_run(task, new_violations) api_calls += 1 # Summary logger.info("=" * 80) logger.info(f"🏁 PUSH processing completed in {processing_time}ms") logger.info(f" Rules evaluated: {len(formatted_rules)}") - logger.info(f" Violations found: {len(violations)}") + logger.info(f" Violations found: {len(new_violations)}") logger.info(f" API calls made: {api_calls}") - if violations: + if new_violations: logger.warning("🚨 VIOLATION SUMMARY:") - for i, violation in enumerate(violations, 1): - logger.warning(f" {i}. {violation.get('rule', 'Unknown')} ({violation.get('severity', 'medium')})") + for i, violation in enumerate(new_violations, 1): + rule_desc = violation.get("rule_description") or violation.get("rule", "Unknown") + logger.warning(f" {i}. {rule_desc} ({violation.get('severity', 'medium')})") logger.warning(f" {violation.get('message', '')}") else: - logger.info("✅ All rules passed - no violations detected!") + if violations: + logger.info(f"✅ All {len(violations)} violations were already reported, skipping check run") + else: + logger.info("✅ All rules passed - no violations detected!") logger.info("=" * 80) return ProcessingResult( - success=True, violations=violations, api_calls_made=api_calls, processing_time_ms=processing_time + state=ProcessingState.PASS if not new_violations else ProcessingState.FAIL, + violations=new_violations, + api_calls_made=api_calls, + processing_time_ms=processing_time, ) def _convert_rules_to_new_format(self, rules: list[Any]) -> list[dict[str, Any]]: diff --git a/src/event_processors/rule_creation.py b/src/event_processors/rule_creation.py index 48adfbd..57d4e6b 100644 --- a/src/event_processors/rule_creation.py +++ b/src/event_processors/rule_creation.py @@ -3,13 +3,23 @@ import time from typing import Any +from pydantic import BaseModel, Field + from src.agents import get_agent -from src.event_processors.base import BaseEventProcessor, ProcessingResult +from src.event_processors.base import BaseEventProcessor, ProcessingResult, ProcessingState from src.tasks.task_queue import Task logger = logging.getLogger(__name__) +class FeasibilityResult(BaseModel): + """Result from feasibility analysis for rule creation.""" + + is_feasible: bool = Field(description="Whether the rule is feasible to implement") + yaml_content: str = Field(description="Generated YAML configuration for the rule", default="") + feedback: str = Field(description="Feedback about the rule feasibility", default="") + + class RuleCreationProcessor(BaseEventProcessor): """Processor for rule creation commands via comments.""" @@ -37,7 +47,7 @@ async def process(self, task: Task) -> ProcessingResult: if not rule_description: return ProcessingResult( - success=False, + state=ProcessingState.ERROR, violations=[], api_calls_made=0, processing_time_ms=int((time.time() - start_time) * 1000), @@ -47,7 +57,29 @@ async def process(self, task: Task) -> ProcessingResult: logger.info(f"📝 Rule description: {rule_description}") # Use the feasibility agent to check if the rule is supported - feasibility_result = await self.feasibility_agent.check_feasibility(rule_description) + agent_result = await self.feasibility_agent.execute(rule_description) + + # Check if agent execution failed + if not agent_result.success: + processing_time = int((time.time() - start_time) * 1000) + logger.error(f"❌ Feasibility agent failed: {agent_result.message}") + return ProcessingResult( + state=ProcessingState.ERROR, + violations=[], + api_calls_made=0, + processing_time_ms=processing_time, + error=f"Feasibility agent failed: {agent_result.message}", + ) + + # Extract feasibility data from AgentResult + feasibility_data = agent_result.data + is_feasible = feasibility_data.get("is_feasible", False) + yaml_content = feasibility_data.get("yaml_content", "") + feedback = agent_result.message # Feedback is in the message field + + feasibility_result = FeasibilityResult( + is_feasible=is_feasible, yaml_content=yaml_content, feedback=feedback + ) processing_time = int((time.time() - start_time) * 1000) @@ -67,12 +99,14 @@ async def process(self, task: Task) -> ProcessingResult: logger.info("=" * 80) - return ProcessingResult(success=True, violations=[], api_calls_made=1, processing_time_ms=processing_time) + return ProcessingResult( + state=ProcessingState.PASS, violations=[], api_calls_made=1, processing_time_ms=processing_time + ) except Exception as e: logger.error(f"❌ Error processing rule creation: {e}") return ProcessingResult( - success=False, + state=ProcessingState.ERROR, violations=[], api_calls_made=0, processing_time_ms=int((time.time() - start_time) * 1000), @@ -97,7 +131,7 @@ def _extract_rule_description(self, task: Task) -> str: return "" - async def _post_result_to_comment(self, task: Task, feasibility_result): + async def _post_result_to_comment(self, task: Task, feasibility_result: FeasibilityResult): """Post the feasibility result as a reply to the original comment.""" try: # Get issue/PR number from the webhook payload @@ -123,7 +157,7 @@ async def _post_result_to_comment(self, task: Task, feasibility_result): except Exception as e: logger.error(f"Error posting feasibility reply: {e}") - def _format_feasibility_reply(self, feasibility_result) -> str: + def _format_feasibility_reply(self, feasibility_result: FeasibilityResult) -> str: """Format the feasibility result as a comment reply.""" if feasibility_result.is_feasible: reply = "## ✅ Rule Creation Successful!\n\n" diff --git a/src/event_processors/violation_acknowledgment.py b/src/event_processors/violation_acknowledgment.py index 96bf57e..35adbea 100644 --- a/src/event_processors/violation_acknowledgment.py +++ b/src/event_processors/violation_acknowledgment.py @@ -5,7 +5,8 @@ from src.agents import get_agent from src.core.models import EventType -from src.event_processors.base import BaseEventProcessor, ProcessingResult +from src.core.utils.violation_tracker import get_violation_tracker +from src.event_processors.base import BaseEventProcessor, ProcessingResult, ProcessingState from src.tasks.task_queue import Task logger = logging.getLogger(__name__) @@ -74,7 +75,7 @@ async def process(self, task: Task) -> ProcessingResult: api_calls += 1 return ProcessingResult( - success=True, + state=ProcessingState.PASS, violations=[], api_calls_made=api_calls, processing_time_ms=int((time.time() - start_time) * 1000), @@ -87,7 +88,7 @@ async def process(self, task: Task) -> ProcessingResult: if not github_token: logger.error(f"❌ Failed to get installation token for {installation_id}") return ProcessingResult( - success=False, + state=ProcessingState.ERROR, violations=[], api_calls_made=api_calls, processing_time_ms=int((time.time() - start_time) * 1000), @@ -178,7 +179,7 @@ async def process(self, task: Task) -> ProcessingResult: ) api_calls += 1 return ProcessingResult( - success=False, + state=ProcessingState.ERROR, violations=[], api_calls_made=api_calls, processing_time_ms=int((time.time() - start_time) * 1000), @@ -203,12 +204,27 @@ async def process(self, task: Task) -> ProcessingResult: api_calls += 1 return ProcessingResult( - success=True, + state=ProcessingState.PASS, violations=[], api_calls_made=api_calls, processing_time_ms=int((time.time() - start_time) * 1000), ) + # Filter out duplicate violations before evaluation + context = ( + { + "pr_number": pr_number, + "commit_sha": pr_data.get("head", {}).get("sha"), + "branch": pr_data.get("head", {}).get("ref"), + } + if pr_number + else {} + ) + + violation_tracker = get_violation_tracker() + # Note: We don't filter here because we want to evaluate all violations + # But we'll filter before posting comments to avoid duplicate messages + # Evaluate acknowledgment against ALL violations evaluation_result = await self._evaluate_acknowledgment( acknowledgment_reason=acknowledgment_reason, @@ -218,13 +234,32 @@ async def process(self, task: Task) -> ProcessingResult: rules=formatted_rules, # Pass the formatted rules ) + # Filter duplicates from violations that will be posted in comments + new_acknowledgable = violation_tracker.filter_new_violations( + evaluation_result["acknowledgable_violations"], repo, context + ) + new_require_fixes = violation_tracker.filter_new_violations( + evaluation_result["require_fixes"], repo, context + ) + + if len(new_acknowledgable) < len(evaluation_result["acknowledgable_violations"]): + logger.info( + f"🔍 Deduplication: {len(evaluation_result['acknowledgable_violations']) - len(new_acknowledgable)} " + f"duplicate acknowledged violation(s) filtered out" + ) + if len(new_require_fixes) < len(evaluation_result["require_fixes"]): + logger.info( + f"🔍 Deduplication: {len(evaluation_result['require_fixes']) - len(new_require_fixes)} " + f"duplicate required-fix violation(s) filtered out" + ) + if evaluation_result["valid"]: # Acknowledgment is valid - selectively approve violations and provide guidance await self._approve_violations_selectively( repo=repo, pr_number=pr_number, - acknowledgable_violations=evaluation_result["acknowledgable_violations"], - require_fixes=evaluation_result["require_fixes"], + acknowledgable_violations=new_acknowledgable, + require_fixes=new_require_fixes, reason=acknowledgment_reason, commenter=commenter, installation_id=installation_id, @@ -235,8 +270,8 @@ async def process(self, task: Task) -> ProcessingResult: await self._update_check_run( repo=repo, pr_number=pr_number, - acknowledgable_violations=evaluation_result["acknowledgable_violations"], - require_fixes=evaluation_result["require_fixes"], + acknowledgable_violations=new_acknowledgable, + require_fixes=new_require_fixes, installation_id=installation_id, ) api_calls += 1 @@ -249,7 +284,7 @@ async def process(self, task: Task) -> ProcessingResult: pr_number=pr_number, reason=evaluation_result["reason"], commenter=commenter, - require_fixes=evaluation_result["require_fixes"], + require_fixes=new_require_fixes, installation_id=installation_id, ) api_calls += 1 @@ -262,9 +297,18 @@ async def process(self, task: Task) -> ProcessingResult: logger.info(f" Status: {'accepted' if evaluation_result['valid'] else 'rejected'}") logger.info("=" * 80) + # If the acknowledgment is valid, the process passes. Otherwise, it fails. + if evaluation_result["valid"]: + state = ProcessingState.PASS + else: + state = ProcessingState.FAIL + + # Use filtered violations for the result + require_fixes_for_result = new_require_fixes if not evaluation_result["valid"] else [] + return ProcessingResult( - success=True, - violations=evaluation_result["require_fixes"] if not evaluation_result["valid"] else [], + state=state, + violations=require_fixes_for_result, api_calls_made=api_calls, processing_time_ms=processing_time, ) @@ -272,7 +316,7 @@ async def process(self, task: Task) -> ProcessingResult: except Exception as e: logger.error(f"❌ Error processing violation acknowledgment: {str(e)}") return ProcessingResult( - success=False, + state=ProcessingState.ERROR, violations=[], api_calls_made=api_calls, processing_time_ms=int((time.time() - start_time) * 1000), diff --git a/src/integrations/providers/base.py b/src/integrations/providers/base.py index 067ff12..6cd1e81 100644 --- a/src/integrations/providers/base.py +++ b/src/integrations/providers/base.py @@ -20,17 +20,17 @@ def __init__(self, model: str, max_tokens: int = 4096, temperature: float = 0.1, @abstractmethod def get_chat_model(self) -> Any: """Get the chat model instance.""" - pass + raise NotImplementedError("Subclasses must implement get_chat_model") @abstractmethod def supports_structured_output(self) -> bool: """Check if this provider supports structured output.""" - pass + raise NotImplementedError("Subclasses must implement supports_structured_output") @abstractmethod def get_provider_name(self) -> str: """Get the provider name.""" - pass + raise NotImplementedError("Subclasses must implement get_provider_name") def get_model_info(self) -> dict[str, Any]: """Get model information.""" diff --git a/src/rules/condition_evaluator.py b/src/rules/condition_evaluator.py new file mode 100644 index 0000000..170a97f --- /dev/null +++ b/src/rules/condition_evaluator.py @@ -0,0 +1,289 @@ +""" +Condition evaluator for complex boolean logic (AND/OR/NOT). + +Supports nested conditions with AND, OR, and NOT operators. +""" + +import logging +from typing import Any + +from src.rules.models import RuleCondition +from src.rules.validators import VALIDATOR_REGISTRY + +logger = logging.getLogger(__name__) + + +class ConditionExpression: + """ + Represents a condition expression with logical operators. + + Supports: + - Simple conditions: single condition evaluation + - AND: all conditions must be true + - OR: at least one condition must be true + - NOT: negates a condition + - Nested: conditions can be nested arbitrarily + """ + + def __init__( + self, + operator: str | None = None, + condition: RuleCondition | None = None, + conditions: list["ConditionExpression"] | None = None, + ): + """ + Initialize a condition expression. + + Args: + operator: Logical operator ("AND", "OR", "NOT", or None for simple condition) + condition: Single condition (for simple conditions or NOT) + conditions: List of nested conditions (for AND/OR) + """ + self.operator = operator.upper() if operator else None + self.condition = condition + self.conditions = conditions or [] + + # Validate structure + if self.operator == "NOT": + if not self.condition: + raise ValueError("NOT operator requires a single condition") + elif self.operator in ("AND", "OR"): + if not self.conditions: + raise ValueError(f"{self.operator} operator requires at least one condition") + elif self.operator is None: + if not self.condition: + raise ValueError("Simple condition requires a condition") + else: + raise ValueError(f"Unknown operator: {self.operator}") + + def to_dict(self) -> dict[str, Any]: + """Convert expression to dictionary format.""" + if self.operator is None: + # Simple condition + return { + "type": self.condition.type, + "parameters": self.condition.parameters, + } + elif self.operator == "NOT": + return { + "operator": "NOT", + "condition": { + "type": self.condition.type, + "parameters": self.condition.parameters, + }, + } + else: + return { + "operator": self.operator, + "conditions": [cond.to_dict() for cond in self.conditions], + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "ConditionExpression": + """ + Create ConditionExpression from dictionary. + + Supports formats: + - Simple: {"type": "author_team_is", "parameters": {"team": "devops"}} + - NOT: {"operator": "NOT", "condition": {...}} + - AND/OR: {"operator": "AND", "conditions": [...]} + """ + if "operator" in data: + operator = data["operator"].upper() + if operator == "NOT": + condition_data = data["condition"] + condition = RuleCondition(type=condition_data["type"], parameters=condition_data.get("parameters", {})) + return cls(operator="NOT", condition=condition) + else: + # AND or OR + nested_conditions = [cls.from_dict(cond) for cond in data["conditions"]] + return cls(operator=operator, conditions=nested_conditions) + else: + # Simple condition + condition = RuleCondition(type=data["type"], parameters=data.get("parameters", {})) + return cls(condition=condition) + + +class ConditionEvaluator: + """ + Evaluates condition expressions against event data. + + Handles: + - Simple conditions using validators + - AND/OR/NOT logical operators + - Nested condition expressions + """ + + def __init__(self): + """Initialize the condition evaluator.""" + self.validator_registry = VALIDATOR_REGISTRY + + async def evaluate( + self, expression: ConditionExpression, event_data: dict[str, Any] + ) -> tuple[bool, dict[str, Any]]: + """ + Evaluate a condition expression against event data. + + Args: + expression: Condition expression to evaluate + event_data: Event data to evaluate against + + Returns: + Tuple of (result: bool, metadata: dict) where metadata contains evaluation details + """ + metadata = { + "operator": expression.operator, + "evaluated_at": "condition_evaluator", + } + + try: + if expression.operator is None: + # Simple condition + result = await self._evaluate_simple_condition(expression.condition, event_data) + metadata["condition_type"] = expression.condition.type + metadata["result"] = result + return result, metadata + + elif expression.operator == "NOT": + # Negate the condition + result, sub_metadata = await self.evaluate( + ConditionExpression(condition=expression.condition), event_data + ) + negated_result = not result + metadata["negated"] = True + metadata["original_result"] = result + metadata["sub_condition"] = sub_metadata + return negated_result, metadata + + elif expression.operator == "AND": + # All conditions must be true + results = [] + sub_metadata_list = [] + for sub_expr in expression.conditions: + sub_result, sub_metadata = await self.evaluate(sub_expr, event_data) + results.append(sub_result) + sub_metadata_list.append(sub_metadata) + # Short-circuit: if any is False, AND is False + if not sub_result: + break + + result = all(results) + metadata["sub_conditions"] = sub_metadata_list + metadata["results"] = results + metadata["result"] = result + return result, metadata + + elif expression.operator == "OR": + # At least one condition must be true + results = [] + sub_metadata_list = [] + for sub_expr in expression.conditions: + sub_result, sub_metadata = await self.evaluate(sub_expr, event_data) + results.append(sub_result) + sub_metadata_list.append(sub_metadata) + # Short-circuit: if any is True, OR is True + if sub_result: + break + + result = any(results) + metadata["sub_conditions"] = sub_metadata_list + metadata["results"] = results + metadata["result"] = result + return result, metadata + + else: + raise ValueError(f"Unknown operator: {expression.operator}") + + except Exception as e: + logger.error(f"Error evaluating condition expression: {e}") + metadata["error"] = str(e) + # Fail closed: if we can't evaluate, assume violation + return False, metadata + + async def _evaluate_simple_condition(self, condition: RuleCondition, event_data: dict[str, Any]) -> bool: + """ + Evaluate a simple condition using a validator. + + Args: + condition: Rule condition to evaluate + event_data: Event data to evaluate against + + Returns: + True if condition is met, False otherwise + """ + validator = self.validator_registry.get(condition.type) + if not validator: + logger.warning(f"Unknown condition type: {condition.type}") + return False + + try: + result = await validator.validate(condition.parameters, event_data) + logger.debug(f"Condition {condition.type} evaluated: {result} (parameters: {condition.parameters})") + return result + except Exception as e: + logger.error(f"Error evaluating condition {condition.type}: {e}") + return False + + async def evaluate_rule_conditions( + self, + conditions: list[ConditionExpression] | list[RuleCondition] | None, + event_data: dict[str, Any], + ) -> tuple[bool, dict[str, Any]]: + """ + Evaluate rule conditions (backward compatible). + + If conditions is a list of RuleCondition (old format), treats them as AND. + If conditions is a list of ConditionExpression (new format), evaluates them. + + Args: + conditions: List of conditions or condition expressions + event_data: Event data to evaluate against + + Returns: + Tuple of (result: bool, metadata: dict) + """ + if not conditions: + # No conditions = rule passes + return True, {"message": "No conditions to evaluate"} + + # Check if it's old format (list of RuleCondition) + if conditions and isinstance(conditions[0], RuleCondition): + # Old format: treat as AND of all conditions + logger.debug("Using legacy condition format (treating as AND)") + results = [] + for condition in conditions: + result = await self._evaluate_simple_condition(condition, event_data) + results.append(result) + if not result: + break # Short-circuit AND + + all_passed = all(results) + return all_passed, { + "format": "legacy", + "results": results, + "condition_count": len(conditions), + } + + # New format: list of ConditionExpression + if len(conditions) == 1: + # Single condition expression + return await self.evaluate(conditions[0], event_data) + else: + # Multiple condition expressions - treat as AND by default + logger.debug(f"Evaluating {len(conditions)} condition expressions (default: AND)") + results = [] + metadata_list = [] + for expr in conditions: + result, metadata = await self.evaluate(expr, event_data) + results.append(result) + metadata_list.append(metadata) + if not result: + break # Short-circuit AND + + all_passed = all(results) + return all_passed, { + "format": "expression", + "default_operator": "AND", + "results": results, + "sub_conditions": metadata_list, + } diff --git a/src/rules/evaluator.py b/src/rules/evaluator.py new file mode 100644 index 0000000..92b7395 --- /dev/null +++ b/src/rules/evaluator.py @@ -0,0 +1,45 @@ +""" +Rule evaluation utilities including condition expression evaluation. +""" + +import logging +from typing import Any + +from src.rules.condition_evaluator import ConditionEvaluator +from src.rules.models import Rule + +logger = logging.getLogger(__name__) + + +async def evaluate_rule_conditions(rule: Rule, event_data: dict[str, Any]) -> tuple[bool, dict[str, Any]]: + """ + Evaluate rule conditions (both legacy and new format). + + Args: + rule: Rule object to evaluate + event_data: Event data to evaluate against + + Returns: + Tuple of (condition_passed: bool, metadata: dict) + - condition_passed: True if all conditions pass, False otherwise + - metadata: Evaluation details + """ + evaluator = ConditionEvaluator() + + # Check if rule has new condition expression format + if rule.condition is not None: + logger.debug(f"Evaluating condition expression for rule: {rule.description}") + result, metadata = await evaluator.evaluate(rule.condition, event_data) + metadata["format"] = "expression" + return result, metadata + + # Check if rule has legacy conditions + if rule.conditions: + logger.debug(f"Evaluating legacy conditions for rule: {rule.description}") + result, metadata = await evaluator.evaluate_rule_conditions(rule.conditions, event_data) + metadata["format"] = "legacy" + return result, metadata + + # No conditions = rule passes (no restrictions) + logger.debug(f"No conditions for rule: {rule.description}") + return True, {"message": "No conditions to evaluate", "format": "none"} diff --git a/src/rules/interface.py b/src/rules/interface.py index ae7bce8..f3b01b7 100644 --- a/src/rules/interface.py +++ b/src/rules/interface.py @@ -23,4 +23,4 @@ async def get_rules(self, repository: str, installation_id: int) -> list[Rule]: Returns: list of Rule objects for the repository """ - pass + raise NotImplementedError("Subclasses must implement get_rules") diff --git a/src/rules/loaders/github_loader.py b/src/rules/loaders/github_loader.py index 7782100..635869d 100644 --- a/src/rules/loaders/github_loader.py +++ b/src/rules/loaders/github_loader.py @@ -12,8 +12,9 @@ from src.core.config import config from src.core.models import EventType from src.integrations.github import GitHubClient, github_client +from src.rules.condition_evaluator import ConditionExpression from src.rules.interface import RuleLoader -from src.rules.models import Rule, RuleAction, RuleSeverity +from src.rules.models import Rule, RuleAction, RuleCondition, RuleSeverity logger = logging.getLogger(__name__) @@ -87,6 +88,23 @@ def _parse_rule(rule_data: dict[str, Any]) -> Rule: # No mapping: just pass parameters as-is parameters = rule_data.get("parameters", {}) + # Parse conditions (legacy format: list of conditions) + conditions = [] + if "conditions" in rule_data: + for cond_data in rule_data["conditions"]: + condition = RuleCondition(type=cond_data["type"], parameters=cond_data.get("parameters", {})) + conditions.append(condition) + + # Parse condition expression (new format: AND/OR/NOT) + condition_expr = None + if "condition" in rule_data: + try: + condition_expr = ConditionExpression.from_dict(rule_data["condition"]) + logger.debug(f"Parsed condition expression for rule: {rule_data['description']}") + except Exception as e: + logger.warning(f"Failed to parse condition expression for rule '{rule_data['description']}': {e}") + # Fall back to legacy format if condition expression parsing fails + # Actions are optional and not mapped actions = [] if "actions" in rule_data: @@ -99,8 +117,8 @@ def _parse_rule(rule_data: dict[str, Any]) -> Rule: enabled=rule_data.get("enabled", True), severity=RuleSeverity(rule_data.get("severity", "medium")), event_types=event_types, - # No conditions: parameters are passed as-is - conditions=[], + conditions=conditions, # Legacy format + condition=condition_expr, # New format actions=actions, parameters=parameters, ) diff --git a/src/rules/models.py b/src/rules/models.py index 8823bd3..9ca2be9 100644 --- a/src/rules/models.py +++ b/src/rules/models.py @@ -1,10 +1,13 @@ from enum import Enum -from typing import Any +from typing import TYPE_CHECKING, Any from pydantic import BaseModel, Field from src.core.models import EventType +if TYPE_CHECKING: + from src.rules.condition_evaluator import ConditionExpression + class RuleSeverity(str, Enum): """Enumerates the severity levels of a rule violation.""" @@ -32,12 +35,32 @@ class RuleAction(BaseModel): class Rule(BaseModel): - """Represents a rule that can be evaluated against repository events.""" + """ + Represents a rule that can be evaluated against repository events. + + Supports both legacy conditions (list of RuleCondition) and new condition expressions + (ConditionExpression with AND/OR/NOT operators). + """ + + model_config = {"arbitrary_types_allowed": True} description: str = Field(description="Primary identifier and description of the rule") enabled: bool = True severity: RuleSeverity = RuleSeverity.MEDIUM event_types: list[EventType] = Field(default_factory=list) - conditions: list[RuleCondition] = Field(default_factory=list) + conditions: list[RuleCondition] = Field(default_factory=list, description="Legacy conditions (treated as AND)") + condition: "ConditionExpression | None" = Field( + default=None, description="New condition expression with AND/OR/NOT support" + ) actions: list[RuleAction] = Field(default_factory=list) - parameters: dict[str, Any] = Field(default_factory=dict) # Store parameters as-is from YAML + parameters: dict[str, Any] = Field(default_factory=dict, description="Store parameters as-is from YAML") + + +def _rebuild_rule_model() -> None: + """Rebuild Rule model to resolve forward references.""" + from src.rules.condition_evaluator import ConditionExpression # noqa: F401 + + Rule.model_rebuild() + + +_rebuild_rule_model() diff --git a/src/rules/validators.py b/src/rules/validators.py index 2d575c1..3ac9d5c 100644 --- a/src/rules/validators.py +++ b/src/rules/validators.py @@ -101,7 +101,7 @@ async def validate(self, parameters: dict[str, Any], event: dict[str, Any]) -> b Returns: True if the condition is met, False otherwise """ - pass + raise NotImplementedError("Subclasses must implement validate") def get_description(self) -> dict[str, Any]: """Get validator description for dynamic strategy selection.""" diff --git a/src/webhooks/handlers/base.py b/src/webhooks/handlers/base.py index 9cda358..849c1e0 100644 --- a/src/webhooks/handlers/base.py +++ b/src/webhooks/handlers/base.py @@ -23,4 +23,4 @@ async def handle(self, event: WebhookEvent) -> dict[str, Any]: Returns: A dictionary containing the results of the handling logic. """ - pass + raise NotImplementedError("Subclasses must implement handle") diff --git a/tests/feedback/test_abstract_classes.py b/tests/feedback/test_abstract_classes.py new file mode 100644 index 0000000..75e7267 --- /dev/null +++ b/tests/feedback/test_abstract_classes.py @@ -0,0 +1,227 @@ +#!/usr/bin/env python3 +""" +Test script to verify that abstract methods now raise NotImplementedError +instead of using pass. + +This script tests by directly reading source files (no imports needed), +verifying the code changes are correct. +""" + +import re +from pathlib import Path + + +def check_file_uses_notimplemented(file_path: Path, class_name: str, method_names: list[str]) -> tuple[bool, list[str]]: + """ + Check if abstract methods in a file use raise NotImplementedError. + + Returns: + (all_passed, list_of_issues) + """ + try: + content = file_path.read_text() + issues = [] + all_passed = True + + for method_name in method_names: + # Find the abstract method definition + # Pattern: @abstractmethod ... def method_name(...): ... raise NotImplementedError + pattern = rf'@abstractmethod\s+(?:async\s+)?def\s+{re.escape(method_name)}\s*\([^)]*\)[^:]*:\s*"""[^"]*"""\s*(.*?)(?=\n @|\n\nclass|\nclass|\Z)' + match = re.search(pattern, content, re.DOTALL) + + if not match: + # Try a simpler pattern without docstring + pattern = rf"@abstractmethod\s+(?:async\s+)?def\s+{re.escape(method_name)}\s*\([^)]*\)[^:]*:\s*(.*?)(?=\n @|\n\nclass|\nclass|\Z)" + match = re.search(pattern, content, re.DOTALL) + + if match: + method_body = match.group(1) + # Check if it contains raise NotImplementedError + if "raise NotImplementedError" in method_body: + print(f" ✅ {class_name}.{method_name} uses raise NotImplementedError") + elif "pass" in method_body and "raise" not in method_body: + print(f" ❌ {class_name}.{method_name} still uses pass instead of raise NotImplementedError") + issues.append(f"{class_name}.{method_name} still uses pass") + all_passed = False + else: + # Might have both or neither, check more carefully + lines = method_body.strip().split("\n") + has_raise = any("raise NotImplementedError" in line for line in lines) + has_pass_only = any(line.strip() == "pass" for line in lines if "raise" not in line) + if has_pass_only and not has_raise: + print(f" ❌ {class_name}.{method_name} still uses pass instead of raise NotImplementedError") + issues.append(f"{class_name}.{method_name} still uses pass") + all_passed = False + elif has_raise: + print(f" ✅ {class_name}.{method_name} uses raise NotImplementedError") + else: + print(f" ⚠️ {class_name}.{method_name} - could not verify (unusual format)") + else: + print(f" ⚠️ {class_name}.{method_name} - could not find method definition") + + return all_passed, issues + except Exception as e: + print(f" ❌ Error checking {file_path}: {e}") + return False, [f"Error: {str(e)}"] + + +def verify_code_changes_direct(): + """Verify code changes by reading source files directly.""" + print("Verifying code changes by inspecting source files...") + print() + + base_path = Path(__file__).parent.parent.parent / "src" + all_passed = True + all_issues = [] + + files_to_check = [ + ( + "integrations/providers/base.py", + "BaseProvider", + ["get_chat_model", "supports_structured_output", "get_provider_name"], + ), + ("agents/base.py", "BaseAgent", ["_build_graph", "execute"]), + ( + "event_processors/base.py", + "BaseEventProcessor", + ["process", "get_event_type", "prepare_webhook_data", "prepare_api_data"], + ), + ("webhooks/handlers/base.py", "EventHandler", ["handle"]), + ("rules/interface.py", "RuleLoader", ["get_rules"]), + ("rules/validators.py", "Condition", ["validate"]), + ] + + for file_rel_path, class_name, method_names in files_to_check: + file_path = base_path / file_rel_path + if not file_path.exists(): + print(f"❌ File not found: {file_path}") + all_passed = False + all_issues.append(f"File not found: {file_rel_path}") + continue + + print(f"Checking {file_rel_path} ({class_name})...") + passed, issues = check_file_uses_notimplemented(file_path, class_name, method_names) + if not passed: + all_passed = False + all_issues.extend(issues) + print() + + return all_passed, all_issues + + +def verify_with_grep(): + """Alternative verification using grep-like pattern matching.""" + print("Alternative verification: Checking for remaining 'pass' in abstract methods...") + print() + + base_path = Path(__file__).parent.parent.parent / "src" + + # Find all Python files + python_files = list(base_path.rglob("*.py")) + + issues_found = [] + + for py_file in python_files: + try: + content = py_file.read_text() + + # Look for @abstractmethod followed by method definition and pass + # More specific pattern: @abstractmethod ... def ... : ... pass (with proper indentation) + lines = content.split("\n") + + in_abstract_method = False + abstract_method_indent = 0 + method_name = None + class_name = None + + for i, line in enumerate(lines): + # Detect class definitions + class_match = re.match(r"^class\s+(\w+)", line) + if class_match: + class_name = class_match.group(1) + + # Detect @abstractmethod + if "@abstractmethod" in line: + in_abstract_method = True + abstract_method_indent = len(line) - len(line.lstrip()) + # Look ahead for method definition + for j in range(i + 1, min(i + 10, len(lines))): + method_match = re.match(r"^\s*(?:async\s+)?def\s+(\w+)", lines[j]) + if method_match: + method_name = method_match.group(1) + break + continue + + # If we're in an abstract method, check for pass + if in_abstract_method: + # Check if this line is just "pass" with appropriate indentation + stripped = line.strip() + line_indent = len(line) - len(line.lstrip()) + + # If we hit another method or class at same or less indentation, we're done + if ( + re.match(r"^\s*(?:@|def\s+|class\s+|async\s+def\s+)", line) + and line_indent <= abstract_method_indent + and line_indent > 0 + ): + in_abstract_method = False + continue + + # Check for standalone pass (not in a comment or string) + if stripped == "pass" and "raise NotImplementedError" not in content[max(0, i - 20) : i + 5]: + # Double check this is actually in the method body + # Look backwards to see if there's a raise statement + method_content = "\n".join(lines[max(0, i - 30) : i + 5]) + if "@abstractmethod" in method_content and "raise NotImplementedError" not in method_content: + rel_path = py_file.relative_to(base_path.parent) + issues_found.append(f"{rel_path}: {class_name}.{method_name} uses pass") + print(f" ❌ Found 'pass' in abstract method: {rel_path} -> {class_name}.{method_name}") + in_abstract_method = False + + except Exception: + continue # Skip files we can't read + + return len(issues_found) == 0, issues_found + + +def main(): + """Main test function.""" + print("=" * 60) + print("Testing Abstract Methods - NotImplementedError Changes") + print("=" * 60) + print() + + # Method 1: Direct file inspection (more reliable) + all_passed, issues = verify_code_changes_direct() + + print("=" * 60) + + # Method 2: Grep-like verification (double check) + grep_passed, grep_issues = verify_with_grep() + if grep_issues: + print("\nGrep verification found additional issues:") + for issue in grep_issues: + print(f" - {issue}") + all_passed = False + + print() + print("=" * 60) + if all_passed and grep_passed: + print("✅ ALL VERIFICATIONS PASSED") + print() + print("Summary:") + print(" - All abstract methods now use 'raise NotImplementedError'") + print(" - No abstract methods use 'pass' anymore") + print(" - Code changes are correct") + return 0 + else: + print("❌ SOME VERIFICATIONS FAILED") + print() + print("Issues found:") + for issue in issues + grep_issues: + print(f" - {issue}") + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/tests/feedback/test_caching_strategy.py b/tests/feedback/test_caching_strategy.py new file mode 100644 index 0000000..f350adb --- /dev/null +++ b/tests/feedback/test_caching_strategy.py @@ -0,0 +1,507 @@ +#!/usr/bin/env python3 +""" +Tests for caching strategy and configuration. + +This test suite verifies that: +1. Cache configuration is properly loaded from environment variables +2. Cache TTL and size limits are configurable +3. Cache respects enable/disable settings +4. Cache eviction policy (LRU) works correctly +5. Cache expiration works as expected +6. Global cache uses config values +7. Function-level cache decorator respects config + +Can be run in two ways: +1. As pytest test: pytest tests/feedback/test_caching_strategy.py -v +2. As standalone verification: python3 tests/feedback/test_caching_strategy.py +""" + +import os +import sys +import time +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest + +# Add project root to path for imports when running directly +ROOT = Path(__file__).resolve().parent.parent.parent +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + + +class TestCacheConfig: + """Test cache configuration loading.""" + + def test_cache_config_defaults(self): + """Test that CacheConfig has sensible defaults.""" + from src.core.config.cache_config import CacheConfig + + config = CacheConfig() + assert config.global_maxsize == 1024 + assert config.global_ttl == 3600 + assert config.default_maxsize == 100 + assert config.default_ttl == 3600 + assert config.enable_cache is True + assert config.enable_metrics is False + + def test_cache_config_from_env(self): + """Test that cache config loads from environment variables.""" + from src.core.config.settings import Config + + with patch.dict( + os.environ, + { + "CACHE_GLOBAL_MAXSIZE": "2048", + "CACHE_GLOBAL_TTL": "7200", + "CACHE_DEFAULT_MAXSIZE": "200", + "CACHE_DEFAULT_TTL": "1800", + "CACHE_ENABLE": "false", + "CACHE_ENABLE_METRICS": "true", + }, + ): + # Reload config to pick up env vars + config = Config() + assert config.cache.global_maxsize == 2048 + assert config.cache.global_ttl == 7200 + assert config.cache.default_maxsize == 200 + assert config.cache.default_ttl == 1800 + assert config.cache.enable_cache is False + assert config.cache.enable_metrics is True + + def test_cache_config_in_main_config(self): + """Test that cache config is included in main Config class.""" + from src.core.config.settings import Config + + config = Config() + assert hasattr(config, "cache") + assert config.cache is not None + assert hasattr(config.cache, "global_maxsize") + assert hasattr(config.cache, "global_ttl") + assert hasattr(config.cache, "default_maxsize") + assert hasattr(config.cache, "default_ttl") + assert hasattr(config.cache, "enable_cache") + assert hasattr(config.cache, "enable_metrics") + + +class TestAsyncCache: + """Test AsyncCache class functionality.""" + + def test_cache_initialization(self): + """Test cache initialization with custom values.""" + from src.core.utils.caching import AsyncCache + + cache = AsyncCache(maxsize=50, ttl=300) + assert cache.maxsize == 50 + assert cache.ttl == 300 + assert cache.size() == 0 + + def test_cache_set_and_get(self): + """Test basic cache set and get operations.""" + from src.core.utils.caching import AsyncCache + + cache = AsyncCache(maxsize=10, ttl=3600) + cache.set("key1", "value1") + assert cache.get("key1") == "value1" + assert cache.size() == 1 + + def test_cache_expiration(self): + """Test that cache entries expire after TTL.""" + from src.core.utils.caching import AsyncCache + + cache = AsyncCache(maxsize=10, ttl=1) # 1 second TTL + cache.set("key1", "value1") + assert cache.get("key1") == "value1" + + # Wait for expiration + time.sleep(1.1) + assert cache.get("key1") is None + assert cache.size() == 0 + + def test_cache_lru_eviction(self): + """Test that cache evicts oldest entries when full (LRU policy).""" + from src.core.utils.caching import AsyncCache + + cache = AsyncCache(maxsize=3, ttl=3600) + # Fill cache to capacity + cache.set("key1", "value1") + time.sleep(0.01) # Small delay to ensure different timestamps + cache.set("key2", "value2") + time.sleep(0.01) + cache.set("key3", "value3") + assert cache.size() == 3 + + # Add one more - should evict oldest (key1) + time.sleep(0.01) + cache.set("key4", "value4") + assert cache.size() == 3 + assert cache.get("key1") is None # Oldest evicted + assert cache.get("key2") == "value2" # Still present + assert cache.get("key3") == "value3" # Still present + assert cache.get("key4") == "value4" # Newest present + + def test_cache_clear(self): + """Test clearing all cache entries.""" + from src.core.utils.caching import AsyncCache + + cache = AsyncCache(maxsize=10, ttl=3600) + cache.set("key1", "value1") + cache.set("key2", "value2") + assert cache.size() == 2 + + cache.clear() + assert cache.size() == 0 + assert cache.get("key1") is None + assert cache.get("key2") is None + + def test_cache_invalidate(self): + """Test invalidating a specific cache entry.""" + from src.core.utils.caching import AsyncCache + + cache = AsyncCache(maxsize=10, ttl=3600) + cache.set("key1", "value1") + cache.set("key2", "value2") + assert cache.size() == 2 + + cache.invalidate("key1") + assert cache.size() == 1 + assert cache.get("key1") is None + assert cache.get("key2") == "value2" + + def test_cache_size_tracking(self): + """Test that cache correctly tracks size.""" + from src.core.utils.caching import AsyncCache + + cache = AsyncCache(maxsize=5, ttl=3600) + assert cache.size() == 0 + + for i in range(3): + cache.set(f"key{i}", f"value{i}") + assert cache.size() == 3 + + cache.invalidate("key1") + assert cache.size() == 2 + + +class TestGlobalCache: + """Test global module-level cache functionality.""" + + @pytest.mark.asyncio + async def test_global_cache_uses_config(self): + """Test that global cache uses config values.""" + # Reset global cache to force re-initialization + import src.core.utils.caching as caching_module + from src.core.utils.caching import _get_global_cache + + caching_module._GLOBAL_CACHE = None + + with patch("src.core.utils.caching._get_config") as mock_get_config: + mock_config = MagicMock() + mock_config.cache.global_maxsize = 512 + mock_config.cache.global_ttl = 1800 + mock_config.cache.enable_cache = True + mock_get_config.return_value = mock_config + + cache = _get_global_cache() + assert cache.maxsize == 512 + assert cache.ttl == 1800 + + @pytest.mark.asyncio + async def test_get_cache_respects_enable_flag(self): + """Test that get_cache respects CACHE_ENABLE setting.""" + # Reset global cache + import src.core.utils.caching as caching_module + from src.core.utils.caching import get_cache + + caching_module._GLOBAL_CACHE = None + + with patch("src.core.utils.caching._get_config") as mock_get_config: + mock_config = MagicMock() + mock_config.cache.enable_cache = False + mock_get_config.return_value = mock_config + + # Should return None when cache is disabled + result = await get_cache("test_key") + assert result is None + + @pytest.mark.asyncio + async def test_set_cache_respects_enable_flag(self): + """Test that set_cache respects CACHE_ENABLE setting.""" + # Reset global cache + import src.core.utils.caching as caching_module + from src.core.utils.caching import get_cache, set_cache + + caching_module._GLOBAL_CACHE = None + + with patch("src.core.utils.caching._get_config") as mock_get_config: + mock_config = MagicMock() + mock_config.cache.enable_cache = False + mock_get_config.return_value = mock_config + + # Should be no-op when cache is disabled + await set_cache("test_key", "test_value") + result = await get_cache("test_key") + assert result is None + + @pytest.mark.asyncio + async def test_get_cache_and_set_cache_integration(self): + """Test integration of get_cache and set_cache.""" + # Reset global cache + import src.core.utils.caching as caching_module + from src.core.utils.caching import get_cache, set_cache + + caching_module._GLOBAL_CACHE = None + + with patch("src.core.utils.caching._get_config") as mock_get_config: + mock_config = MagicMock() + mock_config.cache.global_maxsize = 100 + mock_config.cache.global_ttl = 3600 + mock_config.cache.enable_cache = True + mock_get_config.return_value = mock_config + + # Set and get value + await set_cache("test_key", "test_value") + result = await get_cache("test_key") + assert result == "test_value" + + @pytest.mark.asyncio + async def test_set_cache_ttl_override(self): + """Test that set_cache can override TTL.""" + # Reset global cache + import src.core.utils.caching as caching_module + from src.core.utils.caching import _get_global_cache, set_cache + + caching_module._GLOBAL_CACHE = None + + with patch("src.core.utils.caching._get_config") as mock_get_config: + mock_config = MagicMock() + mock_config.cache.global_maxsize = 100 + mock_config.cache.global_ttl = 3600 + mock_config.cache.enable_cache = True + mock_get_config.return_value = mock_config + + cache = _get_global_cache() + assert cache.ttl == 3600 + + # Override TTL + await set_cache("test_key", "test_value", ttl=1800) + assert cache.ttl == 1800 + + +class TestCachedAsyncDecorator: + """Test @cached_async decorator functionality.""" + + @pytest.mark.asyncio + async def test_cached_async_basic(self): + """Test basic caching with @cached_async decorator.""" + from src.core.utils.caching import cached_async + + call_count = 0 + + @cached_async(ttl=3600, maxsize=10) + async def test_func(x: int) -> int: + nonlocal call_count + call_count += 1 + return x * 2 + + # First call - cache miss + result1 = await test_func(5) + assert result1 == 10 + assert call_count == 1 + + # Second call - cache hit + result2 = await test_func(5) + assert result2 == 10 + assert call_count == 1 # Not called again + + @pytest.mark.asyncio + async def test_cached_async_uses_config_defaults(self): + """Test that @cached_async uses config defaults when not specified.""" + from src.core.utils.caching import cached_async + + with patch("src.core.utils.caching._get_config") as mock_get_config: + mock_config = MagicMock() + mock_config.cache.default_maxsize = 50 + mock_config.cache.default_ttl = 1800 + mock_config.cache.enable_cache = True + mock_get_config.return_value = mock_config + + call_count = 0 + + @cached_async() # No parameters - should use config defaults + async def test_func(x: int) -> int: + nonlocal call_count + call_count += 1 + return x * 2 + + result = await test_func(5) + assert result == 10 + assert call_count == 1 + + @pytest.mark.asyncio + async def test_cached_async_respects_enable_flag(self): + """Test that @cached_async respects CACHE_ENABLE setting.""" + from src.core.utils.caching import cached_async + + with patch("src.core.utils.caching._get_config") as mock_get_config: + mock_config = MagicMock() + mock_config.cache.default_maxsize = 50 + mock_config.cache.default_ttl = 1800 + mock_config.cache.enable_cache = False # Cache disabled + mock_get_config.return_value = mock_config + + call_count = 0 + + @cached_async() + async def test_func(x: int) -> int: + nonlocal call_count + call_count += 1 + return x * 2 + + # Both calls should execute (cache disabled) + result1 = await test_func(5) + result2 = await test_func(5) + assert result1 == 10 + assert result2 == 10 + assert call_count == 2 # Called twice because cache is disabled + + @pytest.mark.asyncio + async def test_cached_async_custom_key_func(self): + """Test @cached_async with custom key function.""" + from src.core.utils.caching import cached_async + + call_count = 0 + + def key_func(x: int, y: int) -> str: + return f"custom:{x}:{y}" + + @cached_async(ttl=3600, key_func=key_func) + async def test_func(x: int, y: int) -> int: + nonlocal call_count + call_count += 1 + return x + y + + # First call + result1 = await test_func(5, 3) + assert result1 == 8 + assert call_count == 1 + + # Second call with same args - cache hit + result2 = await test_func(5, 3) + assert result2 == 8 + assert call_count == 1 + + # Different args - cache miss + result3 = await test_func(5, 4) + assert result3 == 9 + assert call_count == 2 + + +class TestCacheDocumentation: + """Test that caching strategy is properly documented.""" + + def test_caching_module_has_documentation(self): + """Test that caching.py has comprehensive documentation.""" + caching_file = ROOT / "src" / "core" / "utils" / "caching.py" + assert caching_file.exists() + + content = caching_file.read_text() + + # Check for key documentation sections + assert "Caching Strategy" in content + assert "TTL" in content or "Time To Live" in content + assert "Eviction Policy" in content or "LRU" in content + assert "Configuration" in content + assert "Environment variables" in content or "CACHE_" in content + + def test_cache_config_has_docstrings(self): + """Test that CacheConfig has proper docstrings.""" + from src.core.config.cache_config import CacheConfig + + assert CacheConfig.__doc__ is not None + assert len(CacheConfig.__doc__.strip()) > 0 + + +def run_standalone_verification(): + """Run verification checks that don't require pytest.""" + print("=" * 60) + print("Caching Strategy Verification") + print("=" * 60) + print() + + all_passed = True + + # Test 1: Config exists + print("1. Checking CacheConfig exists...") + try: + from src.core.config.cache_config import CacheConfig + + config = CacheConfig() + print(" ✅ CacheConfig created with defaults") + print(f" - global_maxsize: {config.global_maxsize}") + print(f" - global_ttl: {config.global_ttl}") + print(f" - default_maxsize: {config.default_maxsize}") + print(f" - default_ttl: {config.default_ttl}") + print(f" - enable_cache: {config.enable_cache}") + except Exception as e: + print(f" ❌ Failed to import CacheConfig: {e}") + all_passed = False + + # Test 2: Config in main settings + print() + print("2. Checking CacheConfig in main Config...") + try: + from src.core.config.settings import Config + + config = Config() + assert hasattr(config, "cache") + print(" ✅ CacheConfig included in main Config") + print(f" - config.cache.global_maxsize: {config.cache.global_maxsize}") + except Exception as e: + print(f" ❌ Failed to access cache config: {e}") + all_passed = False + + # Test 3: AsyncCache works + print() + print("3. Checking AsyncCache functionality...") + try: + from src.core.utils.caching import AsyncCache + + cache = AsyncCache(maxsize=5, ttl=1) + cache.set("test", "value") + assert cache.get("test") == "value" + assert cache.size() == 1 + print(" ✅ AsyncCache basic operations work") + except Exception as e: + print(f" ❌ AsyncCache test failed: {e}") + all_passed = False + + # Test 4: Documentation exists + print() + print("4. Checking documentation...") + caching_file = ROOT / "src" / "core" / "utils" / "caching.py" + if caching_file.exists(): + content = caching_file.read_text() + if "Caching Strategy" in content: + print(" ✅ Caching strategy documentation found") + else: + print(" ⚠️ Caching strategy documentation may be incomplete") + else: + print(" ❌ Caching file not found") + all_passed = False + + print() + print("=" * 60) + if all_passed: + print("✅ All verification checks passed!") + else: + print("❌ Some checks failed") + print("=" * 60) + + return all_passed + + +if __name__ == "__main__": + # Run standalone verification when executed directly + success = run_standalone_verification() + sys.exit(0 if success else 1) diff --git a/tests/feedback/test_condition_logic.py b/tests/feedback/test_condition_logic.py new file mode 100644 index 0000000..e42dcd9 --- /dev/null +++ b/tests/feedback/test_condition_logic.py @@ -0,0 +1,560 @@ +#!/usr/bin/env python3 +""" +Tests for condition logic with AND/OR/NOT operators. + +This test suite verifies that: +1. Simple conditions work correctly +2. AND operator works (all conditions must pass) +3. OR operator works (at least one condition must pass) +4. NOT operator works (negates condition) +5. Nested conditions work correctly +6. Edge cases are handled properly + +Can be run in two ways: +1. As pytest test: pytest tests/feedback/test_condition_logic.py -v +2. As standalone verification: python3 tests/feedback/test_condition_logic.py +""" + +import sys +from pathlib import Path +from unittest.mock import AsyncMock, patch + +import pytest + +# Add project root to path for imports when running directly +ROOT = Path(__file__).resolve().parent.parent.parent +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + + +class TestConditionExpression: + """Test ConditionExpression model.""" + + def test_simple_condition_creation(self): + """Test creating a simple condition expression.""" + from src.rules.condition_evaluator import ConditionExpression + from src.rules.models import RuleCondition + + condition = RuleCondition(type="author_team_is", parameters={"team": "devops"}) + expr = ConditionExpression(condition=condition) + + assert expr.operator is None + assert expr.condition == condition + assert expr.conditions == [] + + def test_and_condition_creation(self): + """Test creating an AND condition expression.""" + from src.rules.condition_evaluator import ConditionExpression + from src.rules.models import RuleCondition + + cond1 = RuleCondition(type="author_team_is", parameters={"team": "devops"}) + cond2 = RuleCondition(type="files_match_pattern", parameters={"pattern": "*.py"}) + expr1 = ConditionExpression(condition=cond1) + expr2 = ConditionExpression(condition=cond2) + + and_expr = ConditionExpression(operator="AND", conditions=[expr1, expr2]) + + assert and_expr.operator == "AND" + assert len(and_expr.conditions) == 2 + + def test_or_condition_creation(self): + """Test creating an OR condition expression.""" + from src.rules.condition_evaluator import ConditionExpression + from src.rules.models import RuleCondition + + cond1 = RuleCondition(type="author_team_is", parameters={"team": "devops"}) + cond2 = RuleCondition(type="is_weekend", parameters={}) + expr1 = ConditionExpression(condition=cond1) + expr2 = ConditionExpression(condition=cond2) + + or_expr = ConditionExpression(operator="OR", conditions=[expr1, expr2]) + + assert or_expr.operator == "OR" + assert len(or_expr.conditions) == 2 + + def test_not_condition_creation(self): + """Test creating a NOT condition expression.""" + from src.rules.condition_evaluator import ConditionExpression + from src.rules.models import RuleCondition + + condition = RuleCondition(type="is_weekend", parameters={}) + not_expr = ConditionExpression(operator="NOT", condition=condition) + + assert not_expr.operator == "NOT" + assert not_expr.condition == condition + + def test_from_dict_simple(self): + """Test creating expression from dictionary (simple condition).""" + from src.rules.condition_evaluator import ConditionExpression + + data = {"type": "author_team_is", "parameters": {"team": "devops"}} + expr = ConditionExpression.from_dict(data) + + assert expr.operator is None + assert expr.condition.type == "author_team_is" + assert expr.condition.parameters == {"team": "devops"} + + def test_from_dict_and(self): + """Test creating expression from dictionary (AND operator).""" + from src.rules.condition_evaluator import ConditionExpression + + data = { + "operator": "AND", + "conditions": [ + {"type": "author_team_is", "parameters": {"team": "devops"}}, + {"type": "files_match_pattern", "parameters": {"pattern": "*.py"}}, + ], + } + expr = ConditionExpression.from_dict(data) + + assert expr.operator == "AND" + assert len(expr.conditions) == 2 + + def test_from_dict_nested(self): + """Test creating nested expression from dictionary.""" + from src.rules.condition_evaluator import ConditionExpression + + data = { + "operator": "OR", + "conditions": [ + { + "operator": "AND", + "conditions": [ + {"type": "author_team_is", "parameters": {"team": "security"}}, + {"type": "files_match_pattern", "parameters": {"pattern": "**/auth/**"}}, + ], + }, + { + "operator": "AND", + "conditions": [ + {"type": "author_team_is", "parameters": {"team": "devops"}}, + {"type": "is_weekend", "parameters": {}}, + ], + }, + ], + } + expr = ConditionExpression.from_dict(data) + + assert expr.operator == "OR" + assert len(expr.conditions) == 2 + assert expr.conditions[0].operator == "AND" + assert expr.conditions[1].operator == "AND" + + def test_to_dict_simple(self): + """Test converting simple expression to dictionary.""" + from src.rules.condition_evaluator import ConditionExpression + from src.rules.models import RuleCondition + + condition = RuleCondition(type="author_team_is", parameters={"team": "devops"}) + expr = ConditionExpression(condition=condition) + data = expr.to_dict() + + assert data["type"] == "author_team_is" + assert data["parameters"] == {"team": "devops"} + + +class TestConditionEvaluator: + """Test ConditionEvaluator functionality.""" + + @pytest.mark.asyncio + async def test_evaluate_simple_condition(self): + """Test evaluating a simple condition.""" + from src.rules.condition_evaluator import ConditionEvaluator, ConditionExpression + from src.rules.models import RuleCondition + + # Mock validator + mock_validator = AsyncMock() + mock_validator.validate = AsyncMock(return_value=True) + + with patch("src.rules.condition_evaluator.VALIDATOR_REGISTRY", {"author_team_is": mock_validator}): + evaluator = ConditionEvaluator() + condition = RuleCondition(type="author_team_is", parameters={"team": "devops"}) + expr = ConditionExpression(condition=condition) + + event_data = {"sender": {"login": "devops-user"}} + result, metadata = await evaluator.evaluate(expr, event_data) + + assert result is True + assert metadata["condition_type"] == "author_team_is" + mock_validator.validate.assert_called_once() + + @pytest.mark.asyncio + async def test_evaluate_and_condition_all_pass(self): + """Test AND condition where all conditions pass.""" + from src.rules.condition_evaluator import ConditionEvaluator, ConditionExpression + from src.rules.models import RuleCondition + + # Mock validators + mock_validator1 = AsyncMock() + mock_validator1.validate = AsyncMock(return_value=True) + mock_validator2 = AsyncMock() + mock_validator2.validate = AsyncMock(return_value=True) + + with patch( + "src.rules.condition_evaluator.VALIDATOR_REGISTRY", + {"author_team_is": mock_validator1, "files_match_pattern": mock_validator2}, + ): + evaluator = ConditionEvaluator() + cond1 = RuleCondition(type="author_team_is", parameters={"team": "devops"}) + cond2 = RuleCondition(type="files_match_pattern", parameters={"pattern": "*.py"}) + expr1 = ConditionExpression(condition=cond1) + expr2 = ConditionExpression(condition=cond2) + and_expr = ConditionExpression(operator="AND", conditions=[expr1, expr2]) + + event_data = {} + result, metadata = await evaluator.evaluate(and_expr, event_data) + + assert result is True + assert metadata["operator"] == "AND" + assert all(metadata["results"]) + + @pytest.mark.asyncio + async def test_evaluate_and_condition_one_fails(self): + """Test AND condition where one condition fails.""" + from src.rules.condition_evaluator import ConditionEvaluator, ConditionExpression + from src.rules.models import RuleCondition + + # Mock validators - one passes, one fails + mock_validator1 = AsyncMock() + mock_validator1.validate = AsyncMock(return_value=True) + mock_validator2 = AsyncMock() + mock_validator2.validate = AsyncMock(return_value=False) + + with patch( + "src.rules.condition_evaluator.VALIDATOR_REGISTRY", + {"author_team_is": mock_validator1, "files_match_pattern": mock_validator2}, + ): + evaluator = ConditionEvaluator() + cond1 = RuleCondition(type="author_team_is", parameters={"team": "devops"}) + cond2 = RuleCondition(type="files_match_pattern", parameters={"pattern": "*.py"}) + expr1 = ConditionExpression(condition=cond1) + expr2 = ConditionExpression(condition=cond2) + and_expr = ConditionExpression(operator="AND", conditions=[expr1, expr2]) + + event_data = {} + result, metadata = await evaluator.evaluate(and_expr, event_data) + + assert result is False + assert metadata["operator"] == "AND" + assert not all(metadata["results"]) + + @pytest.mark.asyncio + async def test_evaluate_or_condition_one_passes(self): + """Test OR condition where one condition passes.""" + from src.rules.condition_evaluator import ConditionEvaluator, ConditionExpression + from src.rules.models import RuleCondition + + # Mock validators - one passes, one fails + mock_validator1 = AsyncMock() + mock_validator1.validate = AsyncMock(return_value=True) + mock_validator2 = AsyncMock() + mock_validator2.validate = AsyncMock(return_value=False) + + with patch( + "src.rules.condition_evaluator.VALIDATOR_REGISTRY", + {"author_team_is": mock_validator1, "is_weekend": mock_validator2}, + ): + evaluator = ConditionEvaluator() + cond1 = RuleCondition(type="author_team_is", parameters={"team": "devops"}) + cond2 = RuleCondition(type="is_weekend", parameters={}) + expr1 = ConditionExpression(condition=cond1) + expr2 = ConditionExpression(condition=cond2) + or_expr = ConditionExpression(operator="OR", conditions=[expr1, expr2]) + + event_data = {} + result, metadata = await evaluator.evaluate(or_expr, event_data) + + assert result is True + assert metadata["operator"] == "OR" + assert any(metadata["results"]) + + @pytest.mark.asyncio + async def test_evaluate_or_condition_all_fail(self): + """Test OR condition where all conditions fail.""" + from src.rules.condition_evaluator import ConditionEvaluator, ConditionExpression + from src.rules.models import RuleCondition + + # Mock validators - both fail + mock_validator1 = AsyncMock() + mock_validator1.validate = AsyncMock(return_value=False) + mock_validator2 = AsyncMock() + mock_validator2.validate = AsyncMock(return_value=False) + + with patch( + "src.rules.condition_evaluator.VALIDATOR_REGISTRY", + {"author_team_is": mock_validator1, "is_weekend": mock_validator2}, + ): + evaluator = ConditionEvaluator() + cond1 = RuleCondition(type="author_team_is", parameters={"team": "devops"}) + cond2 = RuleCondition(type="is_weekend", parameters={}) + expr1 = ConditionExpression(condition=cond1) + expr2 = ConditionExpression(condition=cond2) + or_expr = ConditionExpression(operator="OR", conditions=[expr1, expr2]) + + event_data = {} + result, metadata = await evaluator.evaluate(or_expr, event_data) + + assert result is False + assert metadata["operator"] == "OR" + assert not any(metadata["results"]) + + @pytest.mark.asyncio + async def test_evaluate_not_condition(self): + """Test NOT condition (negation).""" + from src.rules.condition_evaluator import ConditionEvaluator, ConditionExpression + from src.rules.models import RuleCondition + + # Mock validator returns True + mock_validator = AsyncMock() + mock_validator.validate = AsyncMock(return_value=True) + + with patch("src.rules.condition_evaluator.VALIDATOR_REGISTRY", {"is_weekend": mock_validator}): + evaluator = ConditionEvaluator() + condition = RuleCondition(type="is_weekend", parameters={}) + not_expr = ConditionExpression(operator="NOT", condition=condition) + + event_data = {} + result, metadata = await evaluator.evaluate(not_expr, event_data) + + # NOT True = False + assert result is False + assert metadata["operator"] == "NOT" + assert metadata["negated"] is True + assert metadata["original_result"] is True + + @pytest.mark.asyncio + async def test_evaluate_nested_conditions(self): + """Test nested condition expressions.""" + from src.rules.condition_evaluator import ConditionEvaluator, ConditionExpression + from src.rules.models import RuleCondition + + # Mock validators + mock_validators = { + "author_team_is": AsyncMock(), + "files_match_pattern": AsyncMock(), + "is_weekend": AsyncMock(), + } + mock_validators["author_team_is"].validate = AsyncMock(return_value=True) + mock_validators["files_match_pattern"].validate = AsyncMock(return_value=True) + mock_validators["is_weekend"].validate = AsyncMock(return_value=False) + + with patch("src.rules.condition_evaluator.VALIDATOR_REGISTRY", mock_validators): + evaluator = ConditionEvaluator() + + # (author_team_is AND files_match_pattern) OR is_weekend + cond1 = RuleCondition(type="author_team_is", parameters={"team": "security"}) + cond2 = RuleCondition(type="files_match_pattern", parameters={"pattern": "**/auth/**"}) + cond3 = RuleCondition(type="is_weekend", parameters={}) + + expr1 = ConditionExpression(condition=cond1) + expr2 = ConditionExpression(condition=cond2) + expr3 = ConditionExpression(condition=cond3) + + and_expr = ConditionExpression(operator="AND", conditions=[expr1, expr2]) + or_expr = ConditionExpression(operator="OR", conditions=[and_expr, expr3]) + + event_data = {} + result, metadata = await evaluator.evaluate(or_expr, event_data) + + # (True AND True) OR False = True OR False = True + assert result is True + assert metadata["operator"] == "OR" + + @pytest.mark.asyncio + async def test_evaluate_legacy_conditions(self): + """Test evaluating legacy conditions (list of RuleCondition).""" + from src.rules.condition_evaluator import ConditionEvaluator + from src.rules.models import RuleCondition + + # Mock validators + mock_validator1 = AsyncMock() + mock_validator1.validate = AsyncMock(return_value=True) + mock_validator2 = AsyncMock() + mock_validator2.validate = AsyncMock(return_value=True) + + with patch( + "src.rules.condition_evaluator.VALIDATOR_REGISTRY", + {"author_team_is": mock_validator1, "files_match_pattern": mock_validator2}, + ): + evaluator = ConditionEvaluator() + conditions = [ + RuleCondition(type="author_team_is", parameters={"team": "devops"}), + RuleCondition(type="files_match_pattern", parameters={"pattern": "*.py"}), + ] + + event_data = {} + result, metadata = await evaluator.evaluate_rule_conditions(conditions, event_data) + + assert result is True + assert metadata["format"] == "legacy" + assert len(metadata["results"]) == 2 + + +class TestRuleConditionEvaluation: + """Test rule condition evaluation integration.""" + + @pytest.mark.asyncio + async def test_evaluate_rule_with_condition_expression(self): + """Test evaluating a rule with condition expression.""" + from src.core.models import EventType + from src.rules.condition_evaluator import ConditionExpression + from src.rules.evaluator import evaluate_rule_conditions + from src.rules.models import Rule, RuleSeverity + + # Mock validator + mock_validator = AsyncMock() + mock_validator.validate = AsyncMock(return_value=True) + + with patch("src.rules.condition_evaluator.VALIDATOR_REGISTRY", {"author_team_is": mock_validator}): + condition = ConditionExpression.from_dict({"type": "author_team_is", "parameters": {"team": "devops"}}) + + rule = Rule( + description="Test rule", + enabled=True, + severity=RuleSeverity.HIGH, + event_types=[EventType.PULL_REQUEST], + condition=condition, + ) + + event_data = {"sender": {"login": "devops-user"}} + result, metadata = await evaluate_rule_conditions(rule, event_data) + + assert result is True + assert metadata["format"] == "expression" + + @pytest.mark.asyncio + async def test_evaluate_rule_with_legacy_conditions(self): + """Test evaluating a rule with legacy conditions.""" + from src.core.models import EventType + from src.rules.evaluator import evaluate_rule_conditions + from src.rules.models import Rule, RuleCondition, RuleSeverity + + # Mock validators + mock_validator = AsyncMock() + mock_validator.validate = AsyncMock(return_value=True) + + with patch("src.rules.condition_evaluator.VALIDATOR_REGISTRY", {"author_team_is": mock_validator}): + rule = Rule( + description="Test rule", + enabled=True, + severity=RuleSeverity.HIGH, + event_types=[EventType.PULL_REQUEST], + conditions=[ + RuleCondition(type="author_team_is", parameters={"team": "devops"}), + ], + ) + + event_data = {"sender": {"login": "devops-user"}} + result, metadata = await evaluate_rule_conditions(rule, event_data) + + assert result is True + assert metadata["format"] == "legacy" + + @pytest.mark.asyncio + async def test_evaluate_rule_without_conditions(self): + """Test evaluating a rule without conditions.""" + from src.core.models import EventType + from src.rules.evaluator import evaluate_rule_conditions + from src.rules.models import Rule, RuleSeverity + + rule = Rule( + description="Test rule", + enabled=True, + severity=RuleSeverity.HIGH, + event_types=[EventType.PULL_REQUEST], + ) + + event_data = {} + result, metadata = await evaluate_rule_conditions(rule, event_data) + + assert result is True + assert metadata["format"] == "none" + + +def run_standalone_verification(): + """Run verification checks that don't require pytest.""" + print("=" * 60) + print("Condition Logic Verification") + print("=" * 60) + print() + + all_passed = True + + # Test 1: ConditionExpression exists + print("1. Checking ConditionExpression exists...") + try: + from src.rules.condition_evaluator import ConditionEvaluator, ConditionExpression + from src.rules.models import RuleCondition + + condition = RuleCondition(type="author_team_is", parameters={"team": "devops"}) + expr = ConditionExpression(condition=condition) + print(" ✅ ConditionExpression created successfully") + except Exception as e: + print(f" ❌ Failed to import ConditionExpression: {e}") + all_passed = False + + # Test 2: ConditionEvaluator exists + print() + print("2. Checking ConditionEvaluator exists...") + try: + from src.rules.condition_evaluator import ConditionEvaluator + + _evaluator = ConditionEvaluator() + print(" ✅ ConditionEvaluator created successfully") + except Exception as e: + print(f" ❌ Failed to import ConditionEvaluator: {e}") + all_passed = False + + # Test 3: Rule model supports condition + print() + print("3. Checking Rule model supports condition field...") + try: + from src.rules.models import Rule + + # Check if Rule has condition field + rule_fields = Rule.model_fields.keys() + if "condition" in rule_fields: + print(" ✅ Rule model has 'condition' field") + else: + print(" ❌ Rule model missing 'condition' field") + all_passed = False + except Exception as e: + print(f" ❌ Failed to check Rule model: {e}") + all_passed = False + + # Test 4: from_dict works + print() + print("4. Checking ConditionExpression.from_dict...") + try: + from src.rules.condition_evaluator import ConditionExpression + + data = { + "operator": "AND", + "conditions": [ + {"type": "author_team_is", "parameters": {"team": "devops"}}, + {"type": "files_match_pattern", "parameters": {"pattern": "*.py"}}, + ], + } + expr = ConditionExpression.from_dict(data) + assert expr.operator == "AND" + assert len(expr.conditions) == 2 + print(" ✅ from_dict works correctly") + except Exception as e: + print(f" ❌ from_dict test failed: {e}") + all_passed = False + + print() + print("=" * 60) + if all_passed: + print("✅ All verification checks passed!") + else: + print("❌ Some checks failed") + print("=" * 60) + + return all_passed + + +if __name__ == "__main__": + # Run standalone verification when executed directly + success = run_standalone_verification() + sys.exit(0 if success else 1) diff --git a/tests/feedback/test_processing_states.py b/tests/feedback/test_processing_states.py new file mode 100644 index 0000000..2657982 --- /dev/null +++ b/tests/feedback/test_processing_states.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python3 +""" +Tests for ProcessingState enum and ProcessingResult error handling states. + +This test suite verifies that: +1. ProcessingState enum correctly distinguishes between PASS, FAIL, and ERROR +2. ProcessingResult correctly uses ProcessingState instead of boolean success +3. Backward compatibility property works correctly +4. All three states are properly handled in different scenarios + +Can be run in two ways: +1. As pytest test: pytest tests/feedback/test_processing_states.py -v +2. As standalone verification: python3 tests/feedback/test_processing_states.py + (runs code structure checks without requiring dependencies) +""" + +import sys +from pathlib import Path + +# Add project root to path for imports when running directly +ROOT = Path(__file__).resolve().parent.parent.parent +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + + +def verify_implementation_structure(): + """ + Verify ProcessingState implementation structure without requiring dependencies. + This can run even if project dependencies aren't installed. + """ + print("=" * 60) + print("Verifying ProcessingState Implementation Structure") + print("=" * 60) + print() + + try: + # Read the base.py file to verify the enum definition + base_file = ROOT / "src" / "event_processors" / "base.py" + if not base_file.exists(): + print("❌ ERROR: base.py not found") + return False + + content = base_file.read_text() + + # Check for ProcessingState enum + if "class ProcessingState(str, Enum):" not in content: + print("❌ ERROR: ProcessingState enum not found") + return False + print("✅ ProcessingState enum class found") + + # Check for enum values + if 'PASS = "pass"' not in content: + print("❌ ERROR: PASS value not found") + return False + print("✅ PASS value found") + + if 'FAIL = "fail"' not in content: + print("❌ ERROR: FAIL value not found") + return False + print("✅ FAIL value found") + + if 'ERROR = "error"' not in content: + print("❌ ERROR: ERROR value not found") + return False + print("✅ ERROR value found") + + # Check for ProcessingResult with state field + if "state: ProcessingState" not in content: + print("❌ ERROR: ProcessingResult.state field not found") + return False + print("✅ ProcessingResult.state field found") + + # Check for backward compatibility property + if "@property" in content and "def success(self)" in content: + print("✅ Backward compatibility .success property found") + else: + print("⚠️ WARNING: Backward compatibility .success property not found") + + print() + print("=" * 60) + print("✅ All structure checks passed!") + print("=" * 60) + print() + return True + + except Exception as e: + print(f"❌ ERROR: {e}") + import traceback + + traceback.print_exc() + return False + + +# Try to import pytest and the actual classes for full testing +try: + import pytest + + from src.event_processors.base import ProcessingResult, ProcessingState + + HAS_DEPENDENCIES = True +except (ImportError, ModuleNotFoundError): + # If imports fail, we can still run structure verification + HAS_DEPENDENCIES = False + pytest = None + ProcessingResult = None + ProcessingState = None + + +class TestProcessingState: + """Test ProcessingState enum values and behavior.""" + + def test_processing_state_values(self): + """Verify ProcessingState has correct string values.""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + assert ProcessingState.PASS == "pass" + assert ProcessingState.FAIL == "fail" + assert ProcessingState.ERROR == "error" + + def test_processing_state_enum_membership(self): + """Verify ProcessingState values are proper enum members.""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + assert isinstance(ProcessingState.PASS, ProcessingState) + assert isinstance(ProcessingState.FAIL, ProcessingState) + assert isinstance(ProcessingState.ERROR, ProcessingState) + + +class TestProcessingResultStates: + """Test ProcessingResult with different ProcessingState values.""" + + def test_processing_result_pass_state(self): + """Test ProcessingResult with PASS state (no violations).""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + result = ProcessingResult( + state=ProcessingState.PASS, + violations=[], + api_calls_made=1, + processing_time_ms=100, + ) + + assert result.state == ProcessingState.PASS + assert result.success is True # Backward compatibility + assert result.violations == [] + assert result.error is None + + def test_processing_result_fail_state(self): + """Test ProcessingResult with FAIL state (violations found).""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + violations = [{"rule": "test-rule", "severity": "high", "message": "Test violation"}] + result = ProcessingResult( + state=ProcessingState.FAIL, + violations=violations, + api_calls_made=2, + processing_time_ms=200, + ) + + assert result.state == ProcessingState.FAIL + assert result.success is False # Backward compatibility + assert len(result.violations) == 1 + assert result.error is None + + def test_processing_result_error_state(self): + """Test ProcessingResult with ERROR state (exception occurred).""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + result = ProcessingResult( + state=ProcessingState.ERROR, + violations=[], + api_calls_made=0, + processing_time_ms=50, + error="Failed to fetch rules", + ) + + assert result.state == ProcessingState.ERROR + assert result.success is False # Backward compatibility + assert result.violations == [] + assert result.error == "Failed to fetch rules" + + def test_processing_result_backward_compatibility(self): + """Test that .success property works for backward compatibility.""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + # PASS state should return True + pass_result = ProcessingResult( + state=ProcessingState.PASS, + violations=[], + api_calls_made=1, + processing_time_ms=100, + ) + assert pass_result.success is True + + # FAIL state should return False + fail_result = ProcessingResult( + state=ProcessingState.FAIL, + violations=[{"rule": "test"}], + api_calls_made=1, + processing_time_ms=100, + ) + assert fail_result.success is False + + # ERROR state should return False + error_result = ProcessingResult( + state=ProcessingState.ERROR, + violations=[], + api_calls_made=0, + processing_time_ms=50, + error="Test error", + ) + assert error_result.success is False + + def test_processing_result_state_distinction(self): + """Test that PASS, FAIL, and ERROR are clearly distinguished.""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + # PASS: No violations, no errors + pass_result = ProcessingResult( + state=ProcessingState.PASS, + violations=[], + api_calls_made=1, + processing_time_ms=100, + ) + + # FAIL: Violations found, but processing succeeded + fail_result = ProcessingResult( + state=ProcessingState.FAIL, + violations=[{"rule": "test", "message": "Violation"}], + api_calls_made=1, + processing_time_ms=100, + ) + + # ERROR: Exception occurred, couldn't check + error_result = ProcessingResult( + state=ProcessingState.ERROR, + violations=[], + api_calls_made=0, + processing_time_ms=50, + error="Exception occurred", + ) + + # Verify states are distinct + assert pass_result.state != fail_result.state + assert pass_result.state != error_result.state + assert fail_result.state != error_result.state + + # Verify PASS has no violations and no error + assert pass_result.violations == [] + assert pass_result.error is None + + # Verify FAIL has violations but no error + assert len(fail_result.violations) > 0 + assert fail_result.error is None + + # Verify ERROR has error message + assert error_result.error is not None + + def test_processing_result_with_violations_and_error(self): + """Test edge case: result with both violations and error (should be ERROR state).""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + # If there's an error, state should be ERROR regardless of violations + result = ProcessingResult( + state=ProcessingState.ERROR, + violations=[{"rule": "test"}], # Violations found before error + api_calls_made=1, + processing_time_ms=100, + error="Processing failed after finding violations", + ) + + assert result.state == ProcessingState.ERROR + assert result.error is not None + # Even though violations exist, state is ERROR because processing failed + + def test_processing_result_pydantic_validation(self): + """Test that ProcessingResult validates state correctly.""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + # Valid state should work + result = ProcessingResult( + state=ProcessingState.PASS, + violations=[], + api_calls_made=1, + processing_time_ms=100, + ) + assert result.state == ProcessingState.PASS + + # Invalid state should raise validation error + with pytest.raises(Exception): # Pydantic validation error + ProcessingResult( + state="invalid_state", # type: ignore + violations=[], + api_calls_made=1, + processing_time_ms=100, + ) + + +class TestProcessingStateScenarios: + """Test real-world scenarios for each processing state.""" + + def test_scenario_pass_no_rules_configured(self): + """Scenario: No rules configured - should be PASS (not an error).""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + result = ProcessingResult( + state=ProcessingState.PASS, + violations=[], + api_calls_made=1, + processing_time_ms=50, + ) + + assert result.state == ProcessingState.PASS + assert result.success is True + + def test_scenario_pass_all_rules_passed(self): + """Scenario: Rules evaluated, all passed - should be PASS.""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + result = ProcessingResult( + state=ProcessingState.PASS, + violations=[], + api_calls_made=2, + processing_time_ms=150, + ) + + assert result.state == ProcessingState.PASS + assert result.success is True + + def test_scenario_fail_violations_found(self): + """Scenario: Rules evaluated, violations found - should be FAIL.""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + violations = [ + {"rule": "min-approvals", "severity": "high", "message": "Need 2 approvals"}, + {"rule": "required-labels", "severity": "medium", "message": "Missing label"}, + ] + result = ProcessingResult( + state=ProcessingState.FAIL, + violations=violations, + api_calls_made=2, + processing_time_ms=200, + ) + + assert result.state == ProcessingState.FAIL + assert result.success is False + assert len(result.violations) == 2 + + def test_scenario_error_exception_occurred(self): + """Scenario: Exception during processing - should be ERROR.""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + result = ProcessingResult( + state=ProcessingState.ERROR, + violations=[], + api_calls_made=0, + processing_time_ms=10, + error="Failed to fetch rules: Connection timeout", + ) + + assert result.state == ProcessingState.ERROR + assert result.success is False + assert result.error is not None + assert "timeout" in result.error.lower() + + def test_scenario_error_rules_file_not_found(self): + """Scenario: Rules file not found - could be PASS or ERROR depending on context.""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + # If rules file not found is expected (first time setup), it's PASS + result_pass = ProcessingResult( + state=ProcessingState.PASS, + violations=[], + api_calls_made=1, + processing_time_ms=50, + error="Rules not configured", # Informational, not an error state + ) + + # If rules file should exist but doesn't, it's ERROR + result_error = ProcessingResult( + state=ProcessingState.ERROR, + violations=[], + api_calls_made=1, + processing_time_ms=50, + error="Rules file not found: .watchflow/rules.yaml", + ) + + # Both have error messages, but different states + assert result_pass.state == ProcessingState.PASS + assert result_error.state == ProcessingState.ERROR + + +class TestProcessingStateComparison: + """Test comparison and equality of ProcessingState values.""" + + def test_processing_state_equality(self): + """Test that ProcessingState values can be compared.""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + assert ProcessingState.PASS == ProcessingState.PASS + assert ProcessingState.FAIL == ProcessingState.FAIL + assert ProcessingState.ERROR == ProcessingState.ERROR + assert ProcessingState.PASS != ProcessingState.FAIL + assert ProcessingState.PASS != ProcessingState.ERROR + assert ProcessingState.FAIL != ProcessingState.ERROR + + def test_processing_result_state_comparison(self): + """Test comparing ProcessingResult states.""" + if not HAS_DEPENDENCIES: + pytest.skip("Dependencies not available") + pass_result = ProcessingResult( + state=ProcessingState.PASS, + violations=[], + api_calls_made=1, + processing_time_ms=100, + ) + + fail_result = ProcessingResult( + state=ProcessingState.FAIL, + violations=[{"rule": "test"}], + api_calls_made=1, + processing_time_ms=100, + ) + + assert pass_result.state == ProcessingState.PASS + assert fail_result.state == ProcessingState.FAIL + assert pass_result.state != fail_result.state + + +if __name__ == "__main__": + # When run directly, do structure verification first + print("Running structure verification...") + structure_ok = verify_implementation_structure() + + # If pytest is available, also run the tests + if HAS_DEPENDENCIES and pytest: + print("\n" + "=" * 60) + print("Running pytest tests...") + print("=" * 60 + "\n") + exit_code = pytest.main([__file__, "-v"]) + sys.exit(exit_code) + else: + # Just exit based on structure verification + if structure_ok: + print("\n✅ Structure verification passed!") + print("Note: Install dependencies to run full pytest tests:") + print(" pytest tests/feedback/test_processing_states.py -v") + sys.exit(0 if structure_ok else 1) diff --git a/tests/feedback/test_violation_deduplication.py b/tests/feedback/test_violation_deduplication.py new file mode 100644 index 0000000..35058ca --- /dev/null +++ b/tests/feedback/test_violation_deduplication.py @@ -0,0 +1,472 @@ +#!/usr/bin/env python3 +""" +Tests for violation deduplication and tracking. + +This test suite verifies that: +1. ViolationTracker correctly identifies duplicate violations +2. Fingerprinting generates unique IDs for violations +3. Duplicate violations are filtered before reporting +4. TTL-based expiration works correctly +5. Integration with event processors works + +Can be run in two ways: +1. As pytest test: pytest tests/feedback/test_violation_deduplication.py -v +2. As standalone verification: python3 tests/feedback/test_violation_deduplication.py +""" + +import sys +import time +from pathlib import Path + +# Add project root to path for imports when running directly +ROOT = Path(__file__).resolve().parent.parent.parent +if str(ROOT) not in sys.path: + sys.path.insert(0, str(ROOT)) + + +class TestViolationFingerprinting: + """Test violation fingerprinting functionality.""" + + def test_same_violation_same_fingerprint(self): + """Test that the same violation generates the same fingerprint.""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed", + "severity": "high", + "details": {"validator_used": "test_validator", "parameters": {"key": "value"}}, + } + + fingerprint1 = tracker.generate_fingerprint(violation, "owner/repo") + fingerprint2 = tracker.generate_fingerprint(violation, "owner/repo") + + assert fingerprint1 == fingerprint2, "Same violation should generate same fingerprint" + + def test_different_violations_different_fingerprints(self): + """Test that different violations generate different fingerprints.""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + violation1 = { + "rule_description": "Test Rule 1", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + violation2 = { + "rule_description": "Test Rule 2", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + fingerprint1 = tracker.generate_fingerprint(violation1, "owner/repo") + fingerprint2 = tracker.generate_fingerprint(violation2, "owner/repo") + + assert fingerprint1 != fingerprint2, "Different violations should generate different fingerprints" + + def test_fingerprint_includes_repo(self): + """Test that fingerprint includes repository name.""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + fingerprint1 = tracker.generate_fingerprint(violation, "owner/repo1") + fingerprint2 = tracker.generate_fingerprint(violation, "owner/repo2") + + assert fingerprint1 != fingerprint2, "Different repos should generate different fingerprints" + + def test_fingerprint_includes_context(self): + """Test that fingerprint includes context data (PR number, commit SHA).""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + context1 = {"pr_number": 1, "commit_sha": "abc123"} + context2 = {"pr_number": 2, "commit_sha": "def456"} + + fingerprint1 = tracker.generate_fingerprint(violation, "owner/repo", context1) + fingerprint2 = tracker.generate_fingerprint(violation, "owner/repo", context2) + + assert fingerprint1 != fingerprint2, "Different contexts should generate different fingerprints" + + def test_fingerprint_same_context_same_fingerprint(self): + """Test that same violation with same context generates same fingerprint.""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + context = {"pr_number": 1, "commit_sha": "abc123"} + + fingerprint1 = tracker.generate_fingerprint(violation, "owner/repo", context) + fingerprint2 = tracker.generate_fingerprint(violation, "owner/repo", context) + + assert fingerprint1 == fingerprint2, "Same violation with same context should generate same fingerprint" + + +class TestViolationTracking: + """Test violation tracking and deduplication.""" + + def test_mark_reported(self): + """Test marking a violation as reported.""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + fingerprint = tracker.generate_fingerprint(violation, "owner/repo") + assert not tracker.is_reported(fingerprint), "Violation should not be reported initially" + + tracker.mark_reported(fingerprint, violation) + assert tracker.is_reported(fingerprint), "Violation should be reported after marking" + + def test_filter_new_violations(self): + """Test filtering out duplicate violations.""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + violation1 = { + "rule_description": "Test Rule 1", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + violation2 = { + "rule_description": "Test Rule 2", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + # First batch - all should be new + violations = [violation1, violation2] + new_violations = tracker.filter_new_violations(violations, "owner/repo") + assert len(new_violations) == 2, "All violations should be new initially" + + # Second batch - same violations should be filtered + new_violations = tracker.filter_new_violations(violations, "owner/repo") + assert len(new_violations) == 0, "Duplicate violations should be filtered out" + + def test_filter_mixed_violations(self): + """Test filtering with mix of new and duplicate violations.""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + violation1 = { + "rule_description": "Test Rule 1", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + violation2 = { + "rule_description": "Test Rule 2", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + violation3 = { + "rule_description": "Test Rule 3", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + # First batch + violations = [violation1, violation2] + new_violations = tracker.filter_new_violations(violations, "owner/repo") + assert len(new_violations) == 2 + + # Second batch with one duplicate and one new + violations = [violation1, violation3] + new_violations = tracker.filter_new_violations(violations, "owner/repo") + assert len(new_violations) == 1, "Should filter duplicate, keep new" + assert new_violations[0]["rule_description"] == "Test Rule 3" + + def test_ttl_expiration(self): + """Test that violations expire after TTL.""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker(ttl_seconds=1) # 1 second TTL + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + fingerprint = tracker.generate_fingerprint(violation, "owner/repo") + tracker.mark_reported(fingerprint, violation) + assert tracker.is_reported(fingerprint), "Violation should be reported" + + # Wait for expiration + time.sleep(1.1) + assert not tracker.is_reported(fingerprint), "Violation should expire after TTL" + + def test_get_stats(self): + """Test getting tracker statistics.""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker(ttl_seconds=3600) + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + fingerprint = tracker.generate_fingerprint(violation, "owner/repo") + tracker.mark_reported(fingerprint, violation) + + stats = tracker.get_stats() + assert stats["total_tracked"] == 1 + assert stats["active"] == 1 + assert stats["expired"] == 0 + assert stats["total_reports"] >= 1 + + def test_clear(self): + """Test clearing all tracked violations.""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + fingerprint = tracker.generate_fingerprint(violation, "owner/repo") + tracker.mark_reported(fingerprint, violation) + assert tracker.is_reported(fingerprint) + + tracker.clear() + assert not tracker.is_reported(fingerprint), "Violation should not be reported after clear" + + +class TestGlobalViolationTracker: + """Test global violation tracker instance.""" + + def test_get_violation_tracker(self): + """Test getting the global violation tracker.""" + from src.core.utils.violation_tracker import get_violation_tracker + + tracker1 = get_violation_tracker() + tracker2 = get_violation_tracker() + + assert tracker1 is tracker2, "Should return the same global instance" + + def test_global_tracker_functionality(self): + """Test that global tracker works correctly.""" + from src.core.utils.violation_tracker import get_violation_tracker + + tracker = get_violation_tracker() + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + new_violations = tracker.filter_new_violations([violation], "owner/repo") + assert len(new_violations) == 1, "First violation should be new" + + new_violations = tracker.filter_new_violations([violation], "owner/repo") + assert len(new_violations) == 0, "Duplicate should be filtered" + + +class TestViolationDeduplicationIntegration: + """Test integration with event processors.""" + + def test_violation_dict_format(self): + """Test that violation dictionaries are properly formatted for tracking.""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + # Simulate violation format from RuleViolation model + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed: Test Rule", + "severity": "high", + "details": { + "validator_used": "required_labels", + "parameters": {"required_labels": ["security"]}, + "validation_result": "failed", + }, + "how_to_fix": "Add required labels", + "docs_url": "", + "validation_strategy": "validator", + "execution_time_ms": 10.5, + } + + fingerprint = tracker.generate_fingerprint(violation, "owner/repo") + assert fingerprint is not None + assert len(fingerprint) == 64 # SHA256 hex string length + + def test_context_aware_deduplication(self): + """Test that deduplication works with context (PR number, commit SHA).""" + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + # Same violation on different PRs should be tracked separately + context1 = {"pr_number": 1} + context2 = {"pr_number": 2} + + new1 = tracker.filter_new_violations([violation], "owner/repo", context1) + assert len(new1) == 1 + + new2 = tracker.filter_new_violations([violation], "owner/repo", context2) + assert len(new2) == 1, "Same violation on different PR should be new" + + # But same violation on same PR should be duplicate + new3 = tracker.filter_new_violations([violation], "owner/repo", context1) + assert len(new3) == 0, "Same violation on same PR should be duplicate" + + +def run_standalone_verification(): + """Run verification checks that don't require pytest.""" + print("=" * 60) + print("Violation Deduplication Verification") + print("=" * 60) + print() + + all_passed = True + + # Test 1: ViolationTracker exists + print("1. Checking ViolationTracker exists...") + try: + from src.core.utils.violation_tracker import ViolationTracker, get_violation_tracker + + tracker = ViolationTracker() + print(" ✅ ViolationTracker created successfully") + print(f" - TTL: {tracker.ttl_seconds}s") + except Exception as e: + print(f" ❌ Failed to import ViolationTracker: {e}") + all_passed = False + + # Test 2: Fingerprinting works + print() + print("2. Checking violation fingerprinting...") + try: + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + fingerprint1 = tracker.generate_fingerprint(violation, "owner/repo") + fingerprint2 = tracker.generate_fingerprint(violation, "owner/repo") + assert fingerprint1 == fingerprint2 + print(" ✅ Fingerprinting works correctly") + print(f" - Fingerprint length: {len(fingerprint1)}") + except Exception as e: + print(f" ❌ Fingerprinting test failed: {e}") + all_passed = False + + # Test 3: Deduplication works + print() + print("3. Checking violation deduplication...") + try: + from src.core.utils.violation_tracker import ViolationTracker + + tracker = ViolationTracker() + violation = { + "rule_description": "Test Rule", + "message": "Rule validation failed", + "severity": "high", + "details": {}, + } + + # First batch + new1 = tracker.filter_new_violations([violation], "owner/repo") + assert len(new1) == 1 + + # Second batch (duplicates) + new2 = tracker.filter_new_violations([violation], "owner/repo") + assert len(new2) == 0 + + print(" ✅ Deduplication works correctly") + print(f" - First batch: {len(new1)} new violation(s)") + print(f" - Second batch: {len(new2)} new violation(s) (duplicates filtered)") + except Exception as e: + print(f" ❌ Deduplication test failed: {e}") + all_passed = False + + # Test 4: Global tracker works + print() + print("4. Checking global violation tracker...") + try: + from src.core.utils.violation_tracker import get_violation_tracker + + tracker = get_violation_tracker() + stats = tracker.get_stats() + print(" ✅ Global tracker works") + print(f" - Total tracked: {stats['total_tracked']}") + print(f" - Active: {stats['active']}") + except Exception as e: + print(f" ❌ Global tracker test failed: {e}") + all_passed = False + + # Test 5: Integration with PullRequestProcessor + print() + print("5. Checking integration with PullRequestProcessor...") + try: + from src.event_processors.pull_request import PullRequestProcessor + + # Just check that the import works (actual integration tested in unit tests) + _processor = PullRequestProcessor() + print(" ✅ PullRequestProcessor imports violation tracker") + except Exception as e: + print(f" ⚠️ Integration check: {e}") + # Not a failure, might be missing dependencies + + print() + print("=" * 60) + if all_passed: + print("✅ All verification checks passed!") + else: + print("❌ Some checks failed") + print("=" * 60) + + return all_passed + + +if __name__ == "__main__": + # Run standalone verification when executed directly + success = run_standalone_verification() + sys.exit(0 if success else 1)