diff --git a/cds_dojson/matcher.py b/cds_dojson/matcher.py index c6adf76..ae6f15a 100644 --- a/cds_dojson/matcher.py +++ b/cds_dojson/matcher.py @@ -28,6 +28,12 @@ from invenio_query_parser.walkers.match_unit import MatchUnit from invenio_query_parser.walkers.pypeg_to_ast import PypegConverter +# Cache of (name, model, parsed_query_ast) tuples keyed by entry_point_group. +# Entry points and their query ASTs are static for the lifetime of a process, +# so scanning importlib_metadata and re-parsing pypeg2 queries on every call +# was the dominant cost (~84s for 1320 calls). +_models_cache = {} + class Query(object): """Query object.""" @@ -35,18 +41,33 @@ class Query(object): def __init__(self, query): """Init.""" self._query = query + # Parse once at construction time; re-parsing on every match() call + # via a @property was the other half of the cost. + tree = pypeg2.parse(query, parser, whitespace="") + self._parsed = tree.accept(PypegConverter()) @property def query(self): - """Parse query string using given grammar.""" - tree = pypeg2.parse(self._query, parser, whitespace="") - return tree.accept(PypegConverter()) + """Return the pre-parsed query AST.""" + return self._parsed def match(self, record, user_info=None): """Return True if record match the query.""" return self.query.accept(MatchUnit(record)) +def _load_models(entry_point_group): + """Load and cache entry point models for a given group.""" + if entry_point_group not in _models_cache: + entrypoints = set(importlib_metadata.entry_points(group=entry_point_group)) + models = [] + for ep in entrypoints: + model = ep.load() + models.append((ep.name, model, Query(model.__query__))) + _models_cache[entry_point_group] = models + return _models_cache[entry_point_group] + + def matcher(record, entry_point_group): """Matcher for DoJSON models. @@ -60,15 +81,12 @@ def matcher(record, entry_point_group): logger = logging.getLogger(__name__ + ".dojson_matcher") _matches = [] - entrypoints = set(importlib_metadata.entry_points(group=entry_point_group)) - for entry_point in entrypoints: - model = entry_point.load() - query = Query(model.__query__) + for name, model, query in _load_models(entry_point_group): if query.match(record): logger.info("Model `{0}` found matching the query {1}.".format( - entry_point.name, model + name, model )) - _matches.append([entry_point.name, model]) + _matches.append([name, model]) try: if len(_matches) > 1: logger.error(