diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 824d071..f967dd1 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -1,30 +1,53 @@ -name: Python package +name: CI on: push: - branches: [ "main" ] + branches: [main] pull_request: - branches: [ "main" ] + branches: [main] + workflow_dispatch: jobs: - build: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install linting tools + run: | + python -m pip install --upgrade pip + pip install ruff black + + - name: Check formatting with black + run: black --check src/ tests/ + - name: Lint with ruff + run: ruff check src/ tests/ + + test: runs-on: ubuntu-latest strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11"] + python-version: ["3.10", "3.11", "3.12"] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install -e ".[dev]" - - name: Test with pytest - run: | - pytest + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Test with pytest + run: pytest diff --git a/src/did/binarydoc.py b/src/did/binarydoc.py index ac2bbf3..f5cb5bb 100644 --- a/src/did/binarydoc.py +++ b/src/did/binarydoc.py @@ -1,5 +1,6 @@ import abc + class BinaryDoc(abc.ABC): def __init__(self, *args, **kwargs): pass @@ -7,7 +8,7 @@ def __init__(self, *args, **kwargs): def __del__(self): try: self.fclose() - except: + except Exception: pass @abc.abstractmethod @@ -36,4 +37,4 @@ def fread(self, count, precision, skip): @abc.abstractmethod def fclose(self): - pass \ No newline at end of file + pass diff --git a/src/did/common.py b/src/did/common.py index decb496..ccf0181 100644 --- a/src/did/common.py +++ b/src/did/common.py @@ -3,12 +3,14 @@ from pathlib import Path import uuid + def toolboxdir(): """ Returns the path to the toolbox directory. """ return os.path.dirname(os.path.abspath(__file__)) + def must_be_writable(folder_path): """ Checks if a folder is writable, and creates it if it doesn't exist. @@ -18,35 +20,41 @@ def must_be_writable(folder_path): os.makedirs(folder_path) except OSError: # Fallback for potential permission errors - folder_path = os.path.join(tempfile.gettempdir(), os.path.basename(folder_path)) + folder_path = os.path.join( + tempfile.gettempdir(), os.path.basename(folder_path) + ) os.makedirs(folder_path, exist_ok=True) test_file = os.path.join(folder_path, f"testfile_{uuid.uuid4()}.txt") try: - with open(test_file, 'w') as f: - f.write('test') + with open(test_file, "w") as f: + f.write("test") except IOError: - raise IOError(f'We do not have write access to the folder at {folder_path}') + raise IOError(f"We do not have write access to the folder at {folder_path}") finally: if os.path.exists(test_file): os.remove(test_file) + class PathConstants: """ Class that defines some global constants for the DID package. """ + PATH = toolboxdir() - DEFPATH = os.path.join(PATH, 'example_schema', 'demo_schema1') + DEFPATH = os.path.join(PATH, "example_schema", "demo_schema1") DEFINITIONS = { - '$DIDDOCUMENT_EX1': os.path.join(DEFPATH, 'database_documents'), - '$DIDSCHEMA_EX1': os.path.join(DEFPATH, 'database_schema'), - '$DIDCONTROLLEDVOCAB_EX1': os.path.join(DEFPATH, 'controlled_vocabulary') + "$DIDDOCUMENT_EX1": os.path.join(DEFPATH, "database_documents"), + "$DIDSCHEMA_EX1": os.path.join(DEFPATH, "database_schema"), + "$DIDCONTROLLEDVOCAB_EX1": os.path.join(DEFPATH, "controlled_vocabulary"), } - _temp_path = os.path.join(tempfile.gettempdir(), 'didtemp') - _file_cache_path = os.path.join(str(Path.home()), 'Documents', 'DID', 'fileCache') - _preferences_path = os.path.join(str(Path.home()), 'Documents', 'DID', 'Preferences') + _temp_path = os.path.join(tempfile.gettempdir(), "didtemp") + _file_cache_path = os.path.join(str(Path.home()), "Documents", "DID", "fileCache") + _preferences_path = os.path.join( + str(Path.home()), "Documents", "DID", "Preferences" + ) @property def temppath(self): @@ -63,14 +71,17 @@ def preferences(self): must_be_writable(self._preferences_path) return self._preferences_path + # Placeholder for fileCache class class FileCache: def __init__(self, path, size): self.path = path self.size = size + _cached_cache = None + def get_cache(): """ Returns a persistent cache object. @@ -79,4 +90,4 @@ def get_cache(): if _cached_cache is None: path_constants = PathConstants() _cached_cache = FileCache(path_constants.filecachepath, 33) - return _cached_cache \ No newline at end of file + return _cached_cache diff --git a/src/did/database.py b/src/did/database.py index 4956e74..ef2be96 100644 --- a/src/did/database.py +++ b/src/did/database.py @@ -1,19 +1,20 @@ import abc + class Database(abc.ABC): - def __init__(self, connection='', **kwargs): + def __init__(self, connection="", **kwargs): self.connection = connection self.version = None - self.current_branch_id = '' + self.current_branch_id = "" self.frozen_branch_ids = [] self.dbid = None self.preferences = {} - self.debug = kwargs.get('debug', False) + self.debug = kwargs.get("debug", False) def __del__(self): try: self.close() - except: + except Exception: pass def open(self): @@ -87,7 +88,7 @@ def _do_get_doc_ids(self, branch_id=None): def _do_add_doc(self, document_obj, branch_id, **kwargs): pass - def get_docs(self, document_ids, branch_id=None, OnMissing='error', **kwargs): + def get_docs(self, document_ids, branch_id=None, OnMissing="error", **kwargs): is_single = False if not isinstance(document_ids, list): document_ids = [document_ids] @@ -100,36 +101,40 @@ def get_docs(self, document_ids, branch_id=None, OnMissing='error', **kwargs): # Checking logic here (inefficient but generic): if branch_id is not None: - branch_doc_ids = self.get_doc_ids(branch_id) - # If branch doesn't exist? get_doc_ids might return empty or raise? - # get_doc_ids calls _do_get_doc_ids. + branch_doc_ids = self.get_doc_ids(branch_id) + # If branch doesn't exist? get_doc_ids might return empty or raise? + # get_doc_ids calls _do_get_doc_ids. docs = [] for doc_id in document_ids: if branch_id is not None: if doc_id not in branch_doc_ids: # Document not in branch - if OnMissing == 'error': - raise ValueError(f"Document {doc_id} not found in branch {branch_id}") - elif OnMissing == 'warn': - print(f"Warning: Document {doc_id} not found in branch {branch_id}") - continue + if OnMissing == "error": + raise ValueError( + f"Document {doc_id} not found in branch {branch_id}" + ) + elif OnMissing == "warn": + print( + f"Warning: Document {doc_id} not found in branch {branch_id}" + ) + continue else: - continue + continue docs.append(self._do_get_doc(doc_id, OnMissing=OnMissing, **kwargs)) - if not docs and OnMissing != 'ignore' and len(document_ids) > 0: - # If filtered out all? - pass + if not docs and OnMissing != "ignore" and len(document_ids) > 0: + # If filtered out all? + pass if is_single: - return docs[0] if docs else None + return docs[0] if docs else None else: - return docs + return docs @abc.abstractmethod - def _do_get_doc(self, document_id, OnMissing='error', **kwargs): + def _do_get_doc(self, document_id, OnMissing="error", **kwargs): pass def remove_docs(self, document_ids, branch_id=None, **kwargs): @@ -179,9 +184,11 @@ def search(self, query_obj, branch_id=None): branch_id = self.current_branch_id doc_ids = self.get_doc_ids(branch_id) - docs = self.get_docs(doc_ids, OnMissing='ignore') - if docs is None: docs = [] - if not isinstance(docs, list): docs = [docs] + docs = self.get_docs(doc_ids, OnMissing="ignore") + if docs is None: + docs = [] + if not isinstance(docs, list): + docs = [docs] search_params = query_obj.to_search_structure() @@ -196,4 +203,4 @@ def search(self, query_obj, branch_id=None): @abc.abstractmethod def do_run_sql_query(self, query_str, **kwargs): - pass \ No newline at end of file + pass diff --git a/src/did/datastructures.py b/src/did/datastructures.py index 2f7338a..72697d4 100644 --- a/src/did/datastructures.py +++ b/src/did/datastructures.py @@ -2,6 +2,7 @@ import numpy as np import re + def cell_to_str(the_list): """ Converts a 1-D list to a string representation. @@ -9,10 +10,11 @@ def cell_to_str(the_list): This function mimics the behavior of the Matlab `cell2str` function. """ if not the_list: - return '[]' + return "[]" return json.dumps(the_list) + def cell_or_item(var, index=0, use_index_for_var=False): """ Returns the ith element of a list, or a single item. @@ -27,13 +29,15 @@ def cell_or_item(var, index=0, use_index_for_var=False): else: return var + def col_vec(x): """ Returns a matrix reshaped as a column vector. This function mimics the behavior of the Matlab `colvec` function. """ - return np.array(x).flatten('F').tolist() + return np.array(x).flatten("F").tolist() + def empty_struct(*field_names): """ @@ -44,6 +48,7 @@ def empty_struct(*field_names): """ return {} + def is_empty(x): """ Checks if a value is empty (None or an empty container). @@ -55,6 +60,7 @@ def is_empty(x): except TypeError: return False + def eq_emp(x, y): """ Compares two values, with special handling for empty values. @@ -71,6 +77,7 @@ def eq_emp(x, y): else: return x == y + def size_eq(x, y): """ Determines if the size of two variables is the same. @@ -79,6 +86,7 @@ def size_eq(x, y): """ return np.array(x).shape == np.array(y).shape + def eq_tot(x, y): """ Returns the logical AND of all the results of an element-wise comparison. @@ -87,6 +95,7 @@ def eq_tot(x, y): """ return np.array_equal(x, y) + def eq_len(x, y): """ Returns True if objects to compare are equal and have the same size. @@ -98,6 +107,7 @@ def eq_len(x, y): else: return False + def eq_unique(in_list): """ Return unique elements of a list. @@ -115,6 +125,7 @@ def eq_unique(in_list): out_list.append(item) return out_list + def is_full_field(a, composite_field_name): """ Checks if a nested field exists in a dictionary. @@ -124,7 +135,7 @@ def is_full_field(a, composite_field_name): if not isinstance(a, dict): return False, None - field_names = composite_field_name.split('.') + field_names = composite_field_name.split(".") current_level = a for field_name in field_names: @@ -135,6 +146,7 @@ def is_full_field(a, composite_field_name): return True, current_level + def struct_partial_match(a, b): """ Checks if dictionary b is a subset of dictionary a. @@ -148,6 +160,7 @@ def struct_partial_match(a, b): return True + def field_search(a, search_struct): """ Searches a dictionary to determine if it matches a search structure. @@ -160,75 +173,75 @@ def field_search(a, search_struct): b = False # Assume no match initially - field = search_struct.get('field', '') - operation = search_struct.get('operation', '') - param1 = search_struct.get('param1') - param2 = search_struct.get('param2') + field = search_struct.get("field", "") + operation = search_struct.get("operation", "") + param1 = search_struct.get("param1") + param2 = search_struct.get("param2") is_there, value = is_full_field(a, field) if field else (True, a) negation = False - if operation.startswith('~'): + if operation.startswith("~"): negation = True operation = operation[1:] op_lower = operation.lower() - if op_lower == 'regexp': + if op_lower == "regexp": if is_there and isinstance(value, str): if re.search(param1, value): b = True - elif op_lower == 'exact_string': + elif op_lower == "exact_string": if is_there: - b = (value == param1) - elif op_lower == 'exact_string_anycase': + b = value == param1 + elif op_lower == "exact_string_anycase": if is_there and isinstance(value, str): - b = (value.lower() == param1.lower()) - elif op_lower == 'contains_string': + b = value.lower() == param1.lower() + elif op_lower == "contains_string": if is_there and isinstance(value, str): - b = (param1 in value) - elif op_lower == 'exact_number': + b = param1 in value + elif op_lower == "exact_number": if is_there: b = eq_len(value, param1) - elif op_lower == 'lessthan': + elif op_lower == "lessthan": if is_there: try: b = np.all(np.array(value) < param1) except (ValueError, TypeError): pass - elif op_lower == 'lessthaneq': + elif op_lower == "lessthaneq": if is_there: try: b = np.all(np.array(value) <= param1) except (ValueError, TypeError): pass - elif op_lower == 'greaterthan': + elif op_lower == "greaterthan": if is_there: try: b = np.all(np.array(value) > param1) except (ValueError, TypeError): pass - elif op_lower == 'greaterthaneq': + elif op_lower == "greaterthaneq": if is_there: try: b = np.all(np.array(value) >= param1) except (ValueError, TypeError): pass - elif op_lower == 'hassize': + elif op_lower == "hassize": if is_there: b = eq_len(np.array(value).shape, param1) - elif op_lower == 'hasmember': + elif op_lower == "hasmember": if is_there: try: b = param1 in value except TypeError: pass - elif op_lower == 'hasfield': + elif op_lower == "hasfield": b = is_there - elif op_lower == 'partial_struct': + elif op_lower == "partial_struct": if is_there: b = struct_partial_match(value, param1) - elif op_lower in ('hasanysubfield_contains_string', 'hasanysubfield_exact_string'): + elif op_lower in ("hasanysubfield_contains_string", "hasanysubfield_exact_string"): if is_there and (isinstance(value, list) or isinstance(value, dict)): items_to_check = value if isinstance(value, list) else [value] param1_list = param1 if isinstance(param1, list) else [param1] @@ -242,39 +255,40 @@ def field_search(a, search_struct): if not sub_is_there: match = False break - if op_lower == 'hasanysubfield_contains_string': + if op_lower == "hasanysubfield_contains_string": if not (isinstance(sub_value, str) and p2 in sub_value): match = False break - elif op_lower == 'hasanysubfield_exact_string': + elif op_lower == "hasanysubfield_exact_string": if not (isinstance(sub_value, str) and sub_value == p2): match = False break if match: b = True break - elif op_lower == 'or': + elif op_lower == "or": if isinstance(param1, dict) and isinstance(param2, dict): b = field_search(a, param1) or field_search(a, param2) - elif op_lower == 'depends_on': + elif op_lower == "depends_on": # param1 = dependency name, param2 = dependency value - if 'depends_on' in a: - for dep in a['depends_on']: - if dep.get('name') == param1: - if dep.get('value') == param2: + if "depends_on" in a: + for dep in a["depends_on"]: + if dep.get("name") == param1: + if dep.get("value") == param2: b = True break - elif op_lower == 'isa': + elif op_lower == "isa": # param1 = class name if param1 in a: b = True - elif 'document_class' in a and a['document_class'].get('class_name') == param1: + elif "document_class" in a and a["document_class"].get("class_name") == param1: b = True else: raise ValueError(f"Unknown search operation: {operation}") return not b if negation else b + def find_closest(arr, v): """ Finds the closest value in an array (using absolute value). @@ -287,6 +301,7 @@ def find_closest(arr, v): idx = (np.abs(arr - v)).argmin() return idx, arr[idx] + def json_encode_nan(obj): """ Encodes a Python object into a JSON object, allowing for NaN/Infinity. @@ -295,6 +310,7 @@ def json_encode_nan(obj): """ return json.dumps(obj, allow_nan=True, indent=4) + def struct_merge(s1, s2, error_if_new_field=False, do_alphabetical=True): """ Merges two dictionaries into a common dictionary. @@ -304,7 +320,9 @@ def struct_merge(s1, s2, error_if_new_field=False, do_alphabetical=True): if error_if_new_field: missing_fields = set(s2.keys()) - set(s1.keys()) if missing_fields: - raise ValueError(f"Some fields of the second dictionary are not in the first: {', '.join(missing_fields)}") + raise ValueError( + f"Some fields of the second dictionary are not in the first: {', '.join(missing_fields)}" + ) s_out = s1.copy() s_out.update(s2) @@ -313,6 +331,8 @@ def struct_merge(s1, s2, error_if_new_field=False, do_alphabetical=True): return {key: s_out[key] for key in sorted(s_out)} else: return s_out + + def table_cross_join(t1, t2): """ Performs a cross join (Cartesian product) of two lists of dictionaries. diff --git a/src/did/db.py b/src/did/db.py index 6135b91..870395a 100644 --- a/src/did/db.py +++ b/src/did/db.py @@ -1,5 +1,6 @@ import pandas as pd + def struct_name_value_search(the_struct, the_name, make_error=True): """ Searches a list of dictionaries with 'name' and 'value' keys. @@ -12,17 +13,20 @@ def struct_name_value_search(the_struct, the_name, make_error=True): for i, item in enumerate(the_struct): if not isinstance(item, dict): raise TypeError("the_struct must be a list of dictionaries.") - if 'name' not in item or 'value' not in item: - raise ValueError("Each dictionary in the_struct must have 'name' and 'value' keys.") + if "name" not in item or "value" not in item: + raise ValueError( + "Each dictionary in the_struct must have 'name' and 'value' keys." + ) - if item['name'] == the_name: - return item['value'], i + if item["name"] == the_name: + return item["value"], i if make_error: raise ValueError(f"No matching entries for {the_name} were found.") else: return None, None + def table_cross_join(table1, table2, rename_conflicting_columns=False): """ Performs a Cartesian product (SQL-style CROSS JOIN) of two pandas DataFrames. @@ -35,15 +39,24 @@ def table_cross_join(table1, table2, rename_conflicting_columns=False): conflicting_names = set(table1.columns) & set(table2.columns) if conflicting_names and not rename_conflicting_columns: - raise ValueError(f"Input DataFrames have conflicting column names: {', '.join(conflicting_names)}. " - "Set 'rename_conflicting_columns' to True to automatically rename them.") + raise ValueError( + f"Input DataFrames have conflicting column names: {', '.join(conflicting_names)}. " + "Set 'rename_conflicting_columns' to True to automatically rename them." + ) if rename_conflicting_columns: - table2 = table2.rename(columns={col: f"{col}_1" if col in conflicting_names else col for col in table2.columns}) + table2 = table2.rename( + columns={ + col: f"{col}_1" if col in conflicting_names else col + for col in table2.columns + } + ) - table1['_cross_join_key'] = 1 - table2['_cross_join_key'] = 1 + table1["_cross_join_key"] = 1 + table2["_cross_join_key"] = 1 - result_table = pd.merge(table1, table2, on='_cross_join_key').drop('_cross_join_key', axis=1) + result_table = pd.merge(table1, table2, on="_cross_join_key").drop( + "_cross_join_key", axis=1 + ) - return result_table \ No newline at end of file + return result_table diff --git a/src/did/document.py b/src/did/document.py index 1baf6cd..ab66a63 100644 --- a/src/did/document.py +++ b/src/did/document.py @@ -5,14 +5,15 @@ from . import ido from .common import PathConstants + class Document: - def __init__(self, document_type='base', **kwargs): + def __init__(self, document_type="base", **kwargs): if isinstance(document_type, dict): self.document_properties = document_type else: self.document_properties = self.read_blank_definition(document_type) - self.document_properties['base']['id'] = ido.IDO.unique_id() - self.document_properties['base']['datestamp'] = str(datetime.utcnow()) + self.document_properties["base"]["id"] = ido.IDO.unique_id() + self.document_properties["base"]["datestamp"] = str(datetime.utcnow()) for key, value in kwargs.items(): # This is a simplified way to set properties. A full implementation @@ -23,13 +24,13 @@ def __init__(self, document_type='base', **kwargs): self._reset_file_info() def id(self): - return self.document_properties.get('base', {}).get('id') + return self.document_properties.get("base", {}).get("id") def set_properties(self, **kwargs): for key, value in kwargs.items(): # This is a simplified way to set properties. A full implementation # would need to handle nested properties like 'base.name'. - path = key.split('.') + path = key.split(".") d = self.document_properties for p in path[:-1]: d = d.setdefault(p, {}) @@ -37,46 +38,45 @@ def set_properties(self, **kwargs): return self def _reset_file_info(self): - if 'files' in self.document_properties: + if "files" in self.document_properties: # Only reset if file_info is missing or we are initializing a new document - if 'file_info' not in self.document_properties['files']: - self.document_properties['files']['file_info'] = datastructures.empty_struct('name', 'locations') + if "file_info" not in self.document_properties["files"]: + self.document_properties["files"]["file_info"] = ( + datastructures.empty_struct("name", "locations") + ) def is_in_file_list(self, filename): - file_info = self.document_properties.get('files', {}).get('file_info', []) + file_info = self.document_properties.get("files", {}).get("file_info", []) if isinstance(file_info, dict) and not file_info: - file_info = [] + file_info = [] for i, info in enumerate(file_info): - if info.get('name') == filename: + if info.get("name") == filename: return True, info, i return False, None, None def add_file(self, filename, location): - if 'files' not in self.document_properties: - self.document_properties['files'] = {'file_info': []} + if "files" not in self.document_properties: + self.document_properties["files"] = {"file_info": []} - files_prop = self.document_properties['files'] - if 'file_info' not in files_prop: - files_prop['file_info'] = [] + files_prop = self.document_properties["files"] + if "file_info" not in files_prop: + files_prop["file_info"] = [] - if isinstance(files_prop['file_info'], dict) and not files_prop['file_info']: - files_prop['file_info'] = [] + if isinstance(files_prop["file_info"], dict) and not files_prop["file_info"]: + files_prop["file_info"] = [] - file_info_list = files_prop['file_info'] + file_info_list = files_prop["file_info"] is_in, _, _ = self.is_in_file_list(filename) if not is_in: - new_info = { - 'name': filename, - 'locations': {'location': location} - } - file_info_list.append(new_info) + new_info = {"name": filename, "locations": {"location": location}} + file_info_list.append(new_info) def remove_file(self, filename): is_in, _, index = self.is_in_file_list(filename) if is_in: - del self.document_properties['files']['file_info'][index] + del self.document_properties["files"]["file_info"][index] @staticmethod def set_schema_path(path): @@ -87,58 +87,58 @@ def read_blank_definition(json_file_location_string): # This is a simplified version of the Matlab function. # It reads a JSON file from a predefined location. - schema_path = os.path.join(PathConstants.DEFPATH, 'database_schema') + schema_path = os.path.join(PathConstants.DEFPATH, "database_schema") filepath = os.path.join(schema_path, f"{json_file_location_string}.schema.json") if os.path.exists(filepath): - with open(filepath, 'r') as f: + with open(filepath, "r") as f: data = json.load(f) # Ensure the 'base' key exists - if 'base' not in data: - data['base'] = {} + if "base" not in data: + data["base"] = {} return data # Fallback for base - if json_file_location_string == 'base': + if json_file_location_string == "base": return { "document_class": { "class_name": "did.document", "property_list_name": "base", "class_version": "1.0", - "superclasses": [] + "superclasses": [], }, - "base": { - "id": "", - "name": "", - "datestamp": "" - } + "base": {"id": "", "name": "", "datestamp": ""}, } - raise FileNotFoundError(f"Could not find definition for {json_file_location_string}") + raise FileNotFoundError( + f"Could not find definition for {json_file_location_string}" + ) def dependency_value(self, dependency_name, error_if_not_found=True): - if 'depends_on' in self.document_properties: - for dep in self.document_properties['depends_on']: - if dep.get('name') == dependency_name: - return dep.get('value') + if "depends_on" in self.document_properties: + for dep in self.document_properties["depends_on"]: + if dep.get("name") == dependency_name: + return dep.get("value") if error_if_not_found: raise ValueError(f"Dependency '{dependency_name}' not found.") return None def set_dependency_value(self, dependency_name, value, error_if_not_found=True): - if 'depends_on' in self.document_properties: - for dep in self.document_properties['depends_on']: - if dep.get('name') == dependency_name: - dep['value'] = value + if "depends_on" in self.document_properties: + for dep in self.document_properties["depends_on"]: + if dep.get("name") == dependency_name: + dep["value"] = value return self if error_if_not_found: raise ValueError(f"Dependency '{dependency_name}' not found.") # If not found and not erroring, add it - if 'depends_on' not in self.document_properties: - self.document_properties['depends_on'] = [] - self.document_properties['depends_on'].append({'name': dependency_name, 'value': value}) + if "depends_on" not in self.document_properties: + self.document_properties["depends_on"] = [] + self.document_properties["depends_on"].append( + {"name": dependency_name, "value": value} + ) return self - # ... other methods like validate, plus, etc. would be implemented here ... \ No newline at end of file + # ... other methods like validate, plus, etc. would be implemented here ... diff --git a/src/did/documentservice.py b/src/did/documentservice.py index a10b830..f5f5703 100644 --- a/src/did/documentservice.py +++ b/src/did/documentservice.py @@ -1,5 +1,6 @@ import abc + class DocumentService(abc.ABC): def __init__(self): pass @@ -16,4 +17,4 @@ def search_query(self): """ Create a search query to find this object as a document. """ - pass \ No newline at end of file + pass diff --git a/src/did/file.py b/src/did/file.py index 8a7a9ff..1d816b1 100644 --- a/src/did/file.py +++ b/src/did/file.py @@ -1,25 +1,55 @@ import os import time import uuid -import struct import json import re from datetime import datetime, timedelta import portalocker from urllib.parse import urlparse + def must_be_valid_permission(value): - VALID_PERMISSIONS = ["r", "w", "a", "r+", "w+", "a+", "rb", "wb", "ab", "r+b", "w+b", "a+b"] + VALID_PERMISSIONS = [ + "r", + "w", + "a", + "r+", + "w+", + "a+", + "rb", + "wb", + "ab", + "r+b", + "w+b", + "a+b", + ] if value not in VALID_PERMISSIONS: - raise ValueError(f"File permission must be one of: {', '.join(VALID_PERMISSIONS)}") + raise ValueError( + f"File permission must be one of: {', '.join(VALID_PERMISSIONS)}" + ) + def must_be_valid_machine_format(value): - VALID_MACHINE_FORMAT = ['n', 'native', 'b', 'ieee-be', 'l', 'ieee-le', 's', 'ieee-be.l64', 'a', 'ieee-le.l64'] + VALID_MACHINE_FORMAT = [ + "n", + "native", + "b", + "ieee-be", + "l", + "ieee-le", + "s", + "ieee-be.l64", + "a", + "ieee-le.l64", + ] if value not in VALID_MACHINE_FORMAT: - raise ValueError(f"Machine format must be one of: {', '.join(VALID_MACHINE_FORMAT)}") + raise ValueError( + f"Machine format must be one of: {', '.join(VALID_MACHINE_FORMAT)}" + ) + class Fileobj: - def __init__(self, fullpathfilename='', permission='r', machineformat='n'): + def __init__(self, fullpathfilename="", permission="r", machineformat="n"): must_be_valid_permission(permission) must_be_valid_machine_format(machineformat) self.fullpathfilename = fullpathfilename @@ -27,7 +57,9 @@ def __init__(self, fullpathfilename='', permission='r', machineformat='n'): self.machineformat = machineformat self.fid = None - def set_properties(self, fullpathfilename=None, permission=None, machineformat=None): + def set_properties( + self, fullpathfilename=None, permission=None, machineformat=None + ): if fullpathfilename: self.fullpathfilename = fullpathfilename if permission: @@ -53,8 +85,8 @@ def fopen(self, permission=None, machineformat=None, filename=None): # Python's open() doesn't have a direct machine format mapping like Matlab. # The 'b' for binary mode is the most relevant part of the permission string. mode = self.permission - if 'b' not in mode: - mode += 'b' # Default to binary for this class + if "b" not in mode: + mode += "b" # Default to binary for this class self.fid = open(self.fullpathfilename, mode) except IOError: @@ -62,7 +94,7 @@ def fopen(self, permission=None, machineformat=None, filename=None): return self def fclose(self): - if getattr(self, 'fid', None): + if getattr(self, "fid", None): self.fid.close() self.fid = None @@ -97,18 +129,18 @@ def fwrite(self, data): def fread(self, count=-1): if self.fid: return self.fid.read(count) - return b'', 0 + return b"", 0 def fgetl(self): if self.fid: line = self.fid.readline() - return line.strip(b'\n') - return '' + return line.strip(b"\n") + return "" def fgets(self, nchar=-1): if self.fid: return self.fid.readline(nchar) - return '' + return "" def ferror(self): # Python's file objects raise exceptions rather than setting error flags. @@ -121,6 +153,7 @@ def fileparts(self): def __del__(self): self.fclose() + def checkout_lock_file(filename, check_loops=30, throw_error=True, expiration=3600): """ Tries to establish control of a lock file. @@ -132,26 +165,28 @@ def checkout_lock_file(filename, check_loops=30, throw_error=True, expiration=36 for _ in range(check_loops): try: - lock_file = open(lock_filename, 'x') + lock_file = open(lock_filename, "x") # Use portalocker for an exclusive lock portalocker.lock(lock_file, portalocker.LOCK_EX | portalocker.LOCK_NB) expiration_time = datetime.utcnow() + timedelta(seconds=expiration) lock_file.write(f"{expiration_time.isoformat()}\n{key}") - lock_file.close() # Close the file handle, but the lock is associated with the file path + lock_file.close() # Close the file handle, but the lock is associated with the file path return lock_file, key except (IOError, portalocker.exceptions.LockException): # File exists or is locked, check for expiration try: - with open(lock_filename, 'r') as f: + with open(lock_filename, "r") as f: lines = f.readlines() if len(lines) >= 1: expiration_time_str = lines[0].strip() expiration_time = datetime.fromisoformat(expiration_time_str) if datetime.utcnow() > expiration_time: # Lock expired, try to remove it - release_lock_file(filename, lines[1].strip() if len(lines)>1 else "") - continue # Retry immediately + release_lock_file( + filename, lines[1].strip() if len(lines) > 1 else "" + ) + continue # Retry immediately except (IOError, ValueError): # Could not read lock file or parse time, wait and retry pass @@ -161,6 +196,7 @@ def checkout_lock_file(filename, check_loops=30, throw_error=True, expiration=36 raise IOError(f"Unable to obtain lock with file {filename}.") return None, None + def release_lock_file(filename, key): """ Releases a lock file with the key. @@ -172,12 +208,12 @@ def release_lock_file(filename, key): return True try: - with open(lock_filename, 'r+') as f: + with open(lock_filename, "r+") as f: portalocker.lock(f, portalocker.LOCK_EX) lines = f.readlines() if len(lines) >= 2 and lines[1].strip() == key: # We have the key, release the lock and delete the file - f.truncate(0) # Clear the file + f.truncate(0) # Clear the file portalocker.unlock(f) os.remove(lock_filename) return True @@ -189,6 +225,7 @@ def release_lock_file(filename, key): # Could not get a lock, or file was removed by another process return not os.path.exists(lock_filename) + class BinaryTable: def __init__(self, f, record_type, record_size, elements_per_column, header_size): self.file = f @@ -197,7 +234,7 @@ def __init__(self, f, record_type, record_size, elements_per_column, header_size self.elements_per_column = elements_per_column self.header_size = header_size self.has_lock = False - self.file.set_properties(machineformat='l') # always little-endian + self.file.set_properties(machineformat="l") # always little-endian if not self.file.fullpathfilename: raise ValueError("A full path file name must be given to the file object.") @@ -216,18 +253,20 @@ def get_size(self): def read_header(self): lock_fid, key = self.get_lock() try: - with open(self.file.fullpathfilename, 'rb') as f: + with open(self.file.fullpathfilename, "rb") as f: return f.read(self.header_size) finally: self.release_lock(lock_fid, key) def write_header(self, header_data): if len(header_data) > self.header_size: - raise ValueError("Header data to write is larger than the header size of the file.") + raise ValueError( + "Header data to write is larger than the header size of the file." + ) lock_fid, key = self.get_lock() try: - with open(self.file.fullpathfilename, 'r+b') as f: + with open(self.file.fullpathfilename, "r+b") as f: f.write(header_data) finally: self.release_lock(lock_fid, key) @@ -258,12 +297,16 @@ def read_row(self, row, col): # careful handling of struct format strings and file seeking. lock_fid, key = self.get_lock() try: - with open(self.file.fullpathfilename, 'rb') as f: + with open(self.file.fullpathfilename, "rb") as f: r, _, _ = self.get_size() if row > r: raise IndexError("Row index out of bounds.") - offset = self.header_size + (row - 1) * self.row_size() + sum(self.record_size[:col-1]) + offset = ( + self.header_size + + (row - 1) * self.row_size() + + sum(self.record_size[: col - 1]) + ) f.seek(offset) # This is a simplified example. The actual implementation would @@ -272,7 +315,7 @@ def read_row(self, row, col): # and handle elements_per_column correctly. # Placeholder for reading data - data = f.read(self.record_size[col-1]) + data = f.read(self.record_size[col - 1]) return data finally: self.release_lock(lock_fid, key) @@ -280,16 +323,23 @@ def read_row(self, row, col): # ... other methods like insert_row, delete_row, write_entry etc. would be implemented here ... # These would involve complex file manipulation (copying to temp files) and are omitted for brevity. + class DumbJsonDB: - def __init__(self, command='none', filename='', dirname='.dumbjsondb', unique_object_id_field='id'): - self.paramfilename = '' + def __init__( + self, + command="none", + filename="", + dirname=".dumbjsondb", + unique_object_id_field="id", + ): + self.paramfilename = "" self.dirname = dirname self.unique_object_id_field = unique_object_id_field - if command == 'new': + if command == "new": self.paramfilename = filename self._write_parameters() - elif command == 'load': + elif command == "load": self._load_parameters(filename) def _document_path(self): @@ -305,10 +355,10 @@ def _write_parameters(self): os.makedirs(path) params = { - 'dirname': self.dirname, - 'unique_object_id_field': self.unique_object_id_field + "dirname": self.dirname, + "unique_object_id_field": self.unique_object_id_field, } - with open(self.paramfilename, 'w') as f: + with open(self.paramfilename, "w") as f: json.dump(params, f, indent=4) doc_path = self._document_path() @@ -317,10 +367,12 @@ def _write_parameters(self): def _load_parameters(self, filename): self.paramfilename = filename - with open(filename, 'r') as f: + with open(filename, "r") as f: params = json.load(f) - self.dirname = params.get('dirname', self.dirname) - self.unique_object_id_field = params.get('unique_object_id_field', self.unique_object_id_field) + self.dirname = params.get("dirname", self.dirname) + self.unique_object_id_field = params.get( + "unique_object_id_field", self.unique_object_id_field + ) @staticmethod def _fix_doc_unique_id(doc_unique_id): @@ -332,7 +384,9 @@ def _fix_doc_unique_id(doc_unique_id): def _uniqueid2filename(doc_unique_id, version=0): doc_unique_id = DumbJsonDB._fix_doc_unique_id(doc_unique_id) # A simple and safe way to create a filename from an ID - safe_id = "".join([c for c in doc_unique_id if c.isalpha() or c.isdigit() or c=='_']).rstrip() + safe_id = "".join( + [c for c in doc_unique_id if c.isalpha() or c.isdigit() or c == "_"] + ).rstrip() return f"Object_id_{safe_id}_v{version:05x}.json" def doc_versions(self, doc_unique_id): @@ -343,12 +397,12 @@ def doc_versions(self, doc_unique_id): # Simplified version search, a more robust implementation would parse filenames more carefully prefix = f"Object_id_{doc_unique_id}_v" for f in os.listdir(path): - if f.startswith(prefix) and f.endswith('.json'): + if f.startswith(prefix) and f.endswith(".json"): try: - version_hex = f[len(prefix):-5] + version_hex = f[len(prefix) : -5] versions.append(int(version_hex, 16)) except ValueError: - continue # filename format not as expected + continue # filename format not as expected return sorted(versions) def read(self, doc_unique_id, version=None): @@ -363,7 +417,7 @@ def read(self, doc_unique_id, version=None): filepath = os.path.join(self._document_path(), filename) if os.path.exists(filepath): - with open(filepath, 'r') as f: + with open(filepath, "r") as f: return json.load(f), version return None, None @@ -381,27 +435,34 @@ def add(self, doc_object, overwrite=1, doc_version=None): if file_exists: if overwrite == 0: - raise IOError(f"Document with id {doc_unique_id} and version {doc_version} already exists.") + raise IOError( + f"Document with id {doc_unique_id} and version {doc_version} already exists." + ) elif overwrite == 2: - doc_version = (max(self.doc_versions(doc_unique_id) or [0]) + 1) + doc_version = max(self.doc_versions(doc_unique_id) or [0]) + 1 filename = self._uniqueid2filename(doc_unique_id, doc_version) filepath = os.path.join(self._document_path(), filename) - with open(filepath, 'w') as f: + with open(filepath, "w") as f: json.dump(doc_object, f, indent=4) # Simplified metadata update - self._update_doc_metadata('Added new version', doc_object, doc_unique_id, doc_version) + self._update_doc_metadata( + "Added new version", doc_object, doc_unique_id, doc_version + ) def _update_doc_metadata(self, operation, document, doc_unique_id, doc_version): # This is a simplified placeholder for the metadata logic. # A full implementation would be more complex. pass + class FileCache: - CACHE_INFO_FILE_NAME = '.fileCacheInfo' + CACHE_INFO_FILE_NAME = ".fileCacheInfo" - def __init__(self, directory_name, file_name_characters=32, max_size=100e9, reduce_size=80e9): + def __init__( + self, directory_name, file_name_characters=32, max_size=100e9, reduce_size=80e9 + ): if not os.path.isdir(directory_name): raise ValueError("directory_name must be an existing directory.") @@ -429,32 +490,35 @@ def set_properties(self, max_size, reduce_size, current_size): self.current_size = current_size info = { - 'fileNameCharacters': self.file_name_characters, - 'maxSize': self.max_size, - 'reduceSize': self.reduce_size, - 'currentSize': self.current_size, - 'files': {} # In Python, we can store the file list in the same JSON + "fileNameCharacters": self.file_name_characters, + "maxSize": self.max_size, + "reduceSize": self.reduce_size, + "currentSize": self.current_size, + "files": {}, # In Python, we can store the file list in the same JSON } info_file = self._info_file_name() - with open(info_file, 'w') as f: + with open(info_file, "w") as f: json.dump(info, f, indent=4) def _load_properties(self): info_file = self._info_file_name() - with open(info_file, 'r') as f: + with open(info_file, "r") as f: info = json.load(f) - self.file_name_characters = info.get('fileNameCharacters', self.file_name_characters) - self.max_size = info.get('maxSize', self.max_size) - self.reduce_size = info.get('reduceSize', self.reduce_size) - self.current_size = info.get('currentSize', self.current_size) + self.file_name_characters = info.get( + "fileNameCharacters", self.file_name_characters + ) + self.max_size = info.get("maxSize", self.max_size) + self.reduce_size = info.get("reduceSize", self.reduce_size) + self.current_size = info.get("currentSize", self.current_size) # The other methods (addFile, removeFile, etc.) would be implemented here. # These are complex and would require careful management of the JSON info file # and file system operations. For the purpose of this port, the core structure # has been established. + def fileid_value(fid_or_fileobj): """ Returns the file identifier from a raw FID or a Fileobj object. @@ -464,24 +528,28 @@ def fileid_value(fid_or_fileobj): else: return fid_or_fileobj + def filesep_conversion(filestring, orig_filesep, new_filesep): """ Converts file separators in a path string. """ return filestring.replace(orig_filesep, new_filesep) + def is_filepath_root(filepath): """ Determines if a file path is at the root or not. """ return os.path.isabs(filepath) + def full_filename(filename): """ Returns the full path file name of a file. """ return os.path.abspath(filename) + def is_url(input_string): """ Checks if a string is a URL. @@ -492,40 +560,51 @@ def is_url(input_string): except ValueError: return False + def read_lines(file_path): """ Reads lines of a file as a list of strings. """ - with open(file_path, 'r') as f: + with open(file_path, "r") as f: lines = f.readlines() # Remove trailing newline characters - return [line.rstrip('\n') for line in lines] + return [line.rstrip("\n") for line in lines] + class ReadOnlyFileobj(Fileobj): - def __init__(self, fullpathfilename='', machineformat='n'): - super().__init__(fullpathfilename=fullpathfilename, permission='r', machineformat=machineformat) + def __init__(self, fullpathfilename="", machineformat="n"): + super().__init__( + fullpathfilename=fullpathfilename, + permission="r", + machineformat=machineformat, + ) def fopen(self, permission=None, machineformat=None, filename=None): - if permission and 'r' not in permission: + if permission and "r" not in permission: raise ValueError("Read-only file must be opened with 'r' permission.") - return super().fopen(permission='r', machineformat=machineformat, filename=filename) + return super().fopen( + permission="r", machineformat=machineformat, filename=filename + ) + def str_to_text(filename, s): """ Writes a string to a text file. """ - with open(filename, 'w') as f: + with open(filename, "w") as f: f.write(s) + def string_to_filestring(s): """ Edits a string so it is suitable for use as part of a filename. """ - return re.sub(r'[^a-zA-Z0-9]', '_', s) + return re.sub(r"[^a-zA-Z0-9]", "_", s) + def text_to_cellstr(filename): """ Reads a text file and imports each line as an entry in a list of strings. This is an alias for read_lines. """ - return read_lines(filename) \ No newline at end of file + return read_lines(filename) diff --git a/src/did/fun.py b/src/did/fun.py index 740ecdb..01060e4 100644 --- a/src/did/fun.py +++ b/src/did/fun.py @@ -1,5 +1,6 @@ import networkx as nx + def docs_to_graph(document_objs): """ Creates a directed graph from a list of Document objects. @@ -12,9 +13,9 @@ def docs_to_graph(document_objs): for doc in document_objs: here_node = doc.id() - dependencies = doc.document_properties.get('depends_on', []) + dependencies = doc.document_properties.get("depends_on", []) for dep in dependencies: - there_node = dep.get('value') + there_node = dep.get("value") if there_node in nodes: # Edge from B to A if A depends on B g.add_edge(there_node, here_node) @@ -23,6 +24,7 @@ def docs_to_graph(document_objs): # The adjacency matrix and node list can be accessed from the graph object. return g + def find_all_dependencies(graph, doc_ids): """ Finds all documents that depend on a given set of documents. @@ -33,13 +35,14 @@ def find_all_dependencies(graph, doc_ids): all_deps.update(nx.descendants(graph, doc_id)) return list(all_deps) + def find_docs_missing_dependencies(db, *dependency_names): """ Finds documents that have dependencies on documents that do not exist. """ from .query import Query - q = Query('depends_on', 'hasfield', '', '') + q = Query("depends_on", "hasfield", "", "") docs_with_deps = db.search(q) missing_deps_docs = [] @@ -47,21 +50,22 @@ def find_docs_missing_dependencies(db, *dependency_names): all_doc_ids = db.all_doc_ids() for doc in docs_with_deps: - dependencies = doc.document_properties.get('depends_on', []) + dependencies = doc.document_properties.get("depends_on", []) for dep in dependencies: - dep_name = dep.get('name') - dep_value = dep.get('value') + dep_name = dep.get("name") + dep_value = dep.get("value") if dependency_names and dep_name not in dependency_names: continue if dep_value and dep_value not in all_doc_ids: missing_deps_docs.append(doc) - break # Move to the next document + break # Move to the next document return missing_deps_docs -def plot_interactive_doc_graph(docs, g, layout='spring'): + +def plot_interactive_doc_graph(docs, g, layout="spring"): """ Plots an interactive document graph. @@ -71,8 +75,8 @@ def plot_interactive_doc_graph(docs, g, layout='spring'): fig, ax = plt.subplots() - if layout == 'layered': - pos = nx.nx_agraph.graphviz_layout(g, prog='dot') + if layout == "layered": + pos = nx.nx_agraph.graphviz_layout(g, prog="dot") else: pos = nx.spring_layout(g) @@ -83,10 +87,10 @@ def on_click(event): return # Find the closest node to the click - min_dist = float('inf') + min_dist = float("inf") closest_node = None for node, (x, y) in pos.items(): - dist = (x - event.xdata)**2 + (y - event.ydata)**2 + dist = (x - event.xdata) ** 2 + (y - event.ydata) ** 2 if dist < min_dist: min_dist = dist closest_node = node @@ -112,5 +116,5 @@ def on_click(event): clicked_node = clicked_doc print("Global variable 'clicked_node' set to clicked document") - fig.canvas.mpl_connect('button_press_event', on_click) - plt.show() \ No newline at end of file + fig.canvas.mpl_connect("button_press_event", on_click) + plt.show() diff --git a/src/did/ido.py b/src/did/ido.py index dcb38a3..80ac666 100644 --- a/src/did/ido.py +++ b/src/did/ido.py @@ -1,6 +1,7 @@ import uuid import re + class IDO: def __init__(self, id_value=None): if id_value and self.is_valid(id_value): @@ -25,5 +26,7 @@ def is_valid(id_value): Checks if a unique ID is valid. """ # A simple regex to check for UUID format - pattern = re.compile(r'^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\Z', re.I) - return bool(pattern.match(str(id_value))) \ No newline at end of file + pattern = re.compile( + r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}\Z", re.I + ) + return bool(pattern.match(str(id_value))) diff --git a/src/did/implementations/binarydoc_matfid.py b/src/did/implementations/binarydoc_matfid.py index 581763e..586c6e6 100644 --- a/src/did/implementations/binarydoc_matfid.py +++ b/src/did/implementations/binarydoc_matfid.py @@ -1,18 +1,19 @@ from ..binarydoc import BinaryDoc from ..file import Fileobj + class BinaryDocMatfid(BinaryDoc, Fileobj): - def __init__(self, key='', doc_unique_id='', **kwargs): + def __init__(self, key="", doc_unique_id="", **kwargs): super().__init__(**kwargs) self.key = key self.doc_unique_id = doc_unique_id # Ensure machine format is little-endian for cross-platform compatibility - self.machineformat = 'l' + self.machineformat = "l" def fclose(self): super().fclose() # Reset properties after closing - self.permission = 'r' + self.permission = "r" # The abstract methods from BinaryDoc would be implemented here, # likely by calling the corresponding methods of the Fileobj superclass. @@ -37,4 +38,4 @@ def fwrite(self, data, precision=None, skip=0): def fread(self, count=-1, precision=None, skip=0): # The precision and skip parameters would need to be handled # using Python's struct module for a full implementation. - return super().fread(count) \ No newline at end of file + return super().fread(count) diff --git a/src/did/implementations/doc2sql.py b/src/did/implementations/doc2sql.py index d8983fd..cb25f89 100644 --- a/src/did/implementations/doc2sql.py +++ b/src/did/implementations/doc2sql.py @@ -1,71 +1,187 @@ +import re + + def get_field(doc_props, fields): if not isinstance(fields, list): fields = [fields] for field in fields: - path = field.split('.') + path = field.split(".") d = doc_props try: for p in path: d = d[p] - if d: + if d is not None and d != "": return d - except (KeyError, TypeError): + except (KeyError, TypeError, IndexError): continue - return '' + return "" + def new_column(name, value, matlab_type=None): if matlab_type is None: matlab_type = type(value).__name__ return { - 'name': name, - 'matlabType': matlab_type, - 'sqlType': sql_type_of(matlab_type), - 'value': value + "name": name, + "matlabType": matlab_type, + "sqlType": sql_type_of(matlab_type), + "value": value, } + def sql_type_of(matlab_type): - type_map = { - 'bool': 'BOOLEAN', - 'str': 'TEXT', - 'int': 'INTEGER', - 'float': 'REAL' - } - return type_map.get(matlab_type, 'BLOB') + type_map = {"bool": "BOOLEAN", "str": "TEXT", "int": "INTEGER", "float": "REAL"} + return type_map.get(matlab_type, "BLOB") + + +def _get_class_name(doc_props): + """Extract class name from document properties, supporting both DID-python and NDI formats.""" + # DID-python schema format + if "classname" in doc_props: + return doc_props["classname"] + # NDI / MATLAB format + return get_field(doc_props, ["document_class.class_name", "ndi_document.type"]) + + +def _get_superclass_str(doc_props): + """Extract superclass string matching MATLAB's doc2sql format. + + MATLAB produces comma-space separated, sorted unique superclass names. + For MATLAB-style definitions like "$PATH/base.json", strip path and extension. + For DID-python style ["base", "demoA"], use directly. + """ + # DID-python schema format: top-level 'superclasses' list of strings + if "superclasses" in doc_props and isinstance(doc_props["superclasses"], list): + superclasses = doc_props["superclasses"] + if not superclasses: + return "" + names = [] + for sc in superclasses: + if isinstance(sc, str): + names.append(sc) + elif isinstance(sc, dict) and "definition" in sc: + # MATLAB-style: extract name from definition path + defn = sc["definition"] + name = re.sub(r".+/", "", defn) + name = re.sub(r"\..+$", "", name) + names.append(name) + names = sorted(set(names)) + return ", ".join(names) + + # NDI / MATLAB format: document_class.superclasses + superclasses = get_field(doc_props, ["document_class.superclasses"]) + if isinstance(superclasses, list): + names = [] + for sc in superclasses: + if isinstance(sc, dict) and "definition" in sc: + defn = sc["definition"] + name = re.sub(r".+/", "", defn) + name = re.sub(r"\..+$", "", name) + names.append(name) + elif isinstance(sc, str): + names.append(sc) + names = sorted(set(names)) + return ", ".join(names) + + return "" + + +def _serialize_depends_on(doc_props): + """Serialize depends_on matching MATLAB's format: 'name,value;name,value;'""" + depends_on = doc_props.get("depends_on", []) + if not depends_on or not isinstance(depends_on, list): + return "" + + parts = [] + for dep in depends_on: + if isinstance(dep, dict): + name = str(dep.get("name", "")) + value = str(dep.get("value", "")) + if name and value: + parts.append(f"{name},{value};") + + return "".join(parts) + + +def _flatten_dict(d, prefix=""): + """Flatten a nested dict using ___ separator for nested keys (matching MATLAB's getMetaTableFrom).""" + items = [] + for key, value in d.items(): + col_name = f"{prefix}___{key}" if prefix else key + if isinstance(value, dict): + items.extend(_flatten_dict(value, col_name)) + elif isinstance(value, list): + # Convert lists to string representation + items.append((col_name, str(value))) + else: + items.append((col_name, value)) + return items + + +def _get_meta_table_from(group_name, doc_id, field_value): + """Create a meta-table from a field group, matching MATLAB's getMetaTableFrom.""" + table = {"name": group_name, "columns": [new_column("doc_id", doc_id)]} + + if isinstance(field_value, dict): + for col_name, col_value in _flatten_dict(field_value): + table["columns"].append(new_column(col_name, col_value)) + + return table + + +# Fields to skip when building per-group meta-tables +_SKIP_FIELDS = { + "classname", + "document_class", + "superclasses", + "depends_on", + "files", + "file", +} + def doc_to_sql(doc): + """Convert a document to SQL meta-tables matching MATLAB's did.implementations.doc2sql. + + Returns a list of meta-table dicts. The first is always 'meta' with standard + columns (doc_id, class, superclass, datestamp, creation, deletion, depends_on). + Subsequent tables correspond to top-level field groups (e.g., 'base', 'element'). + + Each meta-table has: + 'name': the group name + 'columns': list of column dicts with 'name' and 'value' + """ doc_props = doc.document_properties - sql_meta_data = { - 'name': 'meta', - 'columns': [] - } + # Build the 'meta' table + meta = {"name": "meta", "columns": []} + + id_val = get_field(doc_props, ["base.id", "ndi_document.id"]) + meta["columns"].append(new_column("doc_id", id_val)) - id_val = get_field(doc_props, ['base.id', 'ndi_document.id']) - sql_meta_data['columns'].append(new_column('doc_id', id_val)) + class_name = _get_class_name(doc_props) + meta["columns"].append(new_column("class", class_name)) - class_name = get_field(doc_props, ['document_class.class_name', 'ndi_document.type']) - sql_meta_data['columns'].append(new_column('class', class_name)) + superclass = _get_superclass_str(doc_props) + meta["columns"].append(new_column("superclass", superclass)) - # Simplified superclass and dependency handling - # A full implementation would parse these structures more carefully. - sql_meta_data['columns'].append(new_column('superclass', '')) - sql_meta_data['columns'].append(new_column('depends_on', '')) + datestamp = get_field(doc_props, ["base.datestamp", "ndi_document.datestamp"]) + meta["columns"].append(new_column("datestamp", datestamp)) - datestamp = get_field(doc_props, ['base.datestamp', 'ndi_document.datestamp']) - sql_meta_data['columns'].append(new_column('datestamp', datestamp)) + meta["columns"].append(new_column("creation", "")) + meta["columns"].append(new_column("deletion", "")) - # Process other fields - other_meta_data = [] + depends_on_str = _serialize_depends_on(doc_props) + meta["columns"].append(new_column("depends_on", depends_on_str)) + + meta_tables = [meta] + + # Build per-group tables for all other top-level dict fields for field_name, field_value in doc_props.items(): - if field_name not in ['base', 'document_class', 'depends_on', 'files']: - meta_table = { - 'name': field_name, - 'columns': [new_column('doc_id', id_val)] - } - if isinstance(field_value, dict): - for sub_field_name, sub_field_value in field_value.items(): - meta_table['columns'].append(new_column(sub_field_name, sub_field_value)) - other_meta_data.append(meta_table) - - return [sql_meta_data] + other_meta_data \ No newline at end of file + if field_name in _SKIP_FIELDS: + continue + if isinstance(field_value, dict): + table = _get_meta_table_from(field_name, id_val, field_value) + meta_tables.append(table) + + return meta_tables diff --git a/src/did/implementations/sqlitedb.py b/src/did/implementations/sqlitedb.py index 65030a0..7429bb2 100644 --- a/src/did/implementations/sqlitedb.py +++ b/src/did/implementations/sqlitedb.py @@ -1,11 +1,31 @@ import sqlite3 import os +import re as _re from ..database import Database + +def _sqlite_regexp(pattern, string): + """SQLite regexp function implementation.""" + if string is None: + return None + try: + return 1 if _re.search(pattern, str(string)) else None + except _re.error: + return None + + +def _sql_escape(value): + """Escape single quotes for SQL string literals.""" + if value is None: + return "" + return str(value).replace("'", "''") + + class SQLiteDB(Database): def __init__(self, filename): super().__init__(connection=filename) self.dbid = None + self._fields_cache = {} # (class, field_name) -> field_idx self._open_db() def _open_db(self): @@ -29,7 +49,7 @@ def _create_db_tables(self): cursor = self.dbid.cursor() # Create branches table - cursor.execute(''' + cursor.execute(""" CREATE TABLE branches ( branch_id TEXT NOT NULL UNIQUE, parent_id TEXT, @@ -37,10 +57,10 @@ def _create_db_tables(self): FOREIGN KEY(parent_id) REFERENCES branches(branch_id), PRIMARY KEY(branch_id) ) - ''') + """) # Create docs table - cursor.execute(''' + cursor.execute(""" CREATE TABLE docs ( doc_id TEXT NOT NULL UNIQUE, doc_idx INTEGER NOT NULL UNIQUE, @@ -48,10 +68,10 @@ def _create_db_tables(self): timestamp REAL, PRIMARY KEY(doc_idx AUTOINCREMENT) ) - ''') + """) # Create branch_docs table - cursor.execute(''' + cursor.execute(""" CREATE TABLE branch_docs ( branch_id TEXT NOT NULL, doc_idx INTEGER NOT NULL, @@ -60,10 +80,10 @@ def _create_db_tables(self): FOREIGN KEY(doc_idx) REFERENCES docs(doc_idx), PRIMARY KEY(branch_id, doc_idx) ) - ''') + """) # Create fields table - cursor.execute(''' + cursor.execute(""" CREATE TABLE fields ( class TEXT NOT NULL, field_name TEXT NOT NULL UNIQUE, @@ -71,10 +91,10 @@ class TEXT NOT NULL, field_idx INTEGER NOT NULL UNIQUE, PRIMARY KEY(field_idx AUTOINCREMENT) ) - ''') + """) # Create doc_data table - cursor.execute(''' + cursor.execute(""" CREATE TABLE doc_data ( doc_idx INTEGER NOT NULL, field_idx INTEGER NOT NULL, @@ -82,10 +102,10 @@ class TEXT NOT NULL, FOREIGN KEY(doc_idx) REFERENCES docs(doc_idx), FOREIGN KEY(field_idx) REFERENCES fields(field_idx) ) - ''') + """) # Create files table - cursor.execute(''' + cursor.execute(""" CREATE TABLE files ( doc_idx INTEGER NOT NULL, filename TEXT NOT NULL, @@ -97,7 +117,7 @@ class TEXT NOT NULL, FOREIGN KEY(doc_idx) REFERENCES docs(doc_idx), PRIMARY KEY(doc_idx, filename, uid) ) - ''') + """) self.dbid.commit() @@ -110,68 +130,133 @@ def do_run_sql_query(self, query_str, params=()): # For brevity, I will start with a few key methods. def _do_get_branch_ids(self): - rows = self.do_run_sql_query('SELECT DISTINCT branch_id FROM branches') - return [row['branch_id'] for row in rows] + rows = self.do_run_sql_query("SELECT DISTINCT branch_id FROM branches") + return [row["branch_id"] for row in rows] def _do_add_branch(self, branch_id, parent_branch_id): import time + cursor = self.dbid.cursor() # Handle empty string parent as NULL - if parent_branch_id == '': + if parent_branch_id == "": parent_branch_id = None # Add the new branch - cursor.execute('INSERT INTO branches (branch_id, parent_id, timestamp) VALUES (?, ?, ?)', - (branch_id, parent_branch_id, time.time())) + cursor.execute( + "INSERT INTO branches (branch_id, parent_id, timestamp) VALUES (?, ?, ?)", + (branch_id, parent_branch_id, time.time()), + ) # Copy docs from parent branch if parent_branch_id: - cursor.execute('SELECT doc_idx FROM branch_docs WHERE branch_id = ?', (parent_branch_id,)) - doc_indices = [row['doc_idx'] for row in cursor.fetchall()] + cursor.execute( + "SELECT doc_idx FROM branch_docs WHERE branch_id = ?", + (parent_branch_id,), + ) + doc_indices = [row["doc_idx"] for row in cursor.fetchall()] for doc_idx in doc_indices: - cursor.execute('INSERT OR IGNORE INTO branch_docs (branch_id, doc_idx, timestamp) VALUES (?, ?, ?)', - (branch_id, doc_idx, time.time())) + cursor.execute( + "INSERT OR IGNORE INTO branch_docs (branch_id, doc_idx, timestamp) VALUES (?, ?, ?)", + (branch_id, doc_idx, time.time()), + ) self.dbid.commit() def _do_get_doc_ids(self, branch_id=None): if branch_id: - rows = self.do_run_sql_query('SELECT d.doc_id FROM docs d JOIN branch_docs bd ON d.doc_idx = bd.doc_idx WHERE bd.branch_id = ?', (branch_id,)) + rows = self.do_run_sql_query( + "SELECT d.doc_id FROM docs d JOIN branch_docs bd ON d.doc_idx = bd.doc_idx WHERE bd.branch_id = ?", + (branch_id,), + ) + else: + rows = self.do_run_sql_query("SELECT doc_id FROM docs") + return [row["doc_id"] for row in rows] + + def _get_field_idx(self, cursor, group_name, field_name): + """Look up or create a field_idx for the given group and field. + + The field_name in the fields table uses the format '{group}.{field}', + matching MATLAB's convention. Triple-underscores in column names from + doc2sql are converted to dots. + """ + # Convert ___ back to . for the stored field_name + full_field_name = f"{group_name}.{field_name}".replace("___", ".") + json_name = full_field_name.replace(".", "___") + + cache_key = (group_name, full_field_name) + if cache_key in self._fields_cache: + return self._fields_cache[cache_key] + + cursor.execute( + "SELECT field_idx FROM fields WHERE field_name = ?", (full_field_name,) + ) + row = cursor.fetchone() + if row: + field_idx = row["field_idx"] else: - rows = self.do_run_sql_query('SELECT doc_id FROM docs') - return [row['doc_id'] for row in rows] + cursor.execute( + "INSERT INTO fields (class, field_name, json_name, field_idx) VALUES (?, ?, ?, NULL)", + (group_name, full_field_name, json_name), + ) + field_idx = cursor.lastrowid + + self._fields_cache[cache_key] = field_idx + return field_idx + + def _populate_doc_data(self, cursor, doc_idx, document_obj): + """Flatten document via doc2sql and insert into fields/doc_data tables.""" + from .doc2sql import doc_to_sql + + meta_tables = doc_to_sql(document_obj) + rows = [] + + for table in meta_tables: + group_name = table["name"] + for col in table["columns"]: + col_name = col["name"] + if col_name == "doc_id": + continue # skip doc_id columns + field_idx = self._get_field_idx(cursor, group_name, col_name) + value = col["value"] + if value is None: + value = "" + rows.append((doc_idx, field_idx, str(value))) + + if rows: + cursor.executemany( + "INSERT INTO doc_data (doc_idx, field_idx, value) VALUES (?, ?, ?)", + rows, + ) def _do_add_doc(self, document_obj, branch_id, **kwargs): - # This is a complex method that involves multiple steps: - # 1. Check if the document already exists. - # 2. If not, add it to the 'docs' table and get its 'doc_idx'. - # 3. Add the document's fields to the 'doc_data' table. - # 4. Add the document reference to the 'branch_docs' table. - # This is a simplified placeholder. import json import time - from ..document import Document doc_id = document_obj.id() cursor = self.dbid.cursor() - cursor.execute('SELECT doc_idx FROM docs WHERE doc_id = ?', (doc_id,)) + cursor.execute("SELECT doc_idx FROM docs WHERE doc_id = ?", (doc_id,)) row = cursor.fetchone() if row: - doc_idx = row['doc_idx'] + doc_idx = row["doc_idx"] else: json_code = json.dumps(document_obj.document_properties) - cursor.execute('INSERT INTO docs (doc_id, json_code, timestamp) VALUES (?, ?, ?)', - (doc_id, json_code, time.time())) + cursor.execute( + "INSERT INTO docs (doc_id, json_code, timestamp) VALUES (?, ?, ?)", + (doc_id, json_code, time.time()), + ) doc_idx = cursor.lastrowid - # Simplified field insertion - # A full implementation would parse the document and insert into doc_data + + # Populate fields and doc_data tables (matching MATLAB's doc2sql behavior) + self._populate_doc_data(cursor, doc_idx, document_obj) try: - cursor.execute('INSERT INTO branch_docs (branch_id, doc_idx, timestamp) VALUES (?, ?, ?)', - (branch_id, doc_idx, time.time())) + cursor.execute( + "INSERT INTO branch_docs (branch_id, doc_idx, timestamp) VALUES (?, ?, ?)", + (branch_id, doc_idx, time.time()), + ) self.dbid.commit() except sqlite3.IntegrityError as e: if "FOREIGN KEY" in str(e): @@ -179,22 +264,212 @@ def _do_add_doc(self, document_obj, branch_id, **kwargs): # Ignore other integrity errors (duplicates) pass - def _do_get_doc(self, document_id, OnMissing='error', **kwargs): + # --- SQL-based search (matching MATLAB's database.m) --- + + def search(self, query_obj, branch_id=None): + """Search using SQL queries against doc_data, matching MATLAB's behavior.""" + if branch_id is None: + branch_id = self.current_branch_id + + search_params = query_obj.to_search_structure() + + # Register regexp function for sqlite + self.dbid.create_function("regexp", 2, _sqlite_regexp) + + doc_ids = self._search_doc_ids(search_params, branch_id) + return doc_ids + + def _search_doc_ids(self, search_struct, branch_id): + """Recursively search for doc_ids matching the search structure. + + Matches MATLAB's search_doc_ids: struct arrays are AND'd, 'or' operations + are unioned, leaf queries go through SQL. + """ + if isinstance(search_struct, list): + if not search_struct: + return [] + # AND: intersect results from all sub-queries + result = None + for item in search_struct: + ids = self._search_doc_ids(item, branch_id) + if result is None: + result = set(ids) + else: + result &= set(ids) + return list(result) if result else [] + + if not isinstance(search_struct, dict): + return [] + + operation = search_struct.get("operation", "") + negation = False + op = operation + if op.startswith("~"): + negation = True + op = op[1:] + op_lower = op.lower() + + if op_lower == "or": + # OR: union results from param1 and param2 + p1 = search_struct.get("param1") + p2 = search_struct.get("param2") + ids1 = self._search_doc_ids(p1, branch_id) if p1 else [] + ids2 = self._search_doc_ids(p2, branch_id) if p2 else [] + result = list(set(ids1) | set(ids2)) + if negation: + all_ids = set(self._do_get_doc_ids(branch_id)) + result = list(all_ids - set(result)) + return result + + # Leaf query: build SQL and execute + sql_clause = self._query_struct_to_sql_str(search_struct) + if sql_clause is None: + # Fallback to brute-force for unsupported operations + return self._brute_force_search(search_struct, branch_id) + + query = ( + "SELECT DISTINCT docs.doc_id FROM docs, branch_docs, doc_data, fields " + "WHERE docs.doc_idx = doc_data.doc_idx " + "AND docs.doc_idx = branch_docs.doc_idx " + "AND branch_docs.branch_id = ? " + "AND fields.field_idx = doc_data.field_idx " + f"AND {sql_clause}" + ) + + try: + rows = self.do_run_sql_query(query, (branch_id,)) + matched = [row["doc_id"] for row in rows] + except sqlite3.OperationalError: + # Fallback on SQL error + return self._brute_force_search(search_struct, branch_id) + + if negation: + all_ids = set(self._do_get_doc_ids(branch_id)) + return list(all_ids - set(matched)) + + return matched + + def _query_struct_to_sql_str(self, search_struct): + """Convert a single query struct to a SQL WHERE clause fragment. + + Returns None if the operation is not supported in SQL. + Matches MATLAB's query_struct_to_sql_str. + """ + field = search_struct.get("field", "") + operation = search_struct.get("operation", "") + param1 = search_struct.get("param1") + param2 = search_struct.get("param2") + + # Strip negation prefix (handled by caller) + op = operation + if op.startswith("~"): + op = op[1:] + op_lower = op.lower() + + if op_lower == "exact_string": + return f"fields.field_name = '{field}' AND doc_data.value = '{_sql_escape(param1)}'" + + elif op_lower == "exact_string_anycase": + return f"fields.field_name = '{field}' AND LOWER(doc_data.value) = LOWER('{_sql_escape(param1)}')" + + elif op_lower == "contains_string": + return f"fields.field_name = '{field}' AND doc_data.value LIKE '%{_sql_escape(param1)}%'" + + elif op_lower == "regexp": + return f"fields.field_name = '{field}' AND regexp('{_sql_escape(param1)}', doc_data.value) IS NOT NULL" + + elif op_lower == "exact_number": + return f"fields.field_name = '{field}' AND CAST(doc_data.value AS REAL) = {float(param1)}" + + elif op_lower == "lessthan": + return f"fields.field_name = '{field}' AND CAST(doc_data.value AS REAL) < {float(param1)}" + + elif op_lower == "lessthaneq": + return f"fields.field_name = '{field}' AND CAST(doc_data.value AS REAL) <= {float(param1)}" + + elif op_lower == "greaterthan": + return f"fields.field_name = '{field}' AND CAST(doc_data.value AS REAL) > {float(param1)}" + + elif op_lower == "greaterthaneq": + return f"fields.field_name = '{field}' AND CAST(doc_data.value AS REAL) >= {float(param1)}" + + elif op_lower == "hasfield": + return ( + f"(fields.field_name = '{field}' OR fields.field_name LIKE '{field}.%')" + ) + + elif op_lower == "isa": + # isa: match on meta.class (exact) OR meta.superclass (contains) + classname = _sql_escape(param1) + return ( + f"((fields.field_name = 'meta.class' AND doc_data.value = '{classname}') " + f"OR (fields.field_name = 'meta.superclass' AND " + f"regexp('(^|, ){classname}(,|$)', doc_data.value) IS NOT NULL))" + ) + + elif op_lower == "depends_on": + # depends_on: search meta.depends_on using LIKE '%name,value;%' + name = _sql_escape(param1) + value = _sql_escape(param2) + if name == "*": + return f"fields.field_name = 'meta.depends_on' AND doc_data.value LIKE '%,{value};%'" + return f"fields.field_name = 'meta.depends_on' AND doc_data.value LIKE '%{name},{value};%'" + + elif op_lower == "hasanysubfield_exact_string": + # Used by resolved depends_on - fall back to brute force + return None + + elif op_lower == "hasanysubfield_contains_string": + # Used by resolved isa - fall back to brute force + return None + + elif op_lower == "hasmember": + # hasmember on a stored value - fall back to brute force + return None + + elif op_lower == "hassize": + return None + + elif op_lower == "partial_struct": + return None + + return None + + def _brute_force_search(self, search_struct, branch_id): + """Fall back to brute-force field_search for unsupported SQL operations.""" + from ..datastructures import field_search + + doc_ids = self._do_get_doc_ids(branch_id) + docs = self.get_docs(doc_ids, OnMissing="ignore") + if docs is None: + docs = [] + if not isinstance(docs, list): + docs = [docs] + + matched = [] + for doc in docs: + if doc and field_search(doc.document_properties, search_struct): + matched.append(doc.id()) + return matched + + def _do_get_doc(self, document_id, OnMissing="error", **kwargs): from ..document import Document import json - row = self.do_run_sql_query('SELECT json_code FROM docs WHERE doc_id = ?', (document_id,)) + row = self.do_run_sql_query( + "SELECT json_code FROM docs WHERE doc_id = ?", (document_id,) + ) if row: - json_code = row[0]['json_code'] + json_code = row[0]["json_code"] doc_struct = json.loads(json_code) return Document(doc_struct) else: # Handle missing document - if OnMissing == 'warn': + if OnMissing == "warn": print(f"Warning: Document id '{document_id}' not found.") return None - elif OnMissing == 'ignore': + elif OnMissing == "ignore": return None else: raise ValueError(f"Document id '{document_id}' not found.") @@ -208,14 +483,14 @@ def open_doc(self, doc_id, filename): is_in, info, _ = doc.is_in_file_list(filename) if is_in: - location = info['locations']['location'] + location = info["locations"]["location"] - # Rebase path if it's relative, assuming it's relative to the DB location - if not os.path.isabs(location): - db_dir = os.path.dirname(os.path.abspath(self.connection)) - location = os.path.join(db_dir, location) + # Rebase path if it's relative, assuming it's relative to the DB location + if not os.path.isabs(location): + db_dir = os.path.dirname(os.path.abspath(self.connection)) + location = os.path.join(db_dir, location) - return ReadOnlyFileobj(location) + return ReadOnlyFileobj(location) raise FileNotFoundError(f"File {filename} not found in document {doc_id}.") @@ -223,47 +498,56 @@ def _do_remove_doc(self, document_id, branch_id, **kwargs): cursor = self.dbid.cursor() # Check if branch exists - cursor.execute('SELECT 1 FROM branches WHERE branch_id = ?', (branch_id,)) + cursor.execute("SELECT 1 FROM branches WHERE branch_id = ?", (branch_id,)) if not cursor.fetchone(): - raise ValueError(f"Branch '{branch_id}' does not exist.") + raise ValueError(f"Branch '{branch_id}' does not exist.") # Get doc_idx from doc_id - cursor.execute('SELECT doc_idx FROM docs WHERE doc_id = ?', (document_id,)) + cursor.execute("SELECT doc_idx FROM docs WHERE doc_id = ?", (document_id,)) row = cursor.fetchone() if row: - doc_idx = row['doc_idx'] + doc_idx = row["doc_idx"] # Remove from branch_docs - cursor.execute('DELETE FROM branch_docs WHERE branch_id = ? AND doc_idx = ?', (branch_id, doc_idx)) + cursor.execute( + "DELETE FROM branch_docs WHERE branch_id = ? AND doc_idx = ?", + (branch_id, doc_idx), + ) # Optional: remove from docs and doc_data if no other branches reference it - cursor.execute('SELECT COUNT(*) FROM branch_docs WHERE doc_idx = ?', (doc_idx,)) + cursor.execute( + "SELECT COUNT(*) FROM branch_docs WHERE doc_idx = ?", (doc_idx,) + ) count = cursor.fetchone()[0] if count == 0: - cursor.execute('DELETE FROM doc_data WHERE doc_idx = ?', (doc_idx,)) - cursor.execute('DELETE FROM docs WHERE doc_idx = ?', (doc_idx,)) + cursor.execute("DELETE FROM doc_data WHERE doc_idx = ?", (doc_idx,)) + cursor.execute("DELETE FROM docs WHERE doc_idx = ?", (doc_idx,)) self.dbid.commit() else: # Handle missing document - on_missing = kwargs.get('OnMissing', 'error').lower() - if on_missing == 'warn': + on_missing = kwargs.get("OnMissing", "error").lower() + if on_missing == "warn": print(f"Warning: Document id '{document_id}' not found for removal.") - elif on_missing != 'ignore': + elif on_missing != "ignore": raise ValueError(f"Document id '{document_id}' not found for removal.") def _do_delete_branch(self, branch_id): cursor = self.dbid.cursor() - cursor.execute('DELETE FROM branch_docs WHERE branch_id = ?', (branch_id,)) - cursor.execute('DELETE FROM branches WHERE branch_id = ?', (branch_id,)) + cursor.execute("DELETE FROM branch_docs WHERE branch_id = ?", (branch_id,)) + cursor.execute("DELETE FROM branches WHERE branch_id = ?", (branch_id,)) self.dbid.commit() def _do_get_sub_branches(self, branch_id): - rows = self.do_run_sql_query('SELECT branch_id FROM branches WHERE parent_id = ?', (branch_id,)) - return [row['branch_id'] for row in rows] + rows = self.do_run_sql_query( + "SELECT branch_id FROM branches WHERE parent_id = ?", (branch_id,) + ) + return [row["branch_id"] for row in rows] def _do_get_branch_parent(self, branch_id): - row = self.do_run_sql_query('SELECT parent_id FROM branches WHERE branch_id = ?', (branch_id,)) + row = self.do_run_sql_query( + "SELECT parent_id FROM branches WHERE branch_id = ?", (branch_id,) + ) if row: - return row[0]['parent_id'] - return None \ No newline at end of file + return row[0]["parent_id"] + return None diff --git a/src/did/query.py b/src/did/query.py index aa66dd7..76e59ba 100644 --- a/src/did/query.py +++ b/src/did/query.py @@ -1,14 +1,28 @@ class Query: VALID_OPS = { - 'regexp', 'exact_string', 'exact_string_anycase', 'contains_string', 'exact_number', - 'lessthan', 'lessthaneq', 'greaterthan', 'greaterthaneq', 'hassize', 'hasmember', - 'hasfield', 'partial_struct', 'hasanysubfield_contains_string', 'hasanysubfield_exact_string', - 'or', 'depends_on', 'isa' + "regexp", + "exact_string", + "exact_string_anycase", + "contains_string", + "exact_number", + "lessthan", + "lessthaneq", + "greaterthan", + "greaterthaneq", + "hassize", + "hasmember", + "hasfield", + "partial_struct", + "hasanysubfield_contains_string", + "hasanysubfield_exact_string", + "or", + "depends_on", + "isa", } def __init__(self, field=None, op=None, param1=None, param2=None): if op: - check_op = op[1:] if op.startswith('~') else op + check_op = op[1:] if op.startswith("~") else op if check_op.lower() not in self.VALID_OPS: raise ValueError(f"Invalid operator: {op}") @@ -17,12 +31,14 @@ def __init__(self, field=None, op=None, param1=None, param2=None): elif isinstance(field, list): self.search_structure = self.search_cell_array_to_search_structure(field) elif field is not None: - self.search_structure = self._create_search_structure(field, op, param1, param2) + self.search_structure = self._create_search_structure( + field, op, param1, param2 + ) else: self.search_structure = [] def _create_search_structure(self, field, op, param1, param2): - return [{'field': field, 'operation': op, 'param1': param1, 'param2': param2}] + return [{"field": field, "operation": op, "param1": param1, "param2": param2}] @staticmethod def search_cell_array_to_search_structure(search_cell_array): @@ -30,9 +46,11 @@ def search_cell_array_to_search_structure(search_cell_array): search_structure = [] for i in range(0, len(search_cell_array), 2): field = search_cell_array[i] - value = search_cell_array[i+1] - op = 'exact_number' if isinstance(value, (int, float)) else 'regexp' - search_structure.append({'field': field, 'operation': op, 'param1': value, 'param2': None}) + value = search_cell_array[i + 1] + op = "exact_number" if isinstance(value, (int, float)) else "regexp" + search_structure.append( + {"field": field, "operation": op, "param1": value, "param2": None} + ) return search_structure def __and__(self, other): @@ -47,9 +65,79 @@ def __or__(self, other): if not isinstance(other, Query): return NotImplemented - return Query(field='', op='or', param1=self.search_structure, param2=other.search_structure) + return Query( + field="", + op="or", + param1=self.search_structure, + param2=other.search_structure, + ) def to_search_structure(self): - # A full implementation would recursively resolve 'isa', 'depends_on', etc. - # This is a simplified version for now. - return self.search_structure \ No newline at end of file + """Resolve high-level operations (isa, depends_on) into lower-level ones. + + This matches MATLAB's query.to_searchstructure which converts: + - 'isa' -> OR(hasanysubfield_contains_string on superclasses, exact_string on class) + - 'depends_on' -> hasanysubfield_exact_string on depends_on + """ + return self._resolve_search_structure(self.search_structure) + + @staticmethod + def _resolve_search_structure(ss): + """Recursively resolve a search structure.""" + if isinstance(ss, list): + return [Query._resolve_single(item) for item in ss] + return Query._resolve_single(ss) + + @staticmethod + def _resolve_single(item): + if not isinstance(item, dict): + return item + + operation = item.get("operation", "") + negation = False + op = operation + if op.startswith("~"): + negation = True + op = op[1:] + op_lower = op.lower() + + if op_lower == "isa": + # We keep 'isa' unresolved for field_search (which handles it directly) + # and let the SQL search handle it via its own isa logic. + # This avoids breaking the brute-force field_search path which + # works with both DID-python and MATLAB document formats. + return item + + elif op_lower == "depends_on": + name_param = item.get("param1", "") + value_param = item.get("param2", "") + param1_list = ["name", "value"] + param2_list = [name_param, value_param] + # Wildcard: if name is '*', only match on value + if name_param == "*": + param1_list = ["value"] + param2_list = [value_param] + resolved = { + "field": "depends_on", + "operation": ( + "~hasanysubfield_exact_string" + if negation + else "hasanysubfield_exact_string" + ), + "param1": param1_list, + "param2": param2_list, + } + return resolved + + elif op_lower == "or": + # Recursively resolve OR sub-structures + p1 = item.get("param1") + p2 = item.get("param2") + return { + "field": item.get("field", ""), + "operation": operation, + "param1": Query._resolve_search_structure(p1) if p1 else p1, + "param2": Query._resolve_search_structure(p2) if p2 else p2, + } + + return item diff --git a/tests/__init__.py b/tests/__init__.py index 2e70fad..e3eebaa 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1 +1 @@ -# This file makes the 'tests' directory a Python package. \ No newline at end of file +# This file makes the 'tests' directory a Python package. diff --git a/tests/helpers.py b/tests/helpers.py index f6c0cbe..b93835d 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -3,6 +3,7 @@ import networkx as nx from did.document import Document + def make_doc_tree(rates): """ Makes a 'tree' of documents to add to a database. @@ -21,7 +22,7 @@ def make_doc_tree(rates): counter = 1 for _ in range(num_a): - doc = Document('demoA', **{'demoA.value': counter}) + doc = Document("demoA", **{"demoA.value": counter}) docs.append(doc) node_names.append(str(counter)) ids_a.append(doc.id()) @@ -29,7 +30,7 @@ def make_doc_tree(rates): counter += 1 for _ in range(num_b): - doc = Document('demoB', **{'demoB.value': counter, 'demoA.value': counter}) + doc = Document("demoB", **{"demoB.value": counter, "demoA.value": counter}) docs.append(doc) node_names.append(str(counter)) ids_b.append(doc.id()) @@ -38,7 +39,7 @@ def make_doc_tree(rates): c_count = 0 for _ in range(num_c): - doc = Document('demoC', **{'demoC.value': counter}) + doc = Document("demoC", **{"demoC.value": counter}) docs.append(doc) node_names.append(str(counter)) ids_c.append(doc.id()) @@ -49,30 +50,31 @@ def make_doc_tree(rates): dep_c = random.randint(0, c_count - 1) if c_count > 0 else -1 if dep_a >= 0: - doc.set_dependency_value('item1', ids_a[dep_a], error_if_not_found=False) + doc.set_dependency_value("item1", ids_a[dep_a], error_if_not_found=False) g.add_edge(ids_a[dep_a], doc.id()) if dep_b >= 0: - doc.set_dependency_value('item2', ids_b[dep_b], error_if_not_found=False) + doc.set_dependency_value("item2", ids_b[dep_b], error_if_not_found=False) g.add_edge(ids_b[dep_b], doc.id()) if dep_c >= 0: - doc.set_dependency_value('item3', ids_c[dep_c], error_if_not_found=False) + doc.set_dependency_value("item3", ids_c[dep_c], error_if_not_found=False) g.add_edge(ids_c[dep_c], doc.id()) - if 'depends_on' not in doc.document_properties: - doc.document_properties['depends_on'] = [] + if "depends_on" not in doc.document_properties: + doc.document_properties["depends_on"] = [] counter += 1 c_count += 1 return g, node_names, docs -def verify_db_document_structure(db, g, expected_docs, OnMissing='error'): + +def verify_db_document_structure(db, g, expected_docs, OnMissing="error"): """ Verifies that the documents in the database match the expected documents and their relationships. """ from did.datastructures import eq_len - fieldset = ['demoA', 'demoB', 'demoC'] + fieldset = ["demoA", "demoB", "demoC"] for doc in expected_docs: id_here = doc.id() @@ -88,13 +90,19 @@ def verify_db_document_structure(db, g, expected_docs, OnMissing='error'): field1 = doc.document_properties[field] field2 = doc_here.document_properties[field] if not eq_len(field1, field2): - return False, f"Field {field} of document {id_here} did not match." + return ( + False, + f"Field {field} of document {id_here} did not match.", + ) else: - return False, f"Field {field} not found in document {id_here} from the database." + return ( + False, + f"Field {field} not found in document {id_here} from the database.", + ) errors = [] - fieldset = ['demoA', 'demoB', 'demoC'] + fieldset = ["demoA", "demoB", "demoC"] for doc in expected_docs: id_here = doc.id() @@ -104,7 +112,7 @@ def verify_db_document_structure(db, g, expected_docs, OnMissing='error'): doc_here = None if doc_here is None: - if OnMissing.lower() == 'ignore': + if OnMissing.lower() == "ignore": continue else: errors.append(f"Document with id {id_here} not found in the database.") @@ -117,12 +125,17 @@ def verify_db_document_structure(db, g, expected_docs, OnMissing='error'): field1 = doc.document_properties[field] field2 = doc_here.document_properties[field] if not eq_len(field1, field2): - errors.append(f"Field {field} of document {id_here} did not match.") + errors.append( + f"Field {field} of document {id_here} did not match." + ) else: - errors.append(f"Field {field} not found in document {id_here} from the database.") + errors.append( + f"Field {field} not found in document {id_here} from the database." + ) return not errors, "\n".join(errors) + def number_to_alpha_label(n): """ Converts a number to an alphabetic label. @@ -133,6 +146,7 @@ def number_to_alpha_label(n): s = chr(65 + remainder) + s return s.lower() + def name_tree(g, initial_node_name_prefix="", node_start=None): """ Names the nodes in a tree structure. @@ -163,6 +177,7 @@ def name_tree(g, initial_node_name_prefix="", node_start=None): return node_names + def make_tree(n_initial, children_rate, children_rate_decay, max_depth): """ Constructs a random tree structure. @@ -179,7 +194,12 @@ def make_tree(n_initial, children_rate, children_rate_decay, max_depth): num_children_here = np.random.poisson(children_rate) if num_children_here > 0: - sub_g = make_tree(num_children_here, children_rate * children_rate_decay, children_rate_decay, max_depth - 1) + sub_g = make_tree( + num_children_here, + children_rate * children_rate_decay, + children_rate_decay, + max_depth - 1, + ) # Renumber nodes in sub_g to avoid conflicts mapping = {n: n + current_node_count for n in sub_g.nodes()} @@ -194,6 +214,7 @@ def make_tree(n_initial, children_rate, children_rate_decay, max_depth): return g + def add_branch_nodes(db, starting_db_branch_id, g, node_names, node_start_index=None): """ Adds a tree of nodes to a DID database. @@ -218,6 +239,7 @@ def add_branch_nodes(db, starting_db_branch_id, g, node_names, node_start_index= for child in children: q.append((node_name, child)) + def verify_branch_nodes(db, g, node_names): """ Verifies all branch nodes in a digraph are in the database. @@ -226,6 +248,7 @@ def verify_branch_nodes(db, g, node_names): missing = set(node_names.values()) - set(all_branches) return not missing, list(missing) + def verify_branch_node_structure(db, g, node_names): """ Verifies branch structure in a digraph are in the database. @@ -234,8 +257,12 @@ def verify_branch_node_structure(db, g, node_names): for node, node_name in node_names.items(): # Get expected parents and children from the graph - expected_parents = [node_names[p] for p in g.predecessors(node)] if node in g else [] - expected_children = [node_names[c] for c in g.successors(node)] if node in g else [] + expected_parents = ( + [node_names[p] for p in g.predecessors(node)] if node in g else [] + ) + expected_children = ( + [node_names[c] for c in g.successors(node)] if node in g else [] + ) # Get actual parents and children from the database actual_parent = db.get_branch_parent(node_name) @@ -248,15 +275,22 @@ def verify_branch_node_structure(db, g, node_names): # Compare parents if set(expected_parents) != set(actual_parent): - return False, f"Error in parent of {node_name}. Expected {expected_parents}, got {actual_parent}" + return ( + False, + f"Error in parent of {node_name}. Expected {expected_parents}, got {actual_parent}", + ) # Compare children if set(expected_children) != set(actual_children): - return False, f"Error in sub_branch of {node_name}. Expected {expected_children}, got {actual_children}" + return ( + False, + f"Error in sub_branch of {node_name}. Expected {expected_children}, got {actual_children}", + ) db.set_branch(current_branch) return True, "" + def delete_random_branch(db, g, node_names): """ Deletes a random branch from a database and digraph. @@ -274,17 +308,19 @@ def delete_random_branch(db, g, node_names): return g, node_names + def get_demo_type(doc): """ Finds the first demo class for a given document. """ - if 'demoA' in doc.document_properties: - return 'demoA' - elif 'demoB' in doc.document_properties: - return 'demoB' - elif 'demoC' in doc.document_properties: - return 'demoC' - return '' + if "demoA" in doc.document_properties: + return "demoA" + elif "demoB" in doc.document_properties: + return "demoB" + elif "demoC" in doc.document_properties: + return "demoC" + return "" + def apply_did_query(docs, q): """ diff --git a/tests/test_branch.py b/tests/test_branch.py index 1930ce1..fcb7487 100644 --- a/tests/test_branch.py +++ b/tests/test_branch.py @@ -1,11 +1,18 @@ import unittest import os -import networkx as nx from did.implementations.sqlitedb import SQLiteDB -from tests.helpers import make_tree, name_tree, add_branch_nodes, verify_branch_nodes, verify_branch_node_structure, delete_random_branch +from tests.helpers import ( + make_tree, + name_tree, + add_branch_nodes, + verify_branch_nodes, + verify_branch_node_structure, + delete_random_branch, +) + class TestBranch(unittest.TestCase): - DB_FILENAME = 'test_db_branch.sqlite' + DB_FILENAME = "test_db_branch.sqlite" def setUp(self): # Create a temporary working directory to run tests in @@ -19,7 +26,7 @@ def tearDown(self): os.remove(self.DB_FILENAME) def test_add_and_verify_branch_nodes(self): - add_branch_nodes(self.db, '', self.g, self.node_names) + add_branch_nodes(self.db, "", self.g, self.node_names) b, missing = verify_branch_nodes(self.db, self.g, self.node_names) self.assertTrue(b, f"Some branches are missing: {missing}") @@ -28,15 +35,18 @@ def test_add_and_verify_branch_nodes(self): self.assertTrue(b, msg) def test_random_branch_deletions(self): - add_branch_nodes(self.db, '', self.g, self.node_names) + add_branch_nodes(self.db, "", self.g, self.node_names) num_random_deletions = min(35, len(self.g.nodes())) for _ in range(num_random_deletions): - self.g, self.node_names = delete_random_branch(self.db, self.g, self.node_names) + self.g, self.node_names = delete_random_branch( + self.db, self.g, self.node_names + ) b, msg = verify_branch_node_structure(self.db, self.g, self.node_names) self.assertTrue(b, f"After random deletions: {msg}") -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index dd33a2c..6893ab0 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,21 +1,39 @@ import unittest -from did.datastructures import * +from did.datastructures import ( + cell_or_item, + cell_to_str, + col_vec, + empty_struct, + eq_emp, + eq_len, + eq_tot, + eq_unique, + field_search, + find_closest, + is_empty, + is_full_field, + json_encode_nan, + size_eq, + struct_merge, + struct_partial_match, +) + class TestDataStructures(unittest.TestCase): def test_cell_to_str(self): - self.assertEqual(cell_to_str(['a', 'b', 'c']), '["a", "b", "c"]') - self.assertEqual(cell_to_str([]), '[]') + self.assertEqual(cell_to_str(["a", "b", "c"]), '["a", "b", "c"]') + self.assertEqual(cell_to_str([]), "[]") def test_cell_or_item(self): - self.assertEqual(cell_or_item(['a', 'b', 'c'], 1), 'b') - self.assertEqual(cell_or_item('a'), 'a') + self.assertEqual(cell_or_item(["a", "b", "c"], 1), "b") + self.assertEqual(cell_or_item("a"), "a") def test_col_vec(self): self.assertEqual(col_vec([1, 2, 3]), [1, 2, 3]) self.assertEqual(col_vec([[1, 2], [3, 4]]), [1, 3, 2, 4]) def test_empty_struct(self): - self.assertEqual(empty_struct('a', 'b'), {}) + self.assertEqual(empty_struct("a", "b"), {}) def test_is_empty(self): self.assertTrue(is_empty(None)) @@ -46,22 +64,24 @@ def test_eq_unique(self): self.assertEqual(eq_unique([[1, 2], [1, 2], [1, 3]]), [[1, 2], [1, 3]]) def test_is_full_field(self): - d = {'a': {'b': {'c': 1}}} - self.assertTrue(is_full_field(d, 'a.b.c')[0]) - self.assertFalse(is_full_field(d, 'a.b.d')[0]) + d = {"a": {"b": {"c": 1}}} + self.assertTrue(is_full_field(d, "a.b.c")[0]) + self.assertFalse(is_full_field(d, "a.b.d")[0]) def test_struct_partial_match(self): - a = {'a': 1, 'b': 2} - b = {'a': 1} - c = {'a': 2} + a = {"a": 1, "b": 2} + b = {"a": 1} + c = {"a": 2} self.assertTrue(struct_partial_match(a, b)) self.assertFalse(struct_partial_match(a, c)) def test_field_search(self): - a = {'a': 1, 'b': 'hello'} - search_struct = [{'field': 'a', 'operation': 'exact_number', 'param1': 1}] + a = {"a": 1, "b": "hello"} + search_struct = [{"field": "a", "operation": "exact_number", "param1": 1}] self.assertTrue(field_search(a, search_struct)) - search_struct = [{'field': 'b', 'operation': 'contains_string', 'param1': 'ell'}] + search_struct = [ + {"field": "b", "operation": "contains_string", "param1": "ell"} + ] self.assertTrue(field_search(a, search_struct)) def test_find_closest(self): @@ -70,15 +90,16 @@ def test_find_closest(self): self.assertEqual(find_closest(arr, 14), (3, 15)) def test_json_encode_nan(self): - d = {'a': 1, 'b': float('nan')} - self.assertIn('NaN', json_encode_nan(d)) + d = {"a": 1, "b": float("nan")} + self.assertIn("NaN", json_encode_nan(d)) def test_struct_merge(self): - s1 = {'a': 1, 'b': 2} - s2 = {'b': 3, 'c': 4} - self.assertEqual(struct_merge(s1, s2), {'a': 1, 'b': 3, 'c': 4}) + s1 = {"a": 1, "b": 2} + s2 = {"b": 3, "c": 4} + self.assertEqual(struct_merge(s1, s2), {"a": 1, "b": 3, "c": 4}) with self.assertRaises(ValueError): struct_merge(s1, s2, error_if_new_field=True) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_db_queries.py b/tests/test_db_queries.py index 62b83ee..9bc6c12 100644 --- a/tests/test_db_queries.py +++ b/tests/test_db_queries.py @@ -3,10 +3,15 @@ import random from did.implementations.sqlitedb import SQLiteDB from did.query import Query -from tests.helpers import make_doc_tree, verify_db_document_structure, get_demo_type, apply_did_query +from tests.helpers import ( + make_doc_tree, + get_demo_type, + apply_did_query, +) + class TestDbQueries(unittest.TestCase): - DB_FILENAME = 'test_db_queries.sqlite' + DB_FILENAME = "test_db_queries.sqlite" db = None docs = None @@ -17,10 +22,10 @@ def setUpClass(cls): os.remove(cls.DB_FILENAME) cls.db = SQLiteDB(cls.DB_FILENAME) - cls.db.add_branch('a') + cls.db.add_branch("a") _, _, cls.docs = make_doc_tree([10, 10, 10]) for doc in cls.docs: - cls.db._do_add_doc(doc, 'a') + cls.db._do_add_doc(doc, "a") @classmethod def tearDownClass(cls): @@ -29,7 +34,7 @@ def tearDownClass(cls): os.remove(cls.DB_FILENAME) def _test_query(self, q): - ids_actual = self.db.search(q, branch_id='a') + ids_actual = self.db.search(q, branch_id="a") self.assertIsInstance(ids_actual, list) ids_expected, _ = apply_did_query(self.docs, q) @@ -41,19 +46,21 @@ def get_random_document_id(self): def test_exact_string(self): id_chosen = self.get_random_document_id() - q = Query('base.id', 'exact_string', id_chosen) + q = Query("base.id", "exact_string", id_chosen) self._test_query(q) def test_not_exact_string(self): id_chosen = self.get_random_document_id() - q = Query('base.id', '~exact_string', id_chosen) + q = Query("base.id", "~exact_string", id_chosen) self._test_query(q) def test_and(self): doc_id = self.docs[0].id() demo_type = get_demo_type(self.docs[0]) field_name = f"{demo_type}.value" - q = Query('base.id', 'exact_string', doc_id) & Query(field_name, 'exact_number', 1) + q = Query("base.id", "exact_string", doc_id) & Query( + field_name, "exact_number", 1 + ) self._test_query(q) def test_or(self): @@ -63,69 +70,75 @@ def test_or(self): demo_type2 = get_demo_type(doc2) field_name1 = f"{demo_type1}.value" field_name2 = f"{demo_type2}.value" - q = Query(field_name1, 'exact_number', 1) | Query(field_name2, 'exact_number', 2) + q = Query(field_name1, "exact_number", 1) | Query( + field_name2, "exact_number", 2 + ) self._test_query(q) def test_contains_string(self): id_chosen = self.get_random_document_id() sub_string = id_chosen[10:12] - q = Query('base.id', 'contains_string', sub_string) + q = Query("base.id", "contains_string", sub_string) self._test_query(q) def test_do_not_contains_string(self): id_chosen = self.get_random_document_id() sub_string = id_chosen[10:12] - q = Query('base.id', '~contains_string', sub_string) + q = Query("base.id", "~contains_string", sub_string) self._test_query(q) def test_less_than(self): number_chosen = random.randint(1, 100) - q = Query('demoA.value', 'lessthan', number_chosen) + q = Query("demoA.value", "lessthan", number_chosen) self._test_query(q) def test_less_than_equal(self): number_chosen = 48 - q = Query('demoA.value', 'lessthaneq', number_chosen) + q = Query("demoA.value", "lessthaneq", number_chosen) self._test_query(q) def test_do_greater_than(self): number_chosen = 1 - q = Query('demoA.value', 'greaterthan', number_chosen) + q = Query("demoA.value", "greaterthan", number_chosen) self._test_query(q) def test_do_greater_than_equal(self): number_chosen = 1 - q = Query('demoA.value', 'greaterthaneq', number_chosen) + q = Query("demoA.value", "greaterthaneq", number_chosen) self._test_query(q) def test_has_field(self): - q = Query('demoA.value', 'hasfield') + q = Query("demoA.value", "hasfield") self._test_query(q) def test_has_member(self): - q = Query('demoA.value', 'hasmember', 1) + q = Query("demoA.value", "hasmember", 1) self._test_query(q) def test_depends_on(self): # Find a doc with a dependency doc_with_dep = None for doc in self.docs: - if 'depends_on' in doc.document_properties and doc.document_properties['depends_on']: + if ( + "depends_on" in doc.document_properties + and doc.document_properties["depends_on"] + ): doc_with_dep = doc break if doc_with_dep: - dep = doc_with_dep.document_properties['depends_on'][0] - q = Query('', 'depends_on', dep['name'], dep['value']) + dep = doc_with_dep.document_properties["depends_on"][0] + q = Query("", "depends_on", dep["name"], dep["value"]) self._test_query(q) def test_do_is_a(self): - q = Query('', 'isa', 'demoB') + q = Query("", "isa", "demoB") self._test_query(q) def test_do_reg_exp(self): - q = Query('base.datestamp', 'regexp', r'\d{4}-\d{2}-\d{2}') + q = Query("base.datestamp", "regexp", r"\d{4}-\d{2}-\d{2}") self._test_query(q) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_document.py b/tests/test_document.py index 58561ad..e3c283e 100644 --- a/tests/test_document.py +++ b/tests/test_document.py @@ -2,52 +2,81 @@ import os from did.document import Document + class TestDocument(unittest.TestCase): def setUp(self): # Set the schema path for the document class - self.schema_path = os.path.join(os.path.dirname(__file__), '..', 'src', 'did', 'example_schema', 'demo_schema1') + self.schema_path = os.path.join( + os.path.dirname(__file__), + "..", + "src", + "did", + "example_schema", + "demo_schema1", + ) Document.set_schema_path(self.schema_path) def test_dependency_management(self): # Create a document of type 'demoC', which has 'depends_on' fields - doc = Document('demoC') + doc = Document("demoC") # Verify 'depends_on' field exists - self.assertIn('depends_on', doc.document_properties, - "The 'depends_on' field should exist for 'demoC' document type.") + self.assertIn( + "depends_on", + doc.document_properties, + "The 'depends_on' field should exist for 'demoC' document type.", + ) # Test setting a new dependency value - doc.set_dependency_value('item1', 'new_value') - retrieved_value = doc.dependency_value('item1') - self.assertEqual(retrieved_value, 'new_value', - "Failed to set and retrieve a new dependency value.") + doc.set_dependency_value("item1", "new_value") + retrieved_value = doc.dependency_value("item1") + self.assertEqual( + retrieved_value, + "new_value", + "Failed to set and retrieve a new dependency value.", + ) # Test updating an existing dependency value - doc.set_dependency_value('item1', 'updated_value') - retrieved_value = doc.dependency_value('item1') - self.assertEqual(retrieved_value, 'updated_value', - "Failed to update an existing dependency value.") + doc.set_dependency_value("item1", "updated_value") + retrieved_value = doc.dependency_value("item1") + self.assertEqual( + retrieved_value, + "updated_value", + "Failed to update an existing dependency value.", + ) def test_file_management(self): # Create a document of type 'demoFile', which is defined to handle files - doc = Document('demoFile') + doc = Document("demoFile") # Add a file and verify it was added - doc.add_file('filename1.ext', '/path/to/file1.txt') - is_in, _, fI_index = doc.is_in_file_list('filename1.ext') + doc.add_file("filename1.ext", "/path/to/file1.txt") + is_in, _, fI_index = doc.is_in_file_list("filename1.ext") self.assertTrue(is_in, "File 'filename1.ext' should be in the file list.") - self.assertIsNotNone(fI_index, "File info index should not be empty after adding a file.") + self.assertIsNotNone( + fI_index, "File info index should not be empty after adding a file." + ) # Verify the location of the added file - self.assertEqual(doc.document_properties['files']['file_info'][fI_index]['locations']['location'], '/path/to/file1.txt', - "The location of the added file is incorrect.") + self.assertEqual( + doc.document_properties["files"]["file_info"][fI_index]["locations"][ + "location" + ], + "/path/to/file1.txt", + "The location of the added file is incorrect.", + ) # Remove the file and verify it was removed - doc.remove_file('filename1.ext') - is_in_after_removal, _, fI_index_after_removal = doc.is_in_file_list('filename1.ext') + doc.remove_file("filename1.ext") + is_in_after_removal, _, fI_index_after_removal = doc.is_in_file_list( + "filename1.ext" + ) # After removal, searching for the file info should yield an empty index self.assertFalse(is_in_after_removal) - self.assertIsNone(fI_index_after_removal, "File info should be empty after removing the file.") + self.assertIsNone( + fI_index_after_removal, "File info should be empty after removing the file." + ) + -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_file_document.py b/tests/test_file_document.py index df3a3aa..33ad92e 100644 --- a/tests/test_file_document.py +++ b/tests/test_file_document.py @@ -5,8 +5,9 @@ from did.implementations.sqlitedb import SQLiteDB from tests.helpers import make_doc_tree + class TestFileDocument(unittest.TestCase): - DB_FILENAME = 'test_file_document.sqlite' + DB_FILENAME = "test_file_document.sqlite" def setUp(self): # Create a temporary database for testing @@ -14,10 +15,10 @@ def setUp(self): if os.path.exists(self.db_path): os.remove(self.db_path) self.db = SQLiteDB(self.db_path) - self.db.add_branch('a') + self.db.add_branch("a") _, _, self.docs = make_doc_tree([1, 1, 1]) for doc in self.docs: - self.db._do_add_doc(doc, 'a') + self.db._do_add_doc(doc, "a") def tearDown(self): # Clean up the database file @@ -27,27 +28,35 @@ def tearDown(self): def test_add_and_open_file(self): # Create a document of type 'demoFile' - schema_path = os.path.join(os.path.dirname(__file__), '..', 'src', 'did', 'example_schema', 'demo_schema1') + schema_path = os.path.join( + os.path.dirname(__file__), + "..", + "src", + "did", + "example_schema", + "demo_schema1", + ) Document.set_schema_path(schema_path) - doc = Document('demoFile') + doc = Document("demoFile") # Create a dummy file to add - dummy_file_path = 'dummy_file.txt' - with open(dummy_file_path, 'w') as f: - f.write('This is a test file.') + dummy_file_path = "dummy_file.txt" + with open(dummy_file_path, "w") as f: + f.write("This is a test file.") # Add the file to the document - doc.add_file('test_file.txt', dummy_file_path) + doc.add_file("test_file.txt", dummy_file_path) # Add the document to the database - self.db._do_add_doc(doc, 'a') + self.db._do_add_doc(doc, "a") # Open the file from the document - file_obj = self.db.open_doc(doc.id(), 'test_file.txt') + file_obj = self.db.open_doc(doc.id(), "test_file.txt") self.assertIsInstance(file_obj, ReadOnlyFileobj) # Clean up the dummy file os.remove(dummy_file_path) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_fileobj.py b/tests/test_fileobj.py index c6e6494..2bb0c80 100644 --- a/tests/test_fileobj.py +++ b/tests/test_fileobj.py @@ -1,22 +1,24 @@ import unittest from did.file import Fileobj + class TestFileobj(unittest.TestCase): def test_constructor(self): # Test creating a fileobj the_file_obj = Fileobj() self.assertIsNone(the_file_obj.fid) - self.assertEqual(the_file_obj.permission, 'r') - self.assertEqual(the_file_obj.machineformat, 'n') - self.assertEqual(the_file_obj.fullpathfilename, '') + self.assertEqual(the_file_obj.permission, "r") + self.assertEqual(the_file_obj.machineformat, "n") + self.assertEqual(the_file_obj.fullpathfilename, "") def test_custom_file_handler_error(self): # Test that passing customFileHandler to the constructor throws an error def my_handler(x): - print(f'File operation: {x}') + print(f"File operation: {x}") with self.assertRaises(TypeError): Fileobj(customFileHandler=my_handler) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_invalid_modification.py b/tests/test_invalid_modification.py index fe92ed7..51172ab 100644 --- a/tests/test_invalid_modification.py +++ b/tests/test_invalid_modification.py @@ -1,11 +1,11 @@ import unittest import os -from did.document import Document from did.implementations.sqlitedb import SQLiteDB from tests.helpers import make_doc_tree + class TestInvalidModification(unittest.TestCase): - DB_FILENAME = 'test_invalid_modification.sqlite' + DB_FILENAME = "test_invalid_modification.sqlite" def setUp(self): # Create a temporary database for testing @@ -13,14 +13,14 @@ def setUp(self): if os.path.exists(self.db_path): os.remove(self.db_path) self.db = SQLiteDB(self.db_path) - self.db.add_branch('a') + self.db.add_branch("a") # Ensure at least one document is created _, _, self.docs = make_doc_tree([1, 1, 1]) while not self.docs: _, _, self.docs = make_doc_tree([1, 1, 1]) for doc in self.docs: - self.db._do_add_doc(doc, 'a') + self.db._do_add_doc(doc, "a") def tearDown(self): # Clean up the database file @@ -32,25 +32,28 @@ def test_add_doc_twice(self): # Adding the same document twice should not raise an error doc = self.docs[0] try: - self.db._do_add_doc(doc, 'a') + self.db._do_add_doc(doc, "a") except Exception as e: self.fail(f"Adding the same document twice raised an exception: {e}") def test_add_doc_to_nonexistent_branch(self): doc = self.docs[0] with self.assertRaises(ValueError): - self.db._do_add_doc(doc, 'nonexistent_branch') + self.db._do_add_doc(doc, "nonexistent_branch") def test_remove_doc_from_nonexistent_branch(self): doc_id = self.docs[0].id() with self.assertRaises(ValueError): - self.db._do_remove_doc(doc_id, 'nonexistent_branch') + self.db._do_remove_doc(doc_id, "nonexistent_branch") def test_get_doc_from_nonexistent_branch(self): doc_id = self.docs[0].id() # This should not raise an error, but should return None - result = self.db.get_docs(doc_id, branch_id='nonexistent_branch', OnMissing='ignore') + result = self.db.get_docs( + doc_id, branch_id="nonexistent_branch", OnMissing="ignore" + ) self.assertIsNone(result) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_path_rebasing.py b/tests/test_path_rebasing.py index ad38bda..459c2a1 100644 --- a/tests/test_path_rebasing.py +++ b/tests/test_path_rebasing.py @@ -4,11 +4,11 @@ from did.document import Document from did.file import ReadOnlyFileobj from did.implementations.sqlitedb import SQLiteDB -from tests.helpers import make_doc_tree + class TestPathRebasing(unittest.TestCase): - DB_FILENAME = 'test_path_rebasing.sqlite' - SUBDIR = 'subdir_for_test' + DB_FILENAME = "test_path_rebasing.sqlite" + SUBDIR = "subdir_for_test" def setUp(self): # Create a subdirectory for the database @@ -18,12 +18,12 @@ def setUp(self): self.db_path = os.path.join(self.SUBDIR, self.DB_FILENAME) self.db = SQLiteDB(self.db_path) - self.db.add_branch('a') + self.db.add_branch("a") # Create a file inside the subdirectory (relative to DB) - self.relative_filename = 'relative_file.txt' - self.file_content = 'Content of relative file.' - with open(os.path.join(self.SUBDIR, self.relative_filename), 'w') as f: + self.relative_filename = "relative_file.txt" + self.file_content = "Content of relative file." + with open(os.path.join(self.SUBDIR, self.relative_filename), "w") as f: f.write(self.file_content) def tearDown(self): @@ -33,17 +33,24 @@ def tearDown(self): def test_open_doc_relative_path(self): # Create a document and add the file using a relative path - schema_path = os.path.join(os.path.dirname(__file__), '..', 'src', 'did', 'example_schema', 'demo_schema1') + schema_path = os.path.join( + os.path.dirname(__file__), + "..", + "src", + "did", + "example_schema", + "demo_schema1", + ) Document.set_schema_path(schema_path) - doc = Document('demoFile') + doc = Document("demoFile") # Add file with just the filename, which implies it's in the same dir as DB - doc.add_file('my_file', self.relative_filename) - self.db._do_add_doc(doc, 'a') + doc.add_file("my_file", self.relative_filename) + self.db._do_add_doc(doc, "a") # Open the document and retrieve the file # The logic should resolve self.relative_filename relative to self.db_path - file_obj = self.db.open_doc(doc.id(), 'my_file') + file_obj = self.db.open_doc(doc.id(), "my_file") self.assertIsInstance(file_obj, ReadOnlyFileobj) # Check content to ensure correct file was opened @@ -52,10 +59,11 @@ def test_open_doc_relative_path(self): # But ReadOnlyFileobj sets permission='r'. Fileobj.fopen adds 'b' if not present. So it opens as 'rb'. file_obj.fopen() - content = file_obj.fread().decode('utf-8') + content = file_obj.fread().decode("utf-8") file_obj.fclose() self.assertEqual(content, self.file_content) -if __name__ == '__main__': + +if __name__ == "__main__": unittest.main() diff --git a/tests/test_query.py b/tests/test_query.py index e8dc2c4..bc626ad 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -1,52 +1,54 @@ import unittest from did.query import Query + class TestQuery(unittest.TestCase): def test_creation(self): # Test creating a query - q = Query('base.name', 'exact_string', 'myname') + q = Query("base.name", "exact_string", "myname") ss = q.to_search_structure() - self.assertEqual(ss[0]['field'], 'base.name') - self.assertEqual(ss[0]['operation'], 'exact_string') - self.assertEqual(ss[0]['param1'], 'myname') + self.assertEqual(ss[0]["field"], "base.name") + self.assertEqual(ss[0]["operation"], "exact_string") + self.assertEqual(ss[0]["param1"], "myname") def test_invalid_operator(self): # Test that an invalid operator throws an error with self.assertRaises(ValueError): - Query('base.name', 'invalid_op', 'myname') + Query("base.name", "invalid_op", "myname") def test_and_query(self): # Test combining queries with AND - q1 = Query('base.name', 'exact_string', 'myname') - q2 = Query('base.age', 'greaterthan', 30) + q1 = Query("base.name", "exact_string", "myname") + q2 = Query("base.age", "greaterthan", 30) q_and = q1 & q2 ss = q_and.to_search_structure() self.assertEqual(len(ss), 2) - self.assertEqual(ss[0]['field'], 'base.name') - self.assertEqual(ss[0]['operation'], 'exact_string') - self.assertEqual(ss[0]['param1'], 'myname') - self.assertEqual(ss[1]['field'], 'base.age') - self.assertEqual(ss[1]['operation'], 'greaterthan') - self.assertEqual(ss[1]['param1'], 30) + self.assertEqual(ss[0]["field"], "base.name") + self.assertEqual(ss[0]["operation"], "exact_string") + self.assertEqual(ss[0]["param1"], "myname") + self.assertEqual(ss[1]["field"], "base.age") + self.assertEqual(ss[1]["operation"], "greaterthan") + self.assertEqual(ss[1]["param1"], 30) def test_or_query(self): # Test combining queries with OR - q1 = Query('base.name', 'exact_string', 'myname') - q2 = Query('base.age', 'greaterthan', 30) + q1 = Query("base.name", "exact_string", "myname") + q2 = Query("base.age", "greaterthan", 30) q_or = q1 | q2 ss = q_or.to_search_structure() - self.assertEqual(ss[0]['operation'], 'or') + self.assertEqual(ss[0]["operation"], "or") # The parameters of the OR query are themselves search structures - param1 = ss[0]['param1'] - self.assertEqual(param1[0]['field'], 'base.name') - self.assertEqual(param1[0]['operation'], 'exact_string') - self.assertEqual(param1[0]['param1'], 'myname') - - param2 = ss[0]['param2'] - self.assertEqual(param2[0]['field'], 'base.age') - self.assertEqual(param2[0]['operation'], 'greaterthan') - self.assertEqual(param2[0]['param1'], 30) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file + param1 = ss[0]["param1"] + self.assertEqual(param1[0]["field"], "base.name") + self.assertEqual(param1[0]["operation"], "exact_string") + self.assertEqual(param1[0]["param1"], "myname") + + param2 = ss[0]["param2"] + self.assertEqual(param2[0]["field"], "base.age") + self.assertEqual(param2[0]["operation"], "greaterthan") + self.assertEqual(param2[0]["param1"], 30) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_table_cross_join.py b/tests/test_table_cross_join.py index b896ec3..d0ea0f6 100644 --- a/tests/test_table_cross_join.py +++ b/tests/test_table_cross_join.py @@ -1,23 +1,27 @@ import unittest -import os from did.datastructures import table_cross_join + class TestTableCrossJoin(unittest.TestCase): def test_cross_join(self): - t1 = [{'a': 1, 'b': 2}, {'a': 3, 'b': 4}] - t2 = [{'c': 5, 'd': 6}, {'c': 7, 'd': 8}] + t1 = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + t2 = [{"c": 5, "d": 6}, {"c": 7, "d": 8}] expected_result = [ - {'a': 1, 'b': 2, 'c': 5, 'd': 6}, - {'a': 1, 'b': 2, 'c': 7, 'd': 8}, - {'a': 3, 'b': 4, 'c': 5, 'd': 6}, - {'a': 3, 'b': 4, 'c': 7, 'd': 8}, + {"a": 1, "b": 2, "c": 5, "d": 6}, + {"a": 1, "b": 2, "c": 7, "d": 8}, + {"a": 3, "b": 4, "c": 5, "d": 6}, + {"a": 3, "b": 4, "c": 7, "d": 8}, ] result = table_cross_join(t1, t2) # The result of the list comprehension is a set, so we need to sort both lists to compare - self.assertEqual(sorted(result, key=lambda x: (x['a'], x['c'])), sorted(expected_result, key=lambda x: (x['a'], x['c']))) + self.assertEqual( + sorted(result, key=lambda x: (x["a"], x["c"])), + sorted(expected_result, key=lambda x: (x["a"], x["c"])), + ) + -if __name__ == '__main__': - unittest.main() \ No newline at end of file +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_valid_modification.py b/tests/test_valid_modification.py index c1a6f92..73ac007 100644 --- a/tests/test_valid_modification.py +++ b/tests/test_valid_modification.py @@ -1,11 +1,11 @@ import unittest import os -from did.document import Document from did.implementations.sqlitedb import SQLiteDB from tests.helpers import make_doc_tree + class TestValidModification(unittest.TestCase): - DB_FILENAME = 'test_valid_modification.sqlite' + DB_FILENAME = "test_valid_modification.sqlite" def setUp(self): # Create a temporary database for testing @@ -13,14 +13,14 @@ def setUp(self): if os.path.exists(self.db_path): os.remove(self.db_path) self.db = SQLiteDB(self.db_path) - self.db.add_branch('a') + self.db.add_branch("a") # Ensure at least one document is created _, _, self.docs = make_doc_tree([1, 1, 1]) while not self.docs: _, _, self.docs = make_doc_tree([1, 1, 1]) for doc in self.docs: - self.db._do_add_doc(doc, 'a') + self.db._do_add_doc(doc, "a") def tearDown(self): # Clean up the database file @@ -33,19 +33,20 @@ def test_remove_and_readd_doc(self): doc_id = doc.id() # Remove the document - self.db.remove_docs(doc_id, 'a') + self.db.remove_docs(doc_id, "a") # Verify it's gone - retrieved_doc = self.db.get_docs(doc_id, OnMissing='ignore') + retrieved_doc = self.db.get_docs(doc_id, OnMissing="ignore") self.assertIsNone(retrieved_doc) # Re-add the document - self.db._do_add_doc(doc, 'a') + self.db._do_add_doc(doc, "a") # Verify it's back retrieved_doc = self.db.get_docs(doc_id) self.assertIsNotNone(retrieved_doc) self.assertEqual(retrieved_doc.id(), doc_id) -if __name__ == '__main__': - unittest.main() \ No newline at end of file + +if __name__ == "__main__": + unittest.main()