From 53648bdbe1ca585979c3b500a1f3eba51f9b6464 Mon Sep 17 00:00:00 2001 From: Dave Lawrence Date: Thu, 2 Apr 2026 10:42:33 +1030 Subject: [PATCH] Security hardening for sync app (SACGF/variantgrid_private#3831) - oauth.py: validate URL scheme and use urljoin instead of string concat - models_classification_sync.py: validate remote_pk, use urljoin for URL construction - variant_grid_download.py: validate exclude_labs/orgs values and lab_group_name format before creating Lab/Org records - alissa_upload.py: safe int conversion for response counts, narrow except clause, store only sanitized response summaries in SyncRun.meta - query_json_filter.py: validate filter keys against safe identifier pattern - sync_runner.py: remove config dump (including credentials) from error message --- library/oauth.py | 9 ++- sync/alissa/alissa_upload.py | 73 ++++++++++++++--------- sync/models/models_classification_sync.py | 8 ++- sync/shariant/query_json_filter.py | 5 ++ sync/shariant/variant_grid_download.py | 17 +++++- sync/sync_runner.py | 3 +- 6 files changed, 79 insertions(+), 36 deletions(-) diff --git a/library/oauth.py b/library/oauth.py index c90673400..1a98cd500 100644 --- a/library/oauth.py +++ b/library/oauth.py @@ -83,6 +83,9 @@ def post(self, url_suffix: str, timeout: int = MINUTE_SECS, **kwargs) -> Respons ) def url(self, path: str): - if path[0:1] == '/': - path = path[1:] - return self.host + '/' + path + from urllib.parse import urlparse, urljoin + parsed = urlparse(self.host) + if parsed.scheme not in ('https', 'http'): + raise ValueError(f"ServerAuth host must use http(s) scheme, got: {parsed.scheme!r}") + base = self.host if self.host.endswith('/') else self.host + '/' + return urljoin(base, path.lstrip('/')) diff --git a/sync/alissa/alissa_upload.py b/sync/alissa/alissa_upload.py index 34131c210..5950f21c4 100644 --- a/sync/alissa/alissa_upload.py +++ b/sync/alissa/alissa_upload.py @@ -123,40 +123,57 @@ def sync(self, sync_run_instance: SyncRunInstance): try: if response_json := response.json(): - total_failed += int(response_json.get("numberFailed")) - total_differs += int(response_json.get("numberDiffers")) - total_imported += int(response_json.get("numberImported")) - - response_jsons.append(response_json) + if not isinstance(response_json, dict): + raise ValueError(f"Expected dict response from Alissa, got {type(response_json).__name__}") + + def _safe_int(val): + try: + return int(val) + except (TypeError, ValueError): + return 0 + + num_failed = _safe_int(response_json.get("numberFailed")) + num_differs = _safe_int(response_json.get("numberDiffers")) + num_imported = _safe_int(response_json.get("numberImported")) + total_failed += num_failed + total_differs += num_differs + total_imported += num_imported + + response_jsons.append({ + "numberFailed": num_failed, + "numberDiffers": num_differs, + "numberImported": num_imported, + "failures": response_json.get("failures") or [], + "infos": response_json.get("infos") or [], + }) if response_error := response_json.get("error"): notify = AdminNotificationBuilder(message="Error Uploading") notify.add_field("Sync Destination", sync_run_instance.name) - notify.add_field("Error", response_error) + notify.add_field("Error", str(response_error)[:500]) notify.send() - elif numberFailed := int(response_json.get("numberFailed")): - if numberFailed > 0: - notify = AdminNotificationBuilder(message="Error Uploading") - notify.add_field("Sync Destination", sync_run_instance.sync_destination.name) - notify.add_field("Failures", numberFailed) - - failure: str - for failure in response_json.get("failures"): - if "\t" in failure: - parts = failure.split("\t") - message = parts[0] - json_data_str = parts[1] - try: - json_data = json.loads(json_data_str) - transcript = json_data.get("transcript") - c_nomen = json_data.get("cNomen") - notify.add_markdown(f"{transcript}:{c_nomen} - \"{message}\"") - except ValueError: - notify.add_markdown(failure) - else: + elif num_failed > 0: + notify = AdminNotificationBuilder(message="Error Uploading") + notify.add_field("Sync Destination", sync_run_instance.sync_destination.name) + notify.add_field("Failures", num_failed) + + failure: str + for failure in (response_json.get("failures") or []): + if "\t" in failure: + parts = failure.split("\t") + message = parts[0] + json_data_str = parts[1] + try: + json_data = json.loads(json_data_str) + transcript = json_data.get("transcript") + c_nomen = json_data.get("cNomen") + notify.add_markdown(f"{transcript}:{c_nomen} - \"{message}\"") + except (ValueError, KeyError): notify.add_markdown(failure) + else: + notify.add_markdown(failure) - notify.send() - except Exception: + notify.send() + except (ValueError, AttributeError): report_exc_info() since_timestamp = None diff --git a/sync/models/models_classification_sync.py b/sync/models/models_classification_sync.py index 8b5dd6025..9cedf51b4 100644 --- a/sync/models/models_classification_sync.py +++ b/sync/models/models_classification_sync.py @@ -1,5 +1,6 @@ -import os +import re from typing import Optional +from urllib.parse import urljoin from django.db import models from django.db.models.deletion import CASCADE @@ -33,5 +34,8 @@ def remote_url(self) -> Optional[str]: raise ValueError(f"{self.classification_modification} not successfully synced") url = self.run.destination.sync_details["host"] remote_pk = self.meta["meta"]["id"] + if not re.fullmatch(r'[\w\-]+', str(remote_pk)): + raise ValueError(f"remote_pk contains unsafe characters: {remote_pk!r}") path = self.classification_modification.classification.get_url_for_pk(remote_pk) - return os.path.join(url, path) + base = url if url.endswith('/') else url + '/' + return urljoin(base, path.lstrip('/')) diff --git a/sync/shariant/query_json_filter.py b/sync/shariant/query_json_filter.py index b65bb884b..b72c1e195 100644 --- a/sync/shariant/query_json_filter.py +++ b/sync/shariant/query_json_filter.py @@ -1,9 +1,12 @@ import operator +import re from functools import reduce from typing import Any from django.db.models import Q +_SAFE_KEY_RE = re.compile(r'^[a-zA-Z]\w*$') + class QueryJsonFilter: """ @@ -58,6 +61,8 @@ def convert_to_q_w_key(self, key: str, value) -> Q: return self.convert_to_q(value, operator.__and__) else: # key is assumed to be regular value key + if not _SAFE_KEY_RE.match(key): + raise ValueError(f"Filter key contains unsafe characters: {key!r}") if isinstance(value, list): handle_none = False is_not = False diff --git a/sync/shariant/variant_grid_download.py b/sync/shariant/variant_grid_download.py index a8375b9ce..38753b519 100644 --- a/sync/shariant/variant_grid_download.py +++ b/sync/shariant/variant_grid_download.py @@ -1,3 +1,4 @@ +import re import time from typing import Optional @@ -31,12 +32,20 @@ def sync(self, sync_run_instance: SyncRunInstance): 'type': 'json', 'build': required_build} + _safe_identifier = re.compile(r'^[\w\-]+$') + exclude_labs = config.get('exclude_labs', None) if exclude_labs: + for lab in exclude_labs: + if not _safe_identifier.match(str(lab)): + raise ValueError(f"exclude_labs contains unsafe value: {lab!r}") params['exclude_labs'] = ','.join(exclude_labs) exclude_orgs = config.get('exclude_orgs', None) if exclude_orgs: + for org in exclude_orgs: + if not _safe_identifier.match(str(org)): + raise ValueError(f"exclude_orgs contains unsafe value: {org!r}") params['exclude_orgs'] = ','.join(exclude_orgs) if not sync_run_instance.full_sync: @@ -92,12 +101,18 @@ def shariant_download_to_upload(known_keys: EvidenceKeyMap, record_json: dict) - return None if not lab: + _lab_group_name_re = re.compile(r'^[\w\-]+/[\w\-]+$') + if not _lab_group_name_re.match(lab_group_name): + raise ValueError(f"lab_group_name has unexpected format: {lab_group_name!r}") parts = lab_group_name.split('/') + lab_name = meta.get('lab_name') or parts[1] + if not isinstance(lab_name, str) or len(lab_name) > 255: + raise ValueError(f"lab_name is invalid: {lab_name!r}") org, _ = Organization.objects.get_or_create(group_name=parts[0], defaults={"name": parts[0]}) australia, _ = Country.objects.get_or_create(name='Australia') Lab.objects.create( group_name=lab_group_name, - name=meta.get('lab_name'), + name=lab_name, organization=org, city='Unknown', country=australia, diff --git a/sync/sync_runner.py b/sync/sync_runner.py index c20051399..e17b81f81 100644 --- a/sync/sync_runner.py +++ b/sync/sync_runner.py @@ -1,4 +1,3 @@ -import json from abc import ABC, abstractmethod from dataclasses import dataclass from datetime import datetime @@ -111,4 +110,4 @@ def sync_runner_for_destination(sync_destination: SyncDestination) -> SyncRunner if factory_requirements.matches(sync_destination): return factory_requirements.factory() - raise ValueError(f"None of the {len(_sync_runner_registry)} SyncRunners is configured for the config of {sync_destination}: ({json.dumps(sync_destination.config)})") + raise ValueError(f"None of the {len(_sync_runner_registry)} SyncRunners matched destination {sync_destination!r}")