From c7c2dc06feacc97f910f45aa1cc07114f84300e5 Mon Sep 17 00:00:00 2001 From: FPreta Date: Mon, 17 Nov 2025 12:04:40 +0100 Subject: [PATCH 1/2] added script to fetch all relevant graph nodes from a given decade --- hvectorspaces/io/cockroach_client.py | 27 +++++++++++++++++++++++++++ tests/test_crdb_client.py | 22 ++++++++++++++++++++++ 2 files changed, 49 insertions(+) diff --git a/hvectorspaces/io/cockroach_client.py b/hvectorspaces/io/cockroach_client.py index 684a503..f5b4dc7 100644 --- a/hvectorspaces/io/cockroach_client.py +++ b/hvectorspaces/io/cockroach_client.py @@ -2,6 +2,7 @@ import os import time from tqdm import tqdm +from typing import Optional from dotenv import load_dotenv import psycopg2 @@ -188,3 +189,29 @@ def drop_table(self, table_name: str, cascade: bool = False): cur.execute(drop_query) self.conn.commit() print(f"✅ Dropped table '{table_name}'{' (CASCADE)' if cascade else ''}.") + + def fetch_per_decade_data( + self, decade_start: int, additional_fields: Optional[list] = None + ): + if decade_start % 10 != 0: + raise ValueError("decade_start must be a multiple of 10.") + decade_end = decade_start + 9 + + # Build field list safely + fields = [sql.Identifier("oa_id"), sql.Identifier("in_decade_references")] + if additional_fields: + fields.extend(sql.Identifier(f) for f in additional_fields) + + query = sql.SQL( + """ + SELECT {fields} + FROM openalex_vector_spaces + WHERE publication_year BETWEEN %s AND %s + """ + ).format(fields=sql.SQL(", ").join(fields)) + + def _fetch(cur): + cur.execute(query, (decade_start, decade_end)) + return cur.fetchall() + + return self.run_transaction(_fetch) diff --git a/tests/test_crdb_client.py b/tests/test_crdb_client.py index 186b840..7b3663c 100644 --- a/tests/test_crdb_client.py +++ b/tests/test_crdb_client.py @@ -61,3 +61,25 @@ def test_cockroach_upload(): assert results == [("Alice", 10), ("Bob", 20), ("Charlie", 30)] client.drop_table(table_name) + + +def test_fetch_in_decade_references(): + decade_start = 1970 + with CockroachClient() as client: + results = client.fetch_per_decade_data( + decade_start, additional_fields=["publication_year", "referenced_works"] + ) + results = list(results) + assert len(results) > 0 + oa_ids = {row[0] for row in results} + assert any(in_dec_ref for _, in_dec_ref, _, _ in results) + for row in results: + oa_id, in_decade_references, publication_year, referenced_works = row + assert isinstance(oa_id, str) + assert all(ref in oa_ids for ref in in_decade_references) + assert ( + set(referenced_works) + .intersection(set(in_decade_references)) + .issubset(oa_ids) + ) + assert 1970 <= publication_year <= 1979 From 8a26ec87dd7d2af6aa9427edcc87f77a68ba15e6 Mon Sep 17 00:00:00 2001 From: FPreta Date: Mon, 17 Nov 2025 12:13:34 +0100 Subject: [PATCH 2/2] addressed review comments --- hvectorspaces/io/cockroach_client.py | 13 +++++++++++++ tests/test_crdb_client.py | 5 +++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/hvectorspaces/io/cockroach_client.py b/hvectorspaces/io/cockroach_client.py index f5b4dc7..f1533f1 100644 --- a/hvectorspaces/io/cockroach_client.py +++ b/hvectorspaces/io/cockroach_client.py @@ -193,6 +193,19 @@ def drop_table(self, table_name: str, cascade: bool = False): def fetch_per_decade_data( self, decade_start: int, additional_fields: Optional[list] = None ): + """ + Fetch works from a specific decade with their in-decade references. + + Args: + decade_start (int): The starting year of the decade (must be a multiple of 10). + additional_fields (Optional[list]): Additional field names to include in the query. + + Returns: + list: Query results with oa_id, in_decade_references, and any additional fields. + + Raises: + ValueError: If decade_start is not a multiple of 10. + """ if decade_start % 10 != 0: raise ValueError("decade_start must be a multiple of 10.") decade_end = decade_start + 9 diff --git a/tests/test_crdb_client.py b/tests/test_crdb_client.py index 7b3663c..83d4526 100644 --- a/tests/test_crdb_client.py +++ b/tests/test_crdb_client.py @@ -64,6 +64,7 @@ def test_cockroach_upload(): def test_fetch_in_decade_references(): + """Test fetching works from a specific decade with their in-decade references.""" decade_start = 1970 with CockroachClient() as client: results = client.fetch_per_decade_data( @@ -79,7 +80,7 @@ def test_fetch_in_decade_references(): assert all(ref in oa_ids for ref in in_decade_references) assert ( set(referenced_works) - .intersection(set(in_decade_references)) - .issubset(oa_ids) + .intersection(oa_ids) + .issubset(set(in_decade_references)) ) assert 1970 <= publication_year <= 1979