diff --git a/src/extension_shield/api/main.py b/src/extension_shield/api/main.py index c027bb4c..b7b772c9 100644 --- a/src/extension_shield/api/main.py +++ b/src/extension_shield/api/main.py @@ -984,6 +984,65 @@ def validate_feedback(self) -> "FeedbackRequest": return self +def _coerce_json_dict(value: Any) -> Dict[str, Any]: + """Convert JSON-ish values to a dict for metadata lookups.""" + if isinstance(value, dict): + return value + if isinstance(value, str): + try: + parsed = json.loads(value) + except (TypeError, ValueError): + return {} + return parsed if isinstance(parsed, dict) else {} + return {} + + +def _extract_feedback_versions(payload: Optional[Dict[str, Any]]) -> tuple[Optional[str], Optional[str]]: + """Extract model and ruleset versions from a scan payload when available.""" + if not isinstance(payload, dict): + return None, None + + metadata = _coerce_json_dict(payload.get("metadata")) + summary = _coerce_json_dict(payload.get("summary")) + report_view_model_source = payload.get("report_view_model") + if not report_view_model_source: + report_view_model_source = summary.get("report_view_model") + report_view_model = _coerce_json_dict(report_view_model_source) + report_meta = _coerce_json_dict(report_view_model.get("meta")) + + scoring_v2 = _coerce_json_dict(payload.get("scoring_v2")) + if not scoring_v2: + scoring_v2 = _coerce_json_dict(summary.get("scoring_v2")) + + model_version = ( + summary.get("model_version") + or metadata.get("model_version") + or report_meta.get("model_version") + or summary.get("llm_model") + or metadata.get("llm_model") + ) + + ruleset_version = ( + summary.get("ruleset_version") + or metadata.get("ruleset_version") + or report_meta.get("ruleset_version") + or scoring_v2.get("weights_version") + or scoring_v2.get("scoring_version") + ) + + return model_version, ruleset_version + + +def _resolve_feedback_versions(scan_id: str) -> tuple[Optional[str], Optional[str]]: + """Look up the related scan payload and extract version metadata safely.""" + payload = scan_results.get(scan_id) + if isinstance(payload, dict): + return _extract_feedback_versions(payload) + + db_payload = db.get_scan_result(scan_id) + return _extract_feedback_versions(db_payload) + + class ReviewQueueClaimRequest(BaseModel): """Request body for claiming a review queue item.""" queue_item_id: str @@ -1504,6 +1563,16 @@ async def run_analysis_workflow(url: str, extension_id: str): if scoring_result.coverage_cap_reason is not None: scoring_v2_payload["coverage_cap_reason"] = scoring_result.coverage_cap_reason + executive_summary = final_state.get("executive_summary") or {} + if isinstance(executive_summary, dict): + executive_summary = dict(executive_summary) + else: + executive_summary = {} + executive_summary.setdefault( + "ruleset_version", + scoring_v2_payload.get("weights_version") or scoring_v2_payload.get("scoring_version"), + ) + # Build scan results - sanitize complex objects to prevent circular references raw_results = { "extension_id": extension_id, @@ -1518,7 +1587,7 @@ async def run_analysis_workflow(url: str, extension_id: str): "webstore_analysis": analysis_results.get("webstore_analysis") or {}, "virustotal_analysis": analysis_results.get("virustotal_analysis") or {}, "entropy_analysis": analysis_results.get("entropy_analysis") or {}, - "summary": final_state.get("executive_summary") or {}, + "summary": executive_summary, "impact_analysis": analysis_results.get("impact_analysis") or {}, "privacy_compliance": analysis_results.get("privacy_compliance") or {}, "extracted_path": _storage_relative_extracted_path(final_state.get("extension_dir")), @@ -2199,11 +2268,14 @@ async def submit_feedback(feedback: FeedbackRequest, http_request: Request): provide details about why it wasn't (false positive, score issues, etc.). """ user_id = _get_user_id(http_request) + if user_id == "anon": + user_id = None # If helpful=true, ignore reason/suggested_score/comment reason = None if feedback.helpful else (feedback.reason.value if feedback.reason else None) suggested_score = None if feedback.helpful else feedback.suggested_score comment = None if feedback.helpful else feedback.comment + model_version, ruleset_version = _resolve_feedback_versions(feedback.scan_id) # Save feedback to database (SQLite or Supabase) db.save_feedback( @@ -2213,8 +2285,8 @@ async def submit_feedback(feedback: FeedbackRequest, http_request: Request): suggested_score=suggested_score, comment=comment, user_id=user_id, - model_version=None, # TODO: Extract from scan result metadata - ruleset_version=None, # TODO: Extract from scan result metadata + model_version=model_version, + ruleset_version=ruleset_version, ) return {"ok": True} diff --git a/src/extension_shield/core/summary_generator.py b/src/extension_shield/core/summary_generator.py index 0cc969a0..d86c9d25 100644 --- a/src/extension_shield/core/summary_generator.py +++ b/src/extension_shield/core/summary_generator.py @@ -35,6 +35,40 @@ def _summary_contradicts_label(text: str, score_label: str) -> bool: return False +def _extract_response_model_version(response: Any) -> Optional[str]: + """Best-effort extraction of the actual model identifier returned by the LLM client.""" + if response is None: + return None + + direct_value = getattr(response, "model", None) or getattr(response, "model_name", None) + if isinstance(direct_value, str) and direct_value.strip(): + return direct_value.strip() + + response_metadata = getattr(response, "response_metadata", None) + if isinstance(response_metadata, dict): + for key in ("model_name", "model", "model_id", "deployment_name"): + value = response_metadata.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + + additional_kwargs = getattr(response, "additional_kwargs", None) + if isinstance(additional_kwargs, dict): + for key in ("model_name", "model", "model_id"): + value = additional_kwargs.get(key) + if isinstance(value, str) and value.strip(): + return value.strip() + + return None + + +def _attach_model_version(summary: Dict[str, Any], model_version: Optional[str]) -> Dict[str, Any]: + """Return a shallow copy of a summary with model_version attached when known.""" + normalized = dict(summary) if isinstance(summary, dict) else {} + if model_version: + normalized["model_version"] = model_version + return normalized + + class SummaryGenerator: """Generates executive summaries from all analysis results.""" @@ -473,11 +507,13 @@ def generate( model_name=model_name, model_parameters=model_parameters, ) + resolved_model_version = _extract_response_model_version(response) # Parse JSON response parser = JsonOutputParser() summary = parser.parse(response.content if hasattr(response, "content") else str(response)) if isinstance(summary, dict): + summary = _attach_model_version(summary, resolved_model_version) score = summary.get("score") score_label = summary.get("score_label") if score is None: @@ -542,10 +578,13 @@ def generate( # ) # Return deterministic fallback from extension_shield.core.report_view_model import _fallback_executive_summary - return _fallback_executive_summary( - score=score, - score_label=score_label, - host_scope_label=host_scope_label, + return _attach_model_version( + _fallback_executive_summary( + score=score, + score_label=score_label, + host_scope_label=host_scope_label, + ), + resolved_model_version, ) # ── Post-LLM sanity check: one_liner must not contradict score_label ── @@ -557,12 +596,15 @@ def generate( score_label, ) from extension_shield.core.report_view_model import _fallback_executive_summary - return _fallback_executive_summary( - score=score, - score_label=score_label, - host_scope_label=host_scope_label, + return _attach_model_version( + _fallback_executive_summary( + score=score, + score_label=score_label, + host_scope_label=host_scope_label, + ), + resolved_model_version, ) - + logger.info("Executive summary generated successfully") return summary except Exception as exc: diff --git a/tests/api/test_feedback_endpoint.py b/tests/api/test_feedback_endpoint.py new file mode 100644 index 00000000..fc961b10 --- /dev/null +++ b/tests/api/test_feedback_endpoint.py @@ -0,0 +1,162 @@ +from unittest.mock import MagicMock, patch + +import pytest +from fastapi.testclient import TestClient + +from extension_shield.api.main import app, scan_results + + +@pytest.fixture +def client() -> TestClient: + return TestClient(app) + + +def test_feedback_persists_versions_from_scan_payload(client: TestClient) -> None: + scan_id = "abcdefghijklmnopabcdefghijklmnop" + scan_results.pop(scan_id, None) + + scan_payload = { + "summary": {"model_version": "gpt-4o"}, + "scoring_v2": {"scoring_version": "2.0.0", "weights_version": "v1"}, + } + + with patch("extension_shield.api.main.db") as mock_db: + mock_db.get_scan_result = MagicMock(return_value=scan_payload) + mock_db.save_feedback = MagicMock() + + response = client.post( + "/api/feedback", + json={ + "scan_id": scan_id, + "helpful": False, + "reason": "score_off", + "suggested_score": 72, + "comment": "Score feels too strict", + }, + ) + + assert response.status_code == 200 + mock_db.save_feedback.assert_called_once_with( + scan_id=scan_id, + helpful=False, + reason="score_off", + suggested_score=72, + comment="Score feels too strict", + user_id=None, + model_version="gpt-4o", + ruleset_version="v1", + ) + + +def test_feedback_handles_missing_version_metadata(client: TestClient) -> None: + scan_id = "ponmlkjihgfedcbaponmlkjihgfedcba" + scan_results.pop(scan_id, None) + + with patch("extension_shield.api.main.db") as mock_db: + mock_db.get_scan_result = MagicMock(return_value={"summary": {}, "scoring_v2": {}}) + mock_db.save_feedback = MagicMock() + + response = client.post( + "/api/feedback", + json={ + "scan_id": scan_id, + "helpful": True, + }, + ) + + assert response.status_code == 200 + mock_db.save_feedback.assert_called_once_with( + scan_id=scan_id, + helpful=True, + reason=None, + suggested_score=None, + comment=None, + user_id=None, + model_version=None, + ruleset_version=None, + ) + + +def test_feedback_persists_versions_from_nested_summary_payload( + client: TestClient, +) -> None: + scan_id = "nestedabcdefghijnestedabcdefghij" + scan_results.pop(scan_id, None) + + scan_payload = { + "summary": { + "report_view_model": { + "meta": { + "model_version": "resolved-model", + "ruleset_version": "legacy-rules", + } + }, + "scoring_v2": { + "scoring_version": "2.1.0", + "weights_version": "rules-v3", + }, + } + } + + with patch("extension_shield.api.main.db") as mock_db: + mock_db.get_scan_result = MagicMock(return_value=scan_payload) + mock_db.save_feedback = MagicMock() + + response = client.post( + "/api/feedback", + json={ + "scan_id": scan_id, + "helpful": False, + "reason": "score_off", + "suggested_score": 85, + "comment": "Nested summary payload", + }, + ) + + assert response.status_code == 200 + mock_db.save_feedback.assert_called_once_with( + scan_id=scan_id, + helpful=False, + reason="score_off", + suggested_score=85, + comment="Nested summary payload", + user_id=None, + model_version="resolved-model", + ruleset_version="legacy-rules", + ) + + +def test_feedback_normalizes_anonymous_user_id_to_none(client: TestClient) -> None: + scan_id = "anonabcdefghijklmnoanonabcdefgh" + scan_results.pop(scan_id, None) + + scan_payload = { + "summary": {"model_version": "gpt-4o-mini"}, + "scoring_v2": {"weights_version": "v1"}, + } + + with patch("extension_shield.api.main.db") as mock_db: + mock_db.get_scan_result = MagicMock(return_value=scan_payload) + mock_db.save_feedback = MagicMock() + + response = client.post( + "/api/feedback", + json={ + "scan_id": scan_id, + "helpful": False, + "reason": "other", + "comment": "anonymous feedback", + }, + ) + + assert response.status_code == 200 + mock_db.save_feedback.assert_called_once_with( + scan_id=scan_id, + helpful=False, + reason="other", + suggested_score=None, + comment="anonymous feedback", + user_id=None, + model_version="gpt-4o-mini", + ruleset_version="v1", + )