Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 59 additions & 1 deletion src/aind_data_access_api/helpers/docdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def get_projection_by_id(
"""
Download a record from docdb using the record _id and a projection.

Projections return fields set to 1 {"field": 1}
Projections return fields set to 1, e.g. {"field": 1}

Parameters
----------
Expand Down Expand Up @@ -137,3 +137,61 @@ def get_name_from_id(
if len(records) == 0:
raise ValueError(f"No record found with _id {_id}")
return records[0]["name"]


def get_record_by_name(
client: MetadataDbClient,
name: str,
) -> Optional[dict]:
"""
Download a record from docdb using the record name.

Parameters
----------
client : MetadataDbClient
name : str

Returns
-------
Optional[dict]
None if record does not exist. Otherwise, it will return the record as
a dict.
"""
records = client.retrieve_docdb_records(
filter_query={"name": name}, limit=1
)
if len(records) > 0:
return records[0]
else:
return None


def get_projection_by_name(
client: MetadataDbClient,
name: str,
projection: dict,
) -> Optional[dict]:
"""
Download a record from docdb using the record name and a projection.

Projections return fields set to 1, e.g. {"field": 1}

Parameters
----------
client : MetadataDbClient
name : str
projection : dict

Returns
-------
Optional[dict]
None if record does not exist. Otherwise, it will return the projected
record as a dict.
"""
records = client.retrieve_docdb_records(
filter_query={"name": name}, projection=projection, limit=1
)
if len(records) > 0:
return records[0]
else:
return None
36 changes: 34 additions & 2 deletions tests/helpers/test_docdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
get_id_from_name,
get_name_from_id,
get_projection_by_id,
get_projection_by_name,
get_record_by_id,
get_record_by_name,
)


Expand Down Expand Up @@ -75,8 +77,8 @@ def test_get_record_by_id(self):
record = get_record_by_id(client, _id="abcd")
self.assertIsNone(record)

def test_get_projected_record_from_docdb(self):
"""Tests get_projected_record_from_docdb"""
def test_get_projection_by_id(self):
"""Tests get_projection_by_id"""
client = MagicMock()
client.retrieve_docdb_records.return_value = [
{"quality_control": {"a": 1}}
Expand All @@ -95,6 +97,36 @@ def test_get_field_from_docdb(self):
field = get_field_by_id(client, _id="abcd", field="quality_control")
self.assertEqual({"quality_control": {"a": 1}}, field)

def test_get_record_by_name(self):
"""Tests get_record_by_name"""
client = MagicMock()
client.retrieve_docdb_records.return_value = [{"name": "abcd"}]
record = get_record_by_name(client=client, name="abcd")
self.assertEqual({"name": "abcd"}, record)

# test the empty case
client.retrieve_docdb_records.return_value = []
record = get_record_by_name(client=client, name="abcd")
self.assertIsNone(record)

def test_get_projection_by_name(self):
"""Tests get_projection_by_name"""
client = MagicMock()
client.retrieve_docdb_records.return_value = [
{"quality_control": {"a": 1}}
]
record = get_projection_by_name(
client=client, name="abcd", projection={"quality_control": 1}
)
self.assertEqual({"quality_control": {"a": 1}}, record)

# test the empty case
client.retrieve_docdb_records.return_value = []
record = get_projection_by_name(
client=client, name="abcd", projection={"quality_control": 1}
)
self.assertIsNone(record)


if __name__ == "__main__":
unittest.main()
Loading