diff --git a/hvectorspaces/io/cockroach_client.py b/hvectorspaces/io/cockroach_client.py index 684a503..f1533f1 100644 --- a/hvectorspaces/io/cockroach_client.py +++ b/hvectorspaces/io/cockroach_client.py @@ -2,6 +2,7 @@ import os import time from tqdm import tqdm +from typing import Optional from dotenv import load_dotenv import psycopg2 @@ -188,3 +189,42 @@ def drop_table(self, table_name: str, cascade: bool = False): cur.execute(drop_query) self.conn.commit() print(f"✅ Dropped table '{table_name}'{' (CASCADE)' if cascade else ''}.") + + def fetch_per_decade_data( + self, decade_start: int, additional_fields: Optional[list] = None + ): + """ + Fetch works from a specific decade with their in-decade references. + + Args: + decade_start (int): The starting year of the decade (must be a multiple of 10). + additional_fields (Optional[list]): Additional field names to include in the query. + + Returns: + list: Query results with oa_id, in_decade_references, and any additional fields. + + Raises: + ValueError: If decade_start is not a multiple of 10. + """ + if decade_start % 10 != 0: + raise ValueError("decade_start must be a multiple of 10.") + decade_end = decade_start + 9 + + # Build field list safely + fields = [sql.Identifier("oa_id"), sql.Identifier("in_decade_references")] + if additional_fields: + fields.extend(sql.Identifier(f) for f in additional_fields) + + query = sql.SQL( + """ + SELECT {fields} + FROM openalex_vector_spaces + WHERE publication_year BETWEEN %s AND %s + """ + ).format(fields=sql.SQL(", ").join(fields)) + + def _fetch(cur): + cur.execute(query, (decade_start, decade_end)) + return cur.fetchall() + + return self.run_transaction(_fetch) diff --git a/tests/test_crdb_client.py b/tests/test_crdb_client.py index 186b840..83d4526 100644 --- a/tests/test_crdb_client.py +++ b/tests/test_crdb_client.py @@ -61,3 +61,26 @@ def test_cockroach_upload(): assert results == [("Alice", 10), ("Bob", 20), ("Charlie", 30)] client.drop_table(table_name) + + +def test_fetch_in_decade_references(): + """Test fetching works from a specific decade with their in-decade references.""" + decade_start = 1970 + with CockroachClient() as client: + results = client.fetch_per_decade_data( + decade_start, additional_fields=["publication_year", "referenced_works"] + ) + results = list(results) + assert len(results) > 0 + oa_ids = {row[0] for row in results} + assert any(in_dec_ref for _, in_dec_ref, _, _ in results) + for row in results: + oa_id, in_decade_references, publication_year, referenced_works = row + assert isinstance(oa_id, str) + assert all(ref in oa_ids for ref in in_decade_references) + assert ( + set(referenced_works) + .intersection(oa_ids) + .issubset(set(in_decade_references)) + ) + assert 1970 <= publication_year <= 1979