Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions cds_dojson/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,46 @@
from invenio_query_parser.walkers.match_unit import MatchUnit
from invenio_query_parser.walkers.pypeg_to_ast import PypegConverter

# Cache of (name, model, parsed_query_ast) tuples keyed by entry_point_group.
# Entry points and their query ASTs are static for the lifetime of a process,
# so scanning importlib_metadata and re-parsing pypeg2 queries on every call
# was the dominant cost (~84s for 1320 calls).
_models_cache = {}


class Query(object):
"""Query object."""

def __init__(self, query):
"""Init."""
self._query = query
# Parse once at construction time; re-parsing on every match() call
# via a @property was the other half of the cost.
tree = pypeg2.parse(query, parser, whitespace="")
self._parsed = tree.accept(PypegConverter())

@property
def query(self):
"""Parse query string using given grammar."""
tree = pypeg2.parse(self._query, parser, whitespace="")
return tree.accept(PypegConverter())
"""Return the pre-parsed query AST."""
return self._parsed

def match(self, record, user_info=None):
"""Return True if record match the query."""
return self.query.accept(MatchUnit(record))


def _load_models(entry_point_group):
"""Load and cache entry point models for a given group."""
if entry_point_group not in _models_cache:
entrypoints = set(importlib_metadata.entry_points(group=entry_point_group))
models = []
for ep in entrypoints:
model = ep.load()
models.append((ep.name, model, Query(model.__query__)))
_models_cache[entry_point_group] = models
return _models_cache[entry_point_group]


def matcher(record, entry_point_group):
"""Matcher for DoJSON models.

Expand All @@ -60,15 +81,12 @@ def matcher(record, entry_point_group):
logger = logging.getLogger(__name__ + ".dojson_matcher")

_matches = []
entrypoints = set(importlib_metadata.entry_points(group=entry_point_group))
for entry_point in entrypoints:
model = entry_point.load()
query = Query(model.__query__)
for name, model, query in _load_models(entry_point_group):
if query.match(record):
logger.info("Model `{0}` found matching the query {1}.".format(
entry_point.name, model
name, model
))
_matches.append([entry_point.name, model])
_matches.append([name, model])
try:
if len(_matches) > 1:
logger.error(
Expand Down
Loading