diff --git a/python-wrapper/pyproject.toml b/python-wrapper/pyproject.toml index c7f57a28..fddad396 100644 --- a/python-wrapper/pyproject.toml +++ b/python-wrapper/pyproject.toml @@ -58,7 +58,7 @@ docs = [ "nbsphinx-link==1.3.1", ] pandas = ["pandas>=2, <3", "pandas-stubs>=2, <3"] -gds = ["graphdatascience>=1, <2"] # not compatible yet with Python 3.13 +gds = ["graphdatascience>=1, <2"] neo4j = ["neo4j"] notebook = [ "ipykernel==6.29.5", diff --git a/python-wrapper/tests/conftest.py b/python-wrapper/tests/conftest.py index 531d25f6..081a7007 100644 --- a/python-wrapper/tests/conftest.py +++ b/python-wrapper/tests/conftest.py @@ -31,32 +31,39 @@ def pytest_collection_modifyitems(config: Any, items: Any) -> None: @pytest.fixture(scope="package") -def gds() -> Generator[Any, None, None]: - from gds_helper import aura_api, connect_to_plugin_gds, create_aurads_instance - from graphdatascience import GraphDataScience +def aura_ds_instance() -> Generator[Any, None, None]: + if os.environ.get("AURA_API_CLIENT_ID", None) is None: + yield None + return + + from gds_helper import aura_api, create_aurads_instance + + api = aura_api() + id, dbms_connection_info = create_aurads_instance(api) + + # setting as environment variables to run notebooks with this connection + os.environ["NEO4J_URI"] = dbms_connection_info.uri + os.environ["NEO4J_USER"] = dbms_connection_info.username + os.environ["NEO4J_PASSWORD"] = dbms_connection_info.password + yield dbms_connection_info - use_cloud_setup = os.environ.get("AURA_API_CLIENT_ID", None) + # Clear Neo4j_URI after test (rerun should create a new instance) + os.environ["NEO4J_URI"] = "" + api.delete_instance(id) - if use_cloud_setup: - api = aura_api() - id, dbms_connection_info = create_aurads_instance(api) - # setting as environment variables to run notebooks with this connection - os.environ["NEO4J_URI"] = dbms_connection_info.uri - os.environ["NEO4J_USER"] = dbms_connection_info.username - os.environ["NEO4J_PASSWORD"] = dbms_connection_info.password +@pytest.fixture(scope="package") +def gds(aura_ds_instance: Any) -> Generator[Any, None, None]: + from gds_helper import connect_to_plugin_gds + from graphdatascience import GraphDataScience + if aura_ds_instance: yield GraphDataScience( - endpoint=dbms_connection_info.uri, - auth=(dbms_connection_info.username, dbms_connection_info.password), + endpoint=aura_ds_instance.uri, + auth=(aura_ds_instance.username, aura_ds_instance.password), aura_ds=True, database="neo4j", ) - - # Clear Neo4j_URI after test (rerun should create a new instance) - os.environ["NEO4J_URI"] = "" - - api.delete_instance(id) else: NEO4J_URI = os.environ.get("NEO4J_URI", "neo4j://localhost:7687") gds = connect_to_plugin_gds(NEO4J_URI) @@ -65,12 +72,24 @@ def gds() -> Generator[Any, None, None]: @pytest.fixture(scope="package") -def neo4j_session() -> Generator[Any, None, None]: +def neo4j_driver(aura_ds_instance: Any) -> Generator[Any, None, None]: import neo4j - NEO4J_URI = os.environ.get("NEO4J_URI", "neo4j://localhost:7687") + if aura_ds_instance: + driver = neo4j.GraphDatabase.driver( + aura_ds_instance.uri, auth=(aura_ds_instance.username, aura_ds_instance.password) + ) + else: + NEO4J_URI = os.environ.get("NEO4J_URI", "neo4j://localhost:7687") + driver = neo4j.GraphDatabase.driver(NEO4J_URI) + + driver.verify_connectivity() + yield driver - with neo4j.GraphDatabase.driver(NEO4J_URI) as driver: - driver.verify_connectivity() - with driver.session() as session: - yield session + driver.close() + + +@pytest.fixture(scope="package") +def neo4j_session(neo4j_driver: Any) -> Generator[Any, None, None]: + with neo4j_driver.session() as session: + yield session diff --git a/scripts/checkstyle.sh b/scripts/checkstyle.sh index 93ca36e5..6d5d70d3 100755 --- a/scripts/checkstyle.sh +++ b/scripts/checkstyle.sh @@ -7,7 +7,7 @@ set -o pipefail python -m ruff check . python -m ruff format --check . -mypy --config-file python-wrapper/pyproject.toml . +mypy --config-file "${GIT_ROOT}/python-wrapper/pyproject.toml" . if [ "${SKIP_NOTEBOOKS:-false}" == "true" ]; then