Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions library/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ def post(self, url_suffix: str, timeout: int = MINUTE_SECS, **kwargs) -> Respons
)

def url(self, path: str):
if path[0:1] == '/':
path = path[1:]
return self.host + '/' + path
from urllib.parse import urlparse, urljoin
parsed = urlparse(self.host)
if parsed.scheme not in ('https', 'http'):
raise ValueError(f"ServerAuth host must use http(s) scheme, got: {parsed.scheme!r}")
base = self.host if self.host.endswith('/') else self.host + '/'
return urljoin(base, path.lstrip('/'))
73 changes: 45 additions & 28 deletions sync/alissa/alissa_upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,40 +123,57 @@ def sync(self, sync_run_instance: SyncRunInstance):

try:
if response_json := response.json():
total_failed += int(response_json.get("numberFailed"))
total_differs += int(response_json.get("numberDiffers"))
total_imported += int(response_json.get("numberImported"))

response_jsons.append(response_json)
if not isinstance(response_json, dict):
raise ValueError(f"Expected dict response from Alissa, got {type(response_json).__name__}")

def _safe_int(val):
try:
return int(val)
except (TypeError, ValueError):
return 0

num_failed = _safe_int(response_json.get("numberFailed"))
num_differs = _safe_int(response_json.get("numberDiffers"))
num_imported = _safe_int(response_json.get("numberImported"))
total_failed += num_failed
total_differs += num_differs
total_imported += num_imported

response_jsons.append({
"numberFailed": num_failed,
"numberDiffers": num_differs,
"numberImported": num_imported,
"failures": response_json.get("failures") or [],
"infos": response_json.get("infos") or [],
})
if response_error := response_json.get("error"):
notify = AdminNotificationBuilder(message="Error Uploading")
notify.add_field("Sync Destination", sync_run_instance.name)
notify.add_field("Error", response_error)
notify.add_field("Error", str(response_error)[:500])
notify.send()
elif numberFailed := int(response_json.get("numberFailed")):
if numberFailed > 0:
notify = AdminNotificationBuilder(message="Error Uploading")
notify.add_field("Sync Destination", sync_run_instance.sync_destination.name)
notify.add_field("Failures", numberFailed)

failure: str
for failure in response_json.get("failures"):
if "\t" in failure:
parts = failure.split("\t")
message = parts[0]
json_data_str = parts[1]
try:
json_data = json.loads(json_data_str)
transcript = json_data.get("transcript")
c_nomen = json_data.get("cNomen")
notify.add_markdown(f"{transcript}:{c_nomen} - \"{message}\"")
except ValueError:
notify.add_markdown(failure)
else:
elif num_failed > 0:
notify = AdminNotificationBuilder(message="Error Uploading")
notify.add_field("Sync Destination", sync_run_instance.sync_destination.name)
notify.add_field("Failures", num_failed)

failure: str
for failure in (response_json.get("failures") or []):
if "\t" in failure:
parts = failure.split("\t")
message = parts[0]
json_data_str = parts[1]
try:
json_data = json.loads(json_data_str)
transcript = json_data.get("transcript")
c_nomen = json_data.get("cNomen")
notify.add_markdown(f"{transcript}:{c_nomen} - \"{message}\"")
except (ValueError, KeyError):
notify.add_markdown(failure)
else:
notify.add_markdown(failure)

notify.send()
except Exception:
notify.send()
except (ValueError, AttributeError):
report_exc_info()

since_timestamp = None
Expand Down
8 changes: 6 additions & 2 deletions sync/models/models_classification_sync.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import re
from typing import Optional
from urllib.parse import urljoin

from django.db import models
from django.db.models.deletion import CASCADE
Expand Down Expand Up @@ -33,5 +34,8 @@ def remote_url(self) -> Optional[str]:
raise ValueError(f"{self.classification_modification} not successfully synced")
url = self.run.destination.sync_details["host"]
remote_pk = self.meta["meta"]["id"]
if not re.fullmatch(r'[\w\-]+', str(remote_pk)):
raise ValueError(f"remote_pk contains unsafe characters: {remote_pk!r}")
path = self.classification_modification.classification.get_url_for_pk(remote_pk)
return os.path.join(url, path)
base = url if url.endswith('/') else url + '/'
return urljoin(base, path.lstrip('/'))
5 changes: 5 additions & 0 deletions sync/shariant/query_json_filter.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import operator
import re
from functools import reduce
from typing import Any

from django.db.models import Q

_SAFE_KEY_RE = re.compile(r'^[a-zA-Z]\w*$')


class QueryJsonFilter:
"""
Expand Down Expand Up @@ -58,6 +61,8 @@ def convert_to_q_w_key(self, key: str, value) -> Q:
return self.convert_to_q(value, operator.__and__)
else:
# key is assumed to be regular value key
if not _SAFE_KEY_RE.match(key):
raise ValueError(f"Filter key contains unsafe characters: {key!r}")
if isinstance(value, list):
handle_none = False
is_not = False
Expand Down
17 changes: 16 additions & 1 deletion sync/shariant/variant_grid_download.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
import time
from typing import Optional

Expand Down Expand Up @@ -31,12 +32,20 @@ def sync(self, sync_run_instance: SyncRunInstance):
'type': 'json',
'build': required_build}

_safe_identifier = re.compile(r'^[\w\-]+$')

exclude_labs = config.get('exclude_labs', None)
if exclude_labs:
for lab in exclude_labs:
if not _safe_identifier.match(str(lab)):
raise ValueError(f"exclude_labs contains unsafe value: {lab!r}")
params['exclude_labs'] = ','.join(exclude_labs)

exclude_orgs = config.get('exclude_orgs', None)
if exclude_orgs:
for org in exclude_orgs:
if not _safe_identifier.match(str(org)):
raise ValueError(f"exclude_orgs contains unsafe value: {org!r}")
params['exclude_orgs'] = ','.join(exclude_orgs)

if not sync_run_instance.full_sync:
Expand Down Expand Up @@ -92,12 +101,18 @@ def shariant_download_to_upload(known_keys: EvidenceKeyMap, record_json: dict) -
return None

if not lab:
_lab_group_name_re = re.compile(r'^[\w\-]+/[\w\-]+$')
if not _lab_group_name_re.match(lab_group_name):
raise ValueError(f"lab_group_name has unexpected format: {lab_group_name!r}")
parts = lab_group_name.split('/')
lab_name = meta.get('lab_name') or parts[1]
if not isinstance(lab_name, str) or len(lab_name) > 255:
raise ValueError(f"lab_name is invalid: {lab_name!r}")
org, _ = Organization.objects.get_or_create(group_name=parts[0], defaults={"name": parts[0]})
australia, _ = Country.objects.get_or_create(name='Australia')
Lab.objects.create(
group_name=lab_group_name,
name=meta.get('lab_name'),
name=lab_name,
organization=org,
city='Unknown',
country=australia,
Expand Down
3 changes: 1 addition & 2 deletions sync/sync_runner.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
from abc import ABC, abstractmethod
from dataclasses import dataclass
from datetime import datetime
Expand Down Expand Up @@ -111,4 +110,4 @@ def sync_runner_for_destination(sync_destination: SyncDestination) -> SyncRunner
if factory_requirements.matches(sync_destination):
return factory_requirements.factory()

raise ValueError(f"None of the {len(_sync_runner_registry)} SyncRunners is configured for the config of {sync_destination}: ({json.dumps(sync_destination.config)})")
raise ValueError(f"None of the {len(_sync_runner_registry)} SyncRunners matched destination {sync_destination!r}")
Loading