From 29202b72f35977d165c46e00a49b3f518d329220 Mon Sep 17 00:00:00 2001 From: Christie Ellks Date: Sat, 7 Mar 2026 18:49:08 -0800 Subject: [PATCH 1/2] add auth --- simple/requirements.txt | 1 + simple/stats/db.py | 21 +++++++++++++++++++-- simple/tests/stats/db_test.py | 11 +++++++++-- 3 files changed, 29 insertions(+), 4 deletions(-) diff --git a/simple/requirements.txt b/simple/requirements.txt index a06f59fe..2ad20545 100644 --- a/simple/requirements.txt +++ b/simple/requirements.txt @@ -27,6 +27,7 @@ requests==2.31.0 rdflib==7.4.0 s2sphere==0.2.5 six==1.16.0 +google-auth==2.49.0 tomli==2.0.1 tzdata==2023.3 urllib3==1.26.20 diff --git a/simple/stats/db.py b/simple/stats/db.py index 22e50fcb..07e6458d 100644 --- a/simple/stats/db.py +++ b/simple/stats/db.py @@ -22,8 +22,12 @@ import sqlite3 from typing import Any +from google.auth.exceptions import DefaultCredentialsError +import google.auth.transport.requests +from google.auth.transport.requests import AuthorizedSession from google.cloud.sql.connector.connector import Connector from google.cloud.sql.connector.connector import IPTypes +from google.oauth2 import id_token import pandas as pd from pyld import jsonld from pymysql.connections import Connection @@ -415,6 +419,20 @@ class DataCommonsPlatformDb(Db): def __init__(self, config: dict) -> None: self.url = config[FIELD_DB_PARAMS][DATA_COMMONS_PLATFORM_URL] + self.nodes_url = self.url + self.NODES_PATH + + try: + auth_req = google.auth.transport.requests.Request() + self.session = AuthorizedSession( + credentials=None, + refresh_handler=lambda: id_token.fetch_id_token(auth_req, self.url)) + id_token.fetch_id_token(auth_req, self.url) + logging.info("Using AUTHENTICATED session for %s", self.url) + except (DefaultCredentialsError, Exception) as e: + logging.warning( + "Could not fetch ID token (%s). Falling back to UNAUTHENTICATED session.", + e) + self.session = requests.Session() def maybe_clear_before_import(self): # Not applicable for Data Commons Platform. @@ -430,8 +448,7 @@ def insert_triples(self, triples: list[Triple]): "Writing %s triples (%s nodes) to Data Commons Platform at [%s]", len(triples), len(jsonld["@graph"]), self.url) logging.info("Writing jsonld: %s", json.dumps(jsonld, indent=2)) - nodes_url = self.url + self.NODES_PATH - response = requests.post(nodes_url, json=jsonld) + response = self.session.post(self.nodes_url, json=jsonld) if response.status_code != 200: # TODO: For now, we just log a warning, but we should raise an exception. logging.warning("Failed to write triples to Data Commons Platform: %s", diff --git a/simple/tests/stats/db_test.py b/simple/tests/stats/db_test.py index 46f1b014..47fb0d82 100644 --- a/simple/tests/stats/db_test.py +++ b/simple/tests/stats/db_test.py @@ -350,13 +350,20 @@ def test_get_datacommons_platform_config_from_env(self): } }) - @mock.patch('requests.post') + @mock.patch('stats.db.AuthorizedSession') + @mock.patch('stats.db.id_token.fetch_id_token') + @mock.patch('stats.db.google.auth.transport.requests.Request') @mock.patch.dict( os.environ, { "USE_DATA_COMMONS_PLATFORM": "true", "DATA_COMMONS_PLATFORM_URL": "https://test_url" }) - def test_insert_triples_into_datacommons_platform(self, mock_post): + def test_insert_triples_into_datacommons_platform(self, mock_auth_request, mock_fetch_id_token, mock_authorized_session): + + mock_session_instance = mock.Mock() + mock_authorized_session.return_value = mock_session_instance + mock_post = mock_session_instance.post + config = get_datacommons_platform_config_from_env() db = create_and_update_db(config) From f0384175bd61938d3837f830c2a63c62b7246bae Mon Sep 17 00:00:00 2001 From: Christie Ellks Date: Sat, 7 Mar 2026 21:10:04 -0800 Subject: [PATCH 2/2] use default creds (tests need updating) --- .gitignore | 3 ++- run_test.sh | 20 ++++++++++---------- simple/stats/db.py | 31 +++++++++++++++++++------------ 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index 66262ad7..8c591086 100644 --- a/.gitignore +++ b/.gitignore @@ -2,7 +2,8 @@ __pycache__/ *.DS_Store .vscode/settings.json -.env/ +.env +.venv/ .idea/ .vscode/ *.iml diff --git a/run_test.sh b/run_test.sh index 829019af..c06dde4a 100755 --- a/run_test.sh +++ b/run_test.sh @@ -19,22 +19,22 @@ set -e # Fixes lint function run_lint_fix { echo -e "#### Fixing Python code" - python3 -m venv .env - source .env/bin/activate + python3 -m venv .venv + source .venv/bin/activate pip3 install yapf==0.40.2 -q if ! command -v isort &> /dev/null then pip3 install isort -q fi - yapf -r -i -p --style='{based_on_style: google, indent_width: 2}' simple/ -e=*pb2.py -e=**/.env/** - isort simple/ --skip-glob=*pb2.py --skip-glob=**/.env/** --profile google + yapf -r -i -p --style='{based_on_style: google, indent_width: 2}' simple/ -e=*pb2.py -e=**/.venv/** + isort simple/ --skip-glob=*pb2.py --skip-glob=**/.venv/** --profile google deactivate } # Lint test function run_lint_test { - python3 -m venv .env - source .env/bin/activate + python3 -m venv .venv + source .venv/bin/activate pip3 install yapf==0.40.2 -q if ! command -v isort &> /dev/null then @@ -42,13 +42,13 @@ function run_lint_test { fi echo -e "#### Checking Python style" - if ! yapf --recursive --diff --style='{based_on_style: google, indent_width: 2}' -p simple/ -e=*pb2.py -e=**/.env/**; then + if ! yapf --recursive --diff --style='{based_on_style: google, indent_width: 2}' -p simple/ -e=*pb2.py -e=**/.venv/**; then echo "Fix Python lint errors by running ./run_test.sh -f" exit 1 fi echo -e "#### Checking Python import order" - if ! isort simple/ -c --skip-glob=*pb2.py --skip-glob=**/.env/** --profile google; then + if ! isort simple/ -c --skip-glob=*pb2.py --skip-glob=**/.venv/** --profile google; then echo "Fix Python import sort orders by running ./run_test.sh -f" exit 1 fi @@ -72,8 +72,8 @@ function py_test { # Do not use Cloud SQL. export USE_CLOUDSQL=false - python3 -m venv .env - source .env/bin/activate + python3 -m venv .venv + source .venv/bin/activate cd simple pip3 install -r requirements.txt -q diff --git a/simple/stats/db.py b/simple/stats/db.py index 07e6458d..846f77f8 100644 --- a/simple/stats/db.py +++ b/simple/stats/db.py @@ -421,18 +421,25 @@ def __init__(self, config: dict) -> None: self.url = config[FIELD_DB_PARAMS][DATA_COMMONS_PLATFORM_URL] self.nodes_url = self.url + self.NODES_PATH - try: + def _get_id_token(url): + # 1. Try to get default credentials + creds, _ = google.auth.default() auth_req = google.auth.transport.requests.Request() - self.session = AuthorizedSession( - credentials=None, - refresh_handler=lambda: id_token.fetch_id_token(auth_req, self.url)) - id_token.fetch_id_token(auth_req, self.url) - logging.info("Using AUTHENTICATED session for %s", self.url) - except (DefaultCredentialsError, Exception) as e: - logging.warning( - "Could not fetch ID token (%s). Falling back to UNAUTHENTICATED session.", - e) - self.session = requests.Session() + + # 2. Refresh to ensure the token is loaded + creds.refresh(auth_req) + + # 3. Check if the credentials already have an id_token (typical for local gcloud) + if hasattr(creds, 'id_token') and creds.id_token: + return creds.id_token + + # 4. Fallback to fetching it (typical for Service Accounts/Cloud environments) + return google.oauth2.id_token.fetch_id_token(auth_req, url) + id_token = _get_id_token(self.url) + + # 2. Make the authenticated request + self.headers = {"Authorization": f"Bearer {id_token}"} + def maybe_clear_before_import(self): # Not applicable for Data Commons Platform. @@ -448,7 +455,7 @@ def insert_triples(self, triples: list[Triple]): "Writing %s triples (%s nodes) to Data Commons Platform at [%s]", len(triples), len(jsonld["@graph"]), self.url) logging.info("Writing jsonld: %s", json.dumps(jsonld, indent=2)) - response = self.session.post(self.nodes_url, json=jsonld) + response = requests.post(self.nodes_url, json=jsonld, headers=self.headers) if response.status_code != 200: # TODO: For now, we just log a warning, but we should raise an exception. logging.warning("Failed to write triples to Data Commons Platform: %s",