diff --git a/src/aind_data_access_api/helpers/docdb.py b/src/aind_data_access_api/helpers/docdb.py index 7a8f32e..2683da6 100644 --- a/src/aind_data_access_api/helpers/docdb.py +++ b/src/aind_data_access_api/helpers/docdb.py @@ -37,7 +37,7 @@ def get_projection_by_id( """ Download a record from docdb using the record _id and a projection. - Projections return fields set to 1 {"field": 1} + Projections return fields set to 1, e.g. {"field": 1} Parameters ---------- @@ -137,3 +137,61 @@ def get_name_from_id( if len(records) == 0: raise ValueError(f"No record found with _id {_id}") return records[0]["name"] + + +def get_record_by_name( + client: MetadataDbClient, + name: str, +) -> Optional[dict]: + """ + Download a record from docdb using the record name. + + Parameters + ---------- + client : MetadataDbClient + name : str + + Returns + ------- + Optional[dict] + None if record does not exist. Otherwise, it will return the record as + a dict. + """ + records = client.retrieve_docdb_records( + filter_query={"name": name}, limit=1 + ) + if len(records) > 0: + return records[0] + else: + return None + + +def get_projection_by_name( + client: MetadataDbClient, + name: str, + projection: dict, +) -> Optional[dict]: + """ + Download a record from docdb using the record name and a projection. + + Projections return fields set to 1, e.g. {"field": 1} + + Parameters + ---------- + client : MetadataDbClient + name : str + projection : dict + + Returns + ------- + Optional[dict] + None if record does not exist. Otherwise, it will return the projected + record as a dict. + """ + records = client.retrieve_docdb_records( + filter_query={"name": name}, projection=projection, limit=1 + ) + if len(records) > 0: + return records[0] + else: + return None diff --git a/tests/helpers/test_docdb.py b/tests/helpers/test_docdb.py index c072cb5..cd7369a 100644 --- a/tests/helpers/test_docdb.py +++ b/tests/helpers/test_docdb.py @@ -8,7 +8,9 @@ get_id_from_name, get_name_from_id, get_projection_by_id, + get_projection_by_name, get_record_by_id, + get_record_by_name, ) @@ -75,8 +77,8 @@ def test_get_record_by_id(self): record = get_record_by_id(client, _id="abcd") self.assertIsNone(record) - def test_get_projected_record_from_docdb(self): - """Tests get_projected_record_from_docdb""" + def test_get_projection_by_id(self): + """Tests get_projection_by_id""" client = MagicMock() client.retrieve_docdb_records.return_value = [ {"quality_control": {"a": 1}} @@ -95,6 +97,36 @@ def test_get_field_from_docdb(self): field = get_field_by_id(client, _id="abcd", field="quality_control") self.assertEqual({"quality_control": {"a": 1}}, field) + def test_get_record_by_name(self): + """Tests get_record_by_name""" + client = MagicMock() + client.retrieve_docdb_records.return_value = [{"name": "abcd"}] + record = get_record_by_name(client=client, name="abcd") + self.assertEqual({"name": "abcd"}, record) + + # test the empty case + client.retrieve_docdb_records.return_value = [] + record = get_record_by_name(client=client, name="abcd") + self.assertIsNone(record) + + def test_get_projection_by_name(self): + """Tests get_projection_by_name""" + client = MagicMock() + client.retrieve_docdb_records.return_value = [ + {"quality_control": {"a": 1}} + ] + record = get_projection_by_name( + client=client, name="abcd", projection={"quality_control": 1} + ) + self.assertEqual({"quality_control": {"a": 1}}, record) + + # test the empty case + client.retrieve_docdb_records.return_value = [] + record = get_projection_by_name( + client=client, name="abcd", projection={"quality_control": 1} + ) + self.assertIsNone(record) + if __name__ == "__main__": unittest.main()