diff --git a/src/pir_pipeline/utils/dashboard_utils.py b/src/pir_pipeline/utils/dashboard_utils.py index 3334954..38ebd67 100644 --- a/src/pir_pipeline/utils/dashboard_utils.py +++ b/src/pir_pipeline/utils/dashboard_utils.py @@ -2,6 +2,7 @@ __all__ = ["get_matches"] +import re from collections import OrderedDict, namedtuple from hashlib import sha1 @@ -92,6 +93,10 @@ def get_search_results( columns = OrderedDict([(col, None) for col in columns]) columns = tuple(columns.keys()) + # Escape keyword (https://stackoverflow.com/questions/4202538/escape-special-characters-in-a-python-string) + keyword = re.escape(keyword) + + # Create regex match query keyword_query = select(table.c[columns]) # Adapted from @@ -117,7 +122,28 @@ def get_search_results( with db.engine.connect() as conn: # Get all search results - result = conn.execute(keyword_query, {"keyword": keyword}) + question_ids = conn.execute( + select(keyword_query.c["question_id"]) + .distinct() + .where(keyword_query.c["question_id"].is_not(None)), + {"keyword": keyword}, + ).scalars() + uqids = conn.execute( + select(keyword_query.c["uqid"]) + .distinct() + .where(keyword_query.c["uqid"].is_not(None)), + {"keyword": keyword}, + ).scalars() + search_query = ( + select(table.c[columns]) + .where( + or_( + table.c["question_id"].in_(question_ids), table.c["uqid"].in_(uqids) + ) + ) + .order_by(table.c["uqid"], table.c["year"].desc()) + ) + result = conn.execute(search_query) # Convert results to dictionary for res in result.all(): diff --git a/tests/utils/test_dashboard_utils.py b/tests/utils/test_dashboard_utils.py index a56cc97..0355039 100644 --- a/tests/utils/test_dashboard_utils.py +++ b/tests/utils/test_dashboard_utils.py @@ -76,12 +76,6 @@ def test_get_matches(self, sql_utils): def test_get_search_results(self, sql_utils): Check = namedtuple("Check", ["kwargs", "ids"]) checks = [ - Check( - { - "keyword": "^Staff$", - }, - {"0e93c25d3a95604f40d3a64e2298093b4faed6f2"}, - ), Check( { "keyword": "child development staff - qualifications",