Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions hvectorspaces/io/cockroach_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import time
from tqdm import tqdm
from typing import Optional

from dotenv import load_dotenv
import psycopg2
Expand Down Expand Up @@ -188,3 +189,42 @@ def drop_table(self, table_name: str, cascade: bool = False):
cur.execute(drop_query)
self.conn.commit()
print(f"✅ Dropped table '{table_name}'{' (CASCADE)' if cascade else ''}.")

def fetch_per_decade_data(
self, decade_start: int, additional_fields: Optional[list] = None
):
"""
Fetch works from a specific decade with their in-decade references.

Args:
decade_start (int): The starting year of the decade (must be a multiple of 10).
additional_fields (Optional[list]): Additional field names to include in the query.

Returns:
list: Query results with oa_id, in_decade_references, and any additional fields.

Raises:
ValueError: If decade_start is not a multiple of 10.
"""
if decade_start % 10 != 0:
raise ValueError("decade_start must be a multiple of 10.")
decade_end = decade_start + 9

# Build field list safely
fields = [sql.Identifier("oa_id"), sql.Identifier("in_decade_references")]
if additional_fields:
fields.extend(sql.Identifier(f) for f in additional_fields)

query = sql.SQL(
"""
SELECT {fields}
FROM openalex_vector_spaces
WHERE publication_year BETWEEN %s AND %s
"""
).format(fields=sql.SQL(", ").join(fields))

def _fetch(cur):
cur.execute(query, (decade_start, decade_end))
return cur.fetchall()

return self.run_transaction(_fetch)
23 changes: 23 additions & 0 deletions tests/test_crdb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,26 @@ def test_cockroach_upload():
assert results == [("Alice", 10), ("Bob", 20), ("Charlie", 30)]

client.drop_table(table_name)


def test_fetch_in_decade_references():
"""Test fetching works from a specific decade with their in-decade references."""
decade_start = 1970
with CockroachClient() as client:
results = client.fetch_per_decade_data(
decade_start, additional_fields=["publication_year", "referenced_works"]
)
results = list(results)
assert len(results) > 0
oa_ids = {row[0] for row in results}
assert any(in_dec_ref for _, in_dec_ref, _, _ in results)
for row in results:
oa_id, in_decade_references, publication_year, referenced_works = row
assert isinstance(oa_id, str)
assert all(ref in oa_ids for ref in in_decade_references)
assert (
set(referenced_works)
.intersection(oa_ids)
.issubset(set(in_decade_references))
)
Comment on lines +81 to +85
Copy link

Copilot AI Nov 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion is redundant and potentially confusing. By definition, if in_decade_references is a subset of oa_ids (as checked on line 79), then the intersection of any set with in_decade_references will also be a subset of oa_ids. This assertion doesn't add meaningful test coverage and could be removed to simplify the test logic.

Suggested change
assert (
set(referenced_works)
.intersection(set(in_decade_references))
.issubset(oa_ids)
)
# The following assertion is redundant and has been removed.

Copilot uses AI. Check for mistakes.
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@copilot it was a mistake. I changed it to the actual test I wanted (that is, in_decade_references contains all the referenced works that are in oa_ids).

assert (
                set(referenced_works)
                .intersection(oa_ids)
                .issubset(set(in_decade_references))
            )

What do you think?

assert 1970 <= publication_year <= 1979