diff --git a/pipeline/spanner/pom.xml b/pipeline/spanner/pom.xml index 1db64212..035b825b 100644 --- a/pipeline/spanner/pom.xml +++ b/pipeline/spanner/pom.xml @@ -76,4 +76,15 @@ test + + + + + ../workflow/ingestion-helper + + schema.sql + + + + diff --git a/pipeline/spanner/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java b/pipeline/spanner/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java index 87fb58d9..b0483371 100644 --- a/pipeline/spanner/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java +++ b/pipeline/spanner/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java @@ -231,17 +231,17 @@ public void validateOrInitializeDatabase() { } } - /** Reads DDL statements from the spanner_schema.sql file in the resources directory. */ + /** Reads DDL statements from the schema.sql file in the resources directory. */ List readDdlStatements() { - InputStream inputStream = getClass().getClassLoader().getResourceAsStream("spanner_schema.sql"); + InputStream inputStream = getClass().getClassLoader().getResourceAsStream("schema.sql"); if (inputStream == null) { - throw new IllegalStateException("Could not find spanner_schema.sql in resources."); + throw new IllegalStateException("Could not find schema.sql in resources."); } try (BufferedReader reader = new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) { return parseDdlStatements(reader); } catch (IOException e) { - throw new IllegalStateException("Failed to read spanner_schema.sql", e); + throw new IllegalStateException("Failed to read schema.sql", e); } } diff --git a/pipeline/spanner/src/main/resources/spanner_schema.sql b/pipeline/spanner/src/main/resources/spanner_schema.sql deleted file mode 100644 index c6932226..00000000 --- a/pipeline/spanner/src/main/resources/spanner_schema.sql +++ /dev/null @@ -1,34 +0,0 @@ -CREATE PROTO BUNDLE ( - `org.datacommons.Observations` -) - -CREATE TABLE Node ( - subject_id STRING(1024) NOT NULL, - value STRING(MAX), - bytes BYTES(MAX), - name STRING(MAX), - types ARRAY, - name_tokenlist TOKENLIST AS (TOKENIZE_FULLTEXT(name)) HIDDEN, -) PRIMARY KEY(subject_id) - -CREATE TABLE Edge ( - subject_id STRING(1024) NOT NULL, - predicate STRING(1024) NOT NULL, - object_id STRING(1024) NOT NULL, - provenance STRING(1024) NOT NULL, -) PRIMARY KEY(subject_id, predicate, object_id, provenance), -INTERLEAVE IN Node - -CREATE TABLE Observation ( - observation_about STRING(1024) NOT NULL, - variable_measured STRING(1024) NOT NULL, - facet_id STRING(1024) NOT NULL, - observation_period STRING(1024), - measurement_method STRING(1024), - unit STRING(1024), - scaling_factor STRING(1024), - observations org.datacommons.Observations, - import_name STRING(1024), - provenance_url STRING(1024), - is_dc_aggregate BOOL, -) PRIMARY KEY(observation_about, variable_measured, facet_id) \ No newline at end of file diff --git a/pipeline/terraform/main.tf b/pipeline/terraform/main.tf new file mode 100644 index 00000000..862636af --- /dev/null +++ b/pipeline/terraform/main.tf @@ -0,0 +1,379 @@ +# Terraform deployment for Data Commons Import Automation Workflow +# +# Usage: +# - Authenticate and set up application default credentials for Terraform to access GCP using 'gcloud auth login --update-adc'. +# - Obtain DataCommons API key: Get an API key portal https://apikeys.datacommons.org/ to be used as the `dc_api_key` variable. +# - Deploy the infrastructure and resources defined in this configuration using 'terraform apply'. +# - The output service account needs to have required permissions to access external resources. +# +# Input variables: +# - GCP project id +# - DC API key +# +# This file sets up: +# - Necessary GCP APIs +# - Secret Manager for the import-config secret +# - GCS Buckets for imports, mounting, and Dataflow templates +# - Spanner Instance and Database +# - Artifact Registry for hosting Docker images (Flex Template & Executor) +# - Pub/Sub Topic and Subscription for triggering imports +# - Cloud Functions, Workflows, and Ingestion Pipeline +# - Unified Service Account with necessary IAM roles for Workflows, Functions, and Pub/Sub + +terraform { + required_providers { + google = { + source = "hashicorp/google" + version = ">= 5.0.0" + } + archive = { + source = "hashicorp/archive" + } + } +} + +variable "project_id" { + description = "The GCP Project ID" + type = string +} + +variable "region" { + description = "The GCP Region" + type = string + default = "us-central1" +} + +variable "spanner_instance_id" { + description = "Spanner Instance ID" + type = string + default = "datcom-import-instance" +} + +variable "spanner_database_id" { + description = "Spanner Database ID" + type = string + default = "dc-import-db" +} + +variable "spanner_graph_database_id" { + description = "Spanner Graph Database ID" + type = string + default = "dc-import-db" +} + +variable "bq_dataset_id" { + description = "BigQuery Dataset ID for aggregation" + type = string + default = "datacommons" +} + +variable "dc_api_key" { + description = "Data Commons API Key" + type = string + sensitive = true +} + +variable "artifact_registry_url" { + description = "Artifact Registry URL for Cloud Run images" + type = string + default = "us-docker.pkg.dev/datcom-ci/gcr.io" +} + +# --- APIs --- + +locals { + services = [ + "artifactregistry.googleapis.com", + "batch.googleapis.com", + "cloudbuild.googleapis.com", + "cloudfunctions.googleapis.com", + "cloudscheduler.googleapis.com", + "compute.googleapis.com", + "dataflow.googleapis.com", + "iam.googleapis.com", + "pubsub.googleapis.com", + "run.googleapis.com", + "secretmanager.googleapis.com", + "spanner.googleapis.com", + "storage.googleapis.com", + "workflows.googleapis.com", + ] +} + +resource "google_project_service" "services" { + for_each = toset(local.services) + project = var.project_id + service = each.key + + disable_on_destroy = false +} + +# --- Secret Manager --- + +resource "google_secret_manager_secret" "import_config" { + secret_id = "import-config" + project = var.project_id + + replication { + auto {} + } + + depends_on = [google_project_service.services] +} + +resource "google_secret_manager_secret_version" "import_config_v1" { + secret = google_secret_manager_secret.import_config.id + secret_data = jsonencode({ + dc_api_key = var.dc_api_key + }) +} + +resource "google_secret_manager_secret" "dc_api_key" { + secret_id = "dc-api-key" + project = var.project_id + + replication { + auto {} + } + + depends_on = [google_project_service.services] +} + +resource "google_secret_manager_secret_version" "dc_api_key_v1" { + secret = google_secret_manager_secret.dc_api_key.id + secret_data = var.dc_api_key +} + +# --- GCS Buckets --- + +resource "google_storage_bucket" "import_bucket" { + name = "${var.project_id}-imports" + location = var.region + project = var.project_id + uniform_bucket_level_access = true + + depends_on = [google_project_service.services] +} + +resource "google_storage_bucket" "mount_bucket" { + name = "${var.project_id}-mount" + location = var.region + project = var.project_id + uniform_bucket_level_access = true + + depends_on = [google_project_service.services] +} + +# --- Cloud Functions Source Packaging --- + +# --- Cloud Functions --- + +resource "google_cloud_run_v2_service" "ingestion_helper" { + name = "ingestion-helper-service" + location = var.region + project = var.project_id + + template { + service_account = google_service_account.automation_sa.email + containers { + image = "${var.artifact_registry_url}/datacommons-ingestion-helper:latest" + env { + name = "PROJECT_ID" + value = var.project_id + } + env { + name = "SPANNER_PROJECT_ID" + value = var.project_id + } + env { + name = "SPANNER_INSTANCE_ID" + value = var.spanner_instance_id + } + env { + name = "SPANNER_DATABASE_ID" + value = var.spanner_database_id + } + env { + name = "SPANNER_GRAPH_DATABASE_ID" + value = var.spanner_graph_database_id + } + env { + name = "GCS_BUCKET_ID" + value = google_storage_bucket.import_bucket.name + } + env { + name = "LOCATION" + value = var.region + } + env { + name = "BQ_DATASET_ID" + value = var.bq_dataset_id + } + } + } + + depends_on = [google_project_service.services] +} + +resource "google_cloud_run_v2_service" "import_helper" { + name = "import-helper-service" + location = var.region + project = var.project_id + + template { + service_account = google_service_account.automation_sa.email + containers { + image = "${var.artifact_registry_url}/datacommons-import-helper:latest" + env { + name = "PROJECT_ID" + value = var.project_id + } + env { + name = "LOCATION" + value = var.region + } + env { + name = "GCS_BUCKET_ID" + value = google_storage_bucket.import_bucket.name + } + env { + name = "INGESTION_HELPER_URL" + value = google_cloud_run_v2_service.ingestion_helper.uri + } + } + } + + depends_on = [google_project_service.services] +} + +# --- Cloud Workflows --- + +resource "google_workflows_workflow" "import_automation_workflow" { + name = "import-automation-workflow" + region = var.region + project = var.project_id + description = "Orchestrates the import automation process" + service_account = google_service_account.automation_sa.id + source_contents = file("${path.module}/../workflow/import-automation-workflow.yaml") + + user_env_vars = { + LOCATION = var.region + GCS_BUCKET_ID = google_storage_bucket.import_bucket.name + GCS_MOUNT_BUCKET = google_storage_bucket.mount_bucket.name + INGESTION_HELPER_URL = google_cloud_run_v2_service.ingestion_helper.uri + } + + depends_on = [google_project_service.services] +} + +resource "google_workflows_workflow" "spanner_ingestion_workflow" { + name = "spanner-ingestion-workflow" + region = var.region + project = var.project_id + description = "Orchestrates Spanner ingestion" + service_account = google_service_account.automation_sa.id + source_contents = file("${path.module}/../workflow/spanner-ingestion-workflow.yaml") + + user_env_vars = { + LOCATION = var.region + PROJECT_ID = var.project_id + SPANNER_PROJECT_ID = var.project_id + SPANNER_INSTANCE_ID = var.spanner_instance_id + SPANNER_DATABASE_ID = var.spanner_database_id + INGESTION_HELPER_URL = google_cloud_run_v2_service.ingestion_helper.uri + } + + depends_on = [google_project_service.services] +} + +# --- Spanner --- + +resource "google_spanner_instance" "import_instance" { + name = var.spanner_instance_id + config = "regional-${var.region}" + display_name = "Import Automation" + num_nodes = 1 + project = var.project_id + + depends_on = [google_project_service.services] +} + +resource "google_spanner_database" "import_db" { + instance = google_spanner_instance.import_instance.name + name = var.spanner_database_id + project = var.project_id + deletion_protection = false +} + +# --- IAM --- + +resource "google_service_account" "automation_sa" { + account_id = "import-automation-sa" + display_name = "Service Account for Import Automation (Workflows & Functions)" + project = var.project_id +} + +resource "google_project_iam_member" "automation_roles" { + for_each = toset([ + "roles/workflows.admin", + "roles/cloudfunctions.admin", + "roles/run.admin", + "roles/run.invoker", + "roles/batch.jobsEditor", + "roles/dataflow.admin", + "roles/logging.logWriter", + "roles/storage.objectAdmin", + "roles/iam.serviceAccountUser", + "roles/spanner.databaseAdmin", + "roles/bigquery.dataEditor", + "roles/bigquery.jobUser", + "roles/artifactregistry.admin", + "roles/secretmanager.secretAccessor", + "roles/cloudbuild.builds.builder", + ]) + project = var.project_id + role = each.key + member = "serviceAccount:${google_service_account.automation_sa.email}" +} + +# --- Artifact Registry --- + +resource "google_artifact_registry_repository" "automation_repo" { + location = var.region + repository_id = "import-automation" + description = "Docker repository for import automation images" + format = "DOCKER" + project = var.project_id + + depends_on = [google_project_service.services] +} + +# --- Pub/Sub --- + +resource "google_pubsub_topic" "import_automation_trigger" { + name = "import-automation-trigger" + project = var.project_id +} + +resource "google_pubsub_subscription" "import_automation_sub" { + name = "import-automation-sub" + topic = google_pubsub_topic.import_automation_trigger.name + project = var.project_id + + filter = "attributes.transfer_status=\"TRANSFER_COMPLETED\"" + + push_config { + push_endpoint = google_cloud_run_v2_service.import_helper.uri + oidc_token { + service_account_email = google_service_account.automation_sa.email + } + } +} + +# Outputs +output "automation_service_account_email" { + value = google_service_account.automation_sa.email + description = "The email of the service account used for import automation." +} + + diff --git a/pipeline/workflow/cloudbuild.yaml b/pipeline/workflow/cloudbuild.yaml new file mode 100644 index 00000000..ed1f8720 --- /dev/null +++ b/pipeline/workflow/cloudbuild.yaml @@ -0,0 +1,50 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Child Cloud Build configuration for deploying to a specific environment. +# Defaults are set to PRODUCTION. Staging builds must override these values. +# +substitutions: + # Production config. + _PROJECT_ID: 'datcom-import-automation-prod' + _SPANNER_PROJECT_ID: 'datcom-store' + _SPANNER_INSTANCE_ID: 'dc-kg-test' + _SPANNER_DATABASE_ID: 'dc_graph_import' + _SPANNER_GRAPH_DATABASE_ID: 'dc_graph_2025_11_07' + _GCS_BUCKET_ID: 'datcom-prod-imports' + _LOCATION: 'us-central1' + _GCS_MOUNT_BUCKET: 'datcom-volume-mount' + _BQ_DATASET_ID: 'datacommons' + _PROJECT_NUMBER: '965988403328' + _AR_REPO_URL: 'us-docker.pkg.dev/datcom-ci/gcr.io' + +steps: +- id: 'ingestion-helper-service' + name: 'gcr.io/cloud-builders/gcloud' + args: ['run', 'deploy', 'ingestion-helper-service', '--image', '${_AR_REPO_URL}/datacommons-ingestion-helper:latest', '--region', '${_LOCATION}', '--project', '${_PROJECT_ID}', '--no-allow-unauthenticated', '--set-env-vars', 'PROJECT_ID=${_PROJECT_ID},LOCATION=${_LOCATION},SPANNER_PROJECT_ID=${_SPANNER_PROJECT_ID},SPANNER_INSTANCE_ID=${_SPANNER_INSTANCE_ID},SPANNER_DATABASE_ID=${_SPANNER_DATABASE_ID},SPANNER_GRAPH_DATABASE_ID=${_SPANNER_GRAPH_DATABASE_ID},GCS_BUCKET_ID=${_GCS_BUCKET_ID},BQ_DATASET_ID=${_BQ_DATASET_ID}'] + +- id: 'import-helper-service' + name: 'gcr.io/cloud-builders/gcloud' + args: ['run', 'deploy', 'import-helper-service', '--image', '${_AR_REPO_URL}/datacommons-import-helper:latest', '--region', '${_LOCATION}', '--project', '${_PROJECT_ID}', '--no-allow-unauthenticated', '--set-env-vars', 'PROJECT_ID=${_PROJECT_ID},LOCATION=${_LOCATION},PROJECT_NUMBER=${_PROJECT_NUMBER},GCS_BUCKET_ID=${_GCS_BUCKET_ID}'] + +- id: 'import-automation-workflow' + name: 'gcr.io/cloud-builders/gcloud' + args: ['workflows', 'deploy', 'import-automation-workflow', '--project', '${_PROJECT_ID}', '--location', '${_LOCATION}', '--source', 'import-automation-workflow.yaml', '--set-env-vars', 'LOCATION=${_LOCATION},GCS_BUCKET_ID=${_GCS_BUCKET_ID},GCS_MOUNT_BUCKET=${_GCS_MOUNT_BUCKET},PROJECT_NUMBER=${_PROJECT_NUMBER}'] + +- id: 'spanner-ingestion-workflow' + name: 'gcr.io/cloud-builders/gcloud' + args: ['workflows', 'deploy', 'spanner-ingestion-workflow', '--project', '${_PROJECT_ID}', '--location', '${_LOCATION}', '--source', 'spanner-ingestion-workflow.yaml', '--set-env-vars', 'LOCATION=${_LOCATION},PROJECT_ID=${_PROJECT_ID},SPANNER_PROJECT_ID=${_SPANNER_PROJECT_ID},SPANNER_INSTANCE_ID=${_SPANNER_INSTANCE_ID},SPANNER_DATABASE_ID=${_SPANNER_GRAPH_DATABASE_ID},PROJECT_NUMBER=${_PROJECT_NUMBER}'] + +options: + logging: CLOUD_LOGGING_ONLY diff --git a/pipeline/workflow/cloudbuild_main.yaml b/pipeline/workflow/cloudbuild_main.yaml new file mode 100644 index 00000000..8f7343a4 --- /dev/null +++ b/pipeline/workflow/cloudbuild_main.yaml @@ -0,0 +1,89 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Parent Cloud Build configuration that orchestrates Staging and Production deployments. +# Usage: gcloud builds submit . --config=cloudbuild.yaml --project=datcom-ci + +substitutions: + # Staging Configuration (Overrides defaults in child build) + _PROJECT_ID: 'datcom-ci' + _SPANNER_PROJECT_ID: 'datcom-ci' + _SPANNER_INSTANCE_ID: 'datcom-spanner-test' + _SPANNER_DATABASE_ID: 'dc-test-db' + _SPANNER_GRAPH_DATABASE_ID: 'dc-test-db' + _GCS_BUCKET_ID: 'datcom-ci-test' + _GCS_MOUNT_BUCKET: 'datcom-ci-test' + _BQ_DATASET_ID: 'datacommons' + _LOCATION: 'us-central1' + _PROJECT_NUMBER: '879489846695' + _AR_REPO_URL: 'us-docker.pkg.dev/datcom-ci/gcr.io' + +steps: + +# 1. Build and push helper images +- id: 'build-ingestion-helper' + name: 'gcr.io/cloud-builders/gcloud' + args: ['builds', 'submit', 'ingestion-helper', '--config', 'ingestion-helper/cloudbuild.yaml', '--substitutions', '_AR_REPO_URL=${_AR_REPO_URL}'] + dir: 'import-automation/workflow' + +- id: 'build-import-helper' + name: 'gcr.io/cloud-builders/gcloud' + args: ['builds', 'submit', 'import-helper', '--config', 'import-helper/cloudbuild.yaml', '--substitutions', '_AR_REPO_URL=${_AR_REPO_URL}'] + dir: 'import-automation/workflow' + +# 2. Trigger Staging Build (Child) +# Overrides default (Production) values with Staging values. +- id: 'deploy-staging' + name: 'gcr.io/cloud-builders/gcloud' + args: + - 'builds' + - 'submit' + - '.' + - '--config=cloudbuild.yaml' + - '--project=${_PROJECT_ID}' + - '--substitutions=_PROJECT_ID=${_PROJECT_ID},_SPANNER_PROJECT_ID=${_SPANNER_PROJECT_ID},_SPANNER_INSTANCE_ID=${_SPANNER_INSTANCE_ID},_SPANNER_DATABASE_ID=${_SPANNER_DATABASE_ID},_SPANNER_GRAPH_DATABASE_ID=${_SPANNER_GRAPH_DATABASE_ID},_GCS_BUCKET_ID=${_GCS_BUCKET_ID},_LOCATION=${_LOCATION},_GCS_MOUNT_BUCKET=${_GCS_MOUNT_BUCKET},_BQ_DATASET_ID=${_BQ_DATASET_ID},_PROJECT_NUMBER=${_PROJECT_NUMBER}' + dir: 'import-automation/workflow' + +# 2. Run E2E Tests on Staging +- id: 'e2e-test-staging' + name: 'python:3.11' + entrypoint: 'bash' + args: + - '-c' + - | + pip install google-cloud-spanner google-cloud-workflows absl-py + python spanner_ingestion_test.py + env: + - 'PROJECT_ID=${_PROJECT_ID}' + - 'LOCATION=${_LOCATION}' + - 'SPANNER_PROJECT_ID=${_SPANNER_PROJECT_ID}' + - 'SPANNER_INSTANCE_ID=${_SPANNER_INSTANCE_ID}' + - 'SPANNER_DATABASE_ID=${_SPANNER_DATABASE_ID}' + - 'GCS_BUCKET_ID=${_GCS_BUCKET_ID}' + dir: 'import-automation/workflow' + +# 3. Trigger Production Build (Child) +# Uses default (Production) values defined in cloudbuild.yaml. +- id: 'deploy-prod' + name: 'gcr.io/cloud-builders/gcloud' + args: + - 'builds' + - 'submit' + - '.' + - '--config=cloudbuild.yaml' + - '--project=${_PROJECT_ID}' # Build runs in CI project, deploys to Prod + dir: 'import-automation/workflow' + +options: + logging: CLOUD_LOGGING_ONLY diff --git a/pipeline/workflow/import-automation-workflow.yaml b/pipeline/workflow/import-automation-workflow.yaml new file mode 100644 index 00000000..d6198b8c --- /dev/null +++ b/pipeline/workflow/import-automation-workflow.yaml @@ -0,0 +1,120 @@ +main: + params: [args] + steps: + - init: + assign: + - projectId: ${sys.get_env("GOOGLE_CLOUD_PROJECT_ID")} + - region: ${sys.get_env("LOCATION")} + - imageUri: ${default(map.get(args, "imageUri"), "us-docker.pkg.dev/datcom-ci/gcr.io/dc-import-executor:stable")} + - jobId: ${text.replace_all(text.to_lower(text.substring(text.split(args.importName, ":")[1], 0, 50) + "-" + string(int(sys.now()))), "_", "-")} + - importName: ${args.importName} + - importConfig: ${default(map.get(args, "importConfig"), "{}")} + - gcsMountBucket: ${sys.get_env("GCS_MOUNT_BUCKET")} + - gcsImportBucket: ${sys.get_env("GCS_BUCKET_ID")} + - gcsMountPath: "/tmp/gcs" + - helperUrl: ${"https://ingestion-helper-service-" + sys.get_env("PROJECT_NUMBER") + "." + region + ".run.app"} + - startTime: ${sys.now()} + - defaultResources: + machine: "n2-standard-8" + cpu: 8000 + memory: 32768 + disk: 100 + - resources: ${default(map.get(args, "resources"), defaultResources)} + - runIngestion: ${default(map.get(args, "runIngestion"), false)} + - ingestionArgs: + importList: + - ${text.split(importName, ":")[1]} + - runImportJob: + try: + call: googleapis.batch.v1.projects.locations.jobs.create + args: + parent: ${"projects/" + projectId + "/locations/" + region} + jobId: ${jobId} + body: + allocationPolicy: + instances: + - policy: + machineType: ${resources.machine} + provisioningModel: "STANDARD" + bootDisk: + image: "projects/debian-cloud/global/images/family/debian-12" + size_gb: ${resources.disk} + installOpsAgent: true + taskGroups: + taskSpec: + volumes: + - gcs: + remotePath: ${gcsMountBucket} + mountPath: ${gcsMountPath} + computeResource: + cpuMilli: ${resources.cpu} + memoryMib: ${resources.memory} + runnables: + - container: + imageUri: ${imageUri} + commands: + - ${"--import_name=" + importName} + - ${"--import_config=" + importConfig} + environment: + variables: + IMPORT_NAME: ${importName} + BATCH_JOB_NAME: ${jobId} + taskCount: 1 + parallelism: 1 + logsPolicy: + destination: CLOUD_LOGGING + connector_params: + timeout: 604800 #7 days + polling_policy: + initial_delay: 60 + multiplier: 2 + max_delay: 600 + result: importJobResponse + except: + as: e + steps: + - updateImportStatus: + call: http.post + args: + url: ${helperUrl} + auth: + type: OIDC + body: + actionType: 'update_import_status' + jobId: ${jobId} + importName: ${importName} + status: 'FAILURE' + executionTime: ${int(sys.now() - startTime)} + latestVersion: ${"gs://" + gcsImportBucket + "/" + text.replace_all(importName, ":", "/")} + result: functionResponse + - failWorkflow: + raise: ${e} + - updateImportVersion: + call: http.post + args: + url: ${helperUrl} + auth: + type: OIDC + body: + actionType: 'update_import_version' + importName: ${importName} + version: 'STAGING' + override: false + comment: '${"import-workflow:" + sys.get_env("GOOGLE_CLOUD_WORKFLOW_EXECUTION_ID")}' + result: functionResponse + - runIngestion: + switch: + - condition: ${runIngestion} + steps: + - runSpannerIngestion: + call: googleapis.workflowexecutions.v1.projects.locations.workflows.executions.create + args: + parent: ${"projects/" + projectId + "/locations/" + region + "/workflows/spanner-ingestion-workflow"} + body: + argument: ${json.encode_to_string(ingestionArgs)} + connector_params: + skip_polling: true + - returnResult: + return: + jobId: ${jobId} + importName: ${importName} \ No newline at end of file diff --git a/pipeline/workflow/import-helper/Dockerfile b/pipeline/workflow/import-helper/Dockerfile new file mode 100644 index 00000000..2473221b --- /dev/null +++ b/pipeline/workflow/import-helper/Dockerfile @@ -0,0 +1,32 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM python:3.12-slim + +# Copy uv binary +COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ + +# Allow statements and log messages to immediately appear in the logs +ENV PYTHONUNBUFFERED True + +WORKDIR /app + +# Copy local code to the container image. +COPY . . + +# Install production dependencies using uv. +RUN uv pip install --system --no-cache -r requirements.txt + +# Run the functions framework +CMD ["functions-framework", "--target", "handle_feed_event"] diff --git a/pipeline/workflow/import-helper/cloudbuild.yaml b/pipeline/workflow/import-helper/cloudbuild.yaml new file mode 100644 index 00000000..760d7b6b --- /dev/null +++ b/pipeline/workflow/import-helper/cloudbuild.yaml @@ -0,0 +1,31 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +steps: + # Build the container image + - name: 'gcr.io/cloud-builders/docker' + args: ['build', '-t', '${_AR_REPO_URL}/${_IMAGE_NAME}:${_TAG}', '-t', '${_AR_REPO_URL}/${_IMAGE_NAME}:latest', '.'] + + # Push the container image + - name: 'gcr.io/cloud-builders/docker' + args: ['push', '${_AR_REPO_URL}/${_IMAGE_NAME}:${_TAG}'] + +substitutions: + _AR_REPO_URL: 'us-docker.pkg.dev/datcom-ci/gcr.io' + _IMAGE_NAME: 'datacommons-import-helper' + _TAG: 'latest' + +images: + - '${_AR_REPO_URL}/${_IMAGE_NAME}:${_TAG}' + - '${_AR_REPO_URL}/${_IMAGE_NAME}:latest' diff --git a/pipeline/workflow/import-helper/import_helper.py b/pipeline/workflow/import-helper/import_helper.py new file mode 100644 index 00000000..50aeaf08 --- /dev/null +++ b/pipeline/workflow/import-helper/import_helper.py @@ -0,0 +1,197 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import logging +import os +import croniter +from datetime import datetime, timezone +from google.auth.transport.requests import Request +from google.oauth2 import id_token +from google.cloud import storage +from google.cloud.workflows import executions_v1 +import requests + +logging.getLogger().setLevel(logging.INFO) + +PROJECT_ID = os.environ.get('PROJECT_ID') +PROJECT_NUMBER = os.environ.get('PROJECT_NUMBER') +LOCATION = os.environ.get('LOCATION') +GCS_BUCKET_ID = os.environ.get('GCS_BUCKET_ID') +INGESTION_HELPER_URL = f"https://ingestion-helper-service-{PROJECT_NUMBER}.{LOCATION}.run.app" +SPANNER_INGESTION_WORKFLOW_ID = 'spanner-ingestion-workflow' +IMPORT_AUTOMATION_WORKFLOW_ID = 'import-automation-workflow' + + +def invoke_spanner_ingestion_workflow(import_name: str): + """Triggers the spanner ingestion workflow. + + Args: + import_name: The name of the import. + """ + workflow_args = {"importList": [import_name.split(':')[-1]]} + + logging.info(f"Invoking {SPANNER_INGESTION_WORKFLOW_ID} for {import_name}") + execution_client = executions_v1.ExecutionsClient() + parent = f"projects/{PROJECT_ID}/locations/{LOCATION}/workflows/{SPANNER_INGESTION_WORKFLOW_ID}" + execution_req = executions_v1.Execution(argument=json.dumps(workflow_args)) + response = execution_client.create_execution(parent=parent, + execution=execution_req) + logging.info( + f"Triggered workflow {SPANNER_INGESTION_WORKFLOW_ID} for {import_name}. Execution ID: {response.name}" + ) + + +def invoke_import_automation_workflow(import_name: str, + latest_version: str, + import_size: str, + graph_path: str, cron_schedule: str, + run_ingestion: bool = False): + """Triggers the import automation workflow. + + Args: + import_name: The name of the import. + latest_version: The version of the import. + import_size: The size of the import ('small', 'medium', 'large'). + graph_path: The graph path for the import. + cron_schedule: The cron schedule for the import. + run_ingestion: Whether to run the ingestion workflow after the import. + """ + import_config = { + "user_script_args": [f"--version={latest_version}"], + "import_version_override": latest_version, + "graph_data_path": graph_path, + "cron_schedule_override": cron_schedule + } + workflow_args = { + "importName": import_name, + "importConfig": json.dumps(import_config), + "runIngestion": run_ingestion + } + + if import_size == 'large': + workflow_args["resources"] = { + "machine": "n2-highmem-16", + "cpu": 16000, + "memory": 131072, + "disk": 100 + } + + logging.info(f"Invoking {IMPORT_AUTOMATION_WORKFLOW_ID} for {import_name}") + execution_client = executions_v1.ExecutionsClient() + parent = f"projects/{PROJECT_ID}/locations/{LOCATION}/workflows/{IMPORT_AUTOMATION_WORKFLOW_ID}" + execution_req = executions_v1.Execution(argument=json.dumps(workflow_args)) + response = execution_client.create_execution(parent=parent, + execution=execution_req) + logging.info( + f"Triggered workflow {IMPORT_AUTOMATION_WORKFLOW_ID} for {import_name}. Execution ID: {response.name}" + ) + + +def update_import_status(import_name, + import_status, + import_version, + graph_path, + job_id, + cron_schedule=None): + """Updates the status for the specified import job. + + Args: + import_name: The name of the import. + import_status: The new status of the import. + import_version: The version of the import. + graph_path: The graph path for the import. + job_id: The job ID associated with the import. + cron_schedule: The cron schedule for the import (optional). + """ + logging.info(f"Updating {import_name} status: {import_status}") + latest_version = 'gs://' + GCS_BUCKET_ID + '/' + import_name.replace( + ':', '/') + '/' + import_version + request = { + 'actionType': 'update_import_status', + 'importName': import_name, + 'status': import_status, + 'job_id': job_id, + 'latestVersion': latest_version, + 'graphPath': graph_path + } + if cron_schedule: + try: + next_refresh = croniter.croniter( + cron_schedule, + datetime.now(timezone.utc)).get_next(datetime).isoformat() + request['nextRefresh'] = next_refresh + except (croniter.CroniterError) as e: + logging.error( + f"Error calculating next refresh from schedule '{cron_schedule}': {e}" + ) + logging.info(f"Update request: {request}") + auth_req = Request() + token = id_token.fetch_id_token(auth_req, INGESTION_HELPER_URL) + headers = {'Authorization': f'Bearer {token}'} + response = requests.post(INGESTION_HELPER_URL, + json=request, + headers=headers) + response.raise_for_status() + logging.info(f"Updated status for {import_name}") + + +def parse_message(request) -> dict: + """Processes the incoming Pub/Sub message. + + Args: + request: The flask request object. + + Returns: + A dictionary containing the message data, or None if invalid. + """ + request_json = request.get_json(silent=True) + if not request_json or 'message' not in request_json: + logging.error('Invalid Pub/Sub message format') + return None + + pubsub_message = request_json['message'] + logging.info(f"Received Pub/Sub message: {pubsub_message}") + try: + data_bytes = base64.b64decode(pubsub_message["data"]) + notification_json = data_bytes.decode("utf-8") + logging.info(f"Notification content: {notification_json}") + except Exception as e: + logging.error(f"Error decoding message data: {e}") + + return pubsub_message + + +def check_duplicate(message_id: str): + """Checks for duplicate messages using a GCS file. + + Args: + message_id: The ID of the message to check. + + Returns: + True if the message is a duplicate, False otherwise. + """ + duplicate = False + if not message_id: + return duplicate + logging.info(f"Checking for existing message: {message_id}") + storage_client = storage.Client() + bucket = storage_client.bucket(GCS_BUCKET_ID) + blob = bucket.blob(f"google3/transfers/{message_id}") + try: + blob.upload_from_string("", if_generation_match=0) + except Exception: + duplicate = True + return duplicate diff --git a/pipeline/workflow/import-helper/main.py b/pipeline/workflow/import-helper/main.py new file mode 100644 index 00000000..c825ec2b --- /dev/null +++ b/pipeline/workflow/import-helper/main.py @@ -0,0 +1,67 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functions_framework +import logging +from datetime import datetime, timezone +import import_helper as helper + +logging.getLogger().setLevel(logging.INFO) + + +# Triggered from a message on a Cloud Pub/Sub topic. +@functions_framework.http +def handle_feed_event(request): + # Updates status in spanner and triggers ingestion workflow + # for an import using CDA feed + message = helper.parse_message(request) + if not message: + return 'Invalid Pub/Sub message format', 400 + + attributes = message.get('attributes', {}) + message_id = message.get('messageId', '') + if attributes.get('transfer_status') != 'TRANSFER_COMPLETED': + return 'OK', 200 + + duplicate = helper.check_duplicate(message_id) + if duplicate: + logging.info(f"Message {message_id} already processed. Skipping.") + return 'OK', 200 + + import_name = attributes.get('import_name') + latest_version = attributes.get( + 'import_version', + datetime.now(timezone.utc).strftime("%Y-%m-%d")) + import_step = attributes.get('import_step', '') + graph_path = attributes.get('graph_path', "/**/*.mcf*") + import_size = attributes.get('import_size', '') + cron_schedule = attributes.get('cron_schedule', '') + if import_step == 'ingestion_workflow_single' or import_step == 'ingestion_workflow_batch': + import_status = 'STAGING' + job_id = attributes.get('feed_name', 'cda_feed') + helper.update_import_status(import_name, import_status, latest_version, + graph_path, job_id, cron_schedule) + if import_step == 'ingestion_workflow_single': + # Invoke ingestion workflow to trigger dataflow job + helper.invoke_spanner_ingestion_workflow(import_name) + elif import_step == 'import_automation_job' or import_step == 'import_automation_e2e': + # Invoke batch import job and optionally ingestion workflow to trigger dataflow job + run_ingestion = True if import_step == 'import_automation_e2e' else False + helper.invoke_import_automation_workflow(import_name, latest_version, + import_size, graph_path, + cron_schedule, run_ingestion) + else: + logging.info(f"Skipping import post processing.") + + return 'OK', 200 diff --git a/pipeline/workflow/import-helper/requirements.txt b/pipeline/workflow/import-helper/requirements.txt new file mode 100644 index 00000000..9d321e81 --- /dev/null +++ b/pipeline/workflow/import-helper/requirements.txt @@ -0,0 +1,6 @@ +functions-framework==3.* +google-cloud-workflows +google-auth +requests +google-cloud-storage +croniter diff --git a/pipeline/workflow/ingestion-helper/Dockerfile b/pipeline/workflow/ingestion-helper/Dockerfile new file mode 100644 index 00000000..7f95ae8b --- /dev/null +++ b/pipeline/workflow/ingestion-helper/Dockerfile @@ -0,0 +1,41 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +FROM python:3.12-slim + +# Copy uv binary +COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/ + +# Allow statements and log messages to immediately appear in the logs +ENV PYTHONUNBUFFERED True + +WORKDIR /app + +# Install protobuf compiler and curl +RUN apt-get update && apt-get install -y protobuf-compiler curl && rm -rf /var/lib/apt/lists/* + +# Copy local code to the container image. +COPY . . + +# Install production dependencies using uv. +RUN uv pip install --system --no-cache . + +# Fetch proto file from GitHub +RUN curl -o storage.proto https://raw.githubusercontent.com/datacommonsorg/import/master/pipeline/data/src/main/proto/storage.proto + +# Generate proto descriptor set +RUN protoc --include_imports --descriptor_set_out=storage.pb storage.proto + +# Run the functions framework +CMD ["functions-framework", "--target", "ingestion_helper"] diff --git a/pipeline/workflow/ingestion-helper/README.md b/pipeline/workflow/ingestion-helper/README.md new file mode 100644 index 00000000..7de6d49a --- /dev/null +++ b/pipeline/workflow/ingestion-helper/README.md @@ -0,0 +1,98 @@ +# Ingestion Helper Cloud Function + +This Cloud Function provides helper routines for the Data Commons Spanner ingestion workflow. It handles tasks such as locking, status updates, and import list retrieval. + +## Usage + +The function expects a JSON payload with a required `actionType` parameter, which determines the operation to perform. + +### Common Parameters + +* `actionType` (Required): A string specifying the action to execute. + +### Supported Actions and Parameters + +#### `get_import_info` +Gets the details of imports that are ready for ingestion. + +* `importList` (Optional): list of imports to ingest. + +#### `acquire_ingestion_lock` +Attempts to acquire the global lock for ingestion to prevent concurrent modifications. + +* `workflowId` (Required): The ID of the workflow attempting to acquire the lock. +* `timeout` (Required): The duration (in seconds) for which the lock should be held. + +#### `release_ingestion_lock` +Releases the global ingestion lock. + +* `workflowId` (Required): The ID of the workflow releasing the lock. + +#### `update_ingestion_status` +Updates the status of imports after an ingestion job completes. + +* `importList` (Required): A list of import names involved in the ingestion. +* `workflowId` (Required): The ID of the workflow. +* `status` (Required): Import status. +* `jobId` (Required): The Dataflow job ID associated with the ingestion. + +#### `update_import_status` +Updates the status of a specific import job. + +* `importName` (Required): The name of the import. +* `status` (Required): The new status to set. +* `jobId` (Optional): The Dataflow job ID. +* `executionTime` (Optional): Execution time in seconds. +* `dataVolume` (Optional): Data volume in bytes. +* `latestVersion` (Optional): Latest version string. +* `graphPath` (Optional): Graph path regex. +* `nextRefresh` (Optional): Next refresh timestamp. + + +#### `update_import_version` +Updates the version of an import, records version history, and updates the status. + +* `importName` (Required): The name of the import. +* `version` (Required): The version string. If set to `'STAGING'`, it resolves to the current staging version. +* `comment` (Required): A comment for the audit log explaining the version update. +* `override` (Optional): Override version without checking import status (boolean) + +#### `initialize_database` +Initializes the Spanner database by creating all necessary tables and uploading proto descriptors. + +* This action requires no payload parameters. It automatically reads `schema.sql` and `storage.pb` from the container directory to provision the database schema and proto descriptors. +* `enableEmbeddings` (Optional): Boolean to enable creation of embedding tables and models. +* **Note on Protos**: The `storage.pb` file is generated during the Docker build process. The `Dockerfile` fetches `storage.proto` from the `datacommonsorg/import` GitHub repository and compiles it into `storage.pb`. + +#### `embedding_ingestion` +Triggers the generation of embeddings for updated nodes in Spanner. It fetches nodes of specific types (e.g., `StatisticalVariable`, `Topic`) that have been updated, generates embeddings using a remote ML model in Spanner, and stores the results in the `NodeEmbedding` table. + +* `enableEmbeddings` (Optional): Boolean to override the default setting for enabling embeddings. If false or missing and default is false, it skips embedding generation. +* **Flags**: + - `--node_types`: A comma-separated list of node types to process (default: `StatisticalVariable,Topic`). This is a command-line flag for the service, not a request parameter. + +## Local Development and Testing + +To run the helper service locally and test its functionality: + +### Running the Server +Ensure you have installed the requirements (`uv pip install -r requirements.txt`), then start the functions framework: + +```bash +uv run functions-framework --target ingestion_helper +``` +By default, this will start serving on `http://localhost:8080`. + +### Triggering Actions +You can test specific actions by sending a POST request with a JSON payload. For example, to trigger database initialization locally: +```bash +curl -X POST http://localhost:8080 \ + -H "Content-Type: application/json" \ + -d '{"actionType": "initialize_database"}' +``` +### Running unit tests +Run unit tests with uv using: + +```bash +uv run pytest +``` diff --git a/pipeline/workflow/ingestion-helper/aggregation_utils.py b/pipeline/workflow/ingestion-helper/aggregation_utils.py new file mode 100644 index 00000000..16fe06c7 --- /dev/null +++ b/pipeline/workflow/ingestion-helper/aggregation_utils.py @@ -0,0 +1,69 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from google.cloud import bigquery + +logging.getLogger().setLevel(logging.INFO) + +class AggregationUtils: + def __init__(self): + # Initialize BigQuery Client + try: + self.bq_client = bigquery.Client() + except Exception as e: + logging.warning(f"Failed to initialize BigQuery client: {e}") + self.bq_client = None + + self.bq_dataset_id = os.environ.get('BQ_DATASET_ID') + + def run_aggregation(self, import_list): + """ + Runs a BQ query for each import in the import_list. + """ + logging.info(f"Received request for importList: {import_list}") + results = [] + if not self.bq_client: + logging.error("BigQuery client not initialized") + return False + + try: + for import_item in import_list: + import_name = import_item.get('importName') + + query = None + # Define specific queries based on importName + if import_name: + query = """ + SELECT @import_name as import_name, CURRENT_TIMESTAMP() as execution_time + """ + else: + logging.info('Skipping aggregation logic') + continue + + if query: + job_config = bigquery.QueryJobConfig(query_parameters=[ + bigquery.ScalarQueryParameter("import_name", "STRING", + import_name), + ]) + query_job = self.bq_client.query(query, job_config=job_config) + query_results = query_job.result() + for row in query_results: + results.append(dict(row)) + return True + + except Exception as e: + logging.error(f"Aggregation failed: {e}") + raise e diff --git a/pipeline/workflow/ingestion-helper/cloudbuild.yaml b/pipeline/workflow/ingestion-helper/cloudbuild.yaml new file mode 100644 index 00000000..632b3bf1 --- /dev/null +++ b/pipeline/workflow/ingestion-helper/cloudbuild.yaml @@ -0,0 +1,31 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +steps: + # Build the container image + - name: 'gcr.io/cloud-builders/docker' + args: ['build', '-t', '${_AR_REPO_URL}/${_IMAGE_NAME}:${_TAG}', '-t', '${_AR_REPO_URL}/${_IMAGE_NAME}:latest', '.'] + + # Push the container image + - name: 'gcr.io/cloud-builders/docker' + args: ['push', '${_AR_REPO_URL}/${_IMAGE_NAME}:${_TAG}'] + +substitutions: + _AR_REPO_URL: 'us-docker.pkg.dev/datcom-ci/gcr.io' + _IMAGE_NAME: 'datacommons-ingestion-helper' + _TAG: 'latest' + +images: + - '${_AR_REPO_URL}/${_IMAGE_NAME}:${_TAG}' + - '${_AR_REPO_URL}/${_IMAGE_NAME}:latest' diff --git a/pipeline/workflow/ingestion-helper/embedding_schema.sql b/pipeline/workflow/ingestion-helper/embedding_schema.sql new file mode 100644 index 00000000..50d57ce3 --- /dev/null +++ b/pipeline/workflow/ingestion-helper/embedding_schema.sql @@ -0,0 +1,44 @@ +-- Copyright 2026 Google LLC +-- +-- Licensed under the Apache License, Version 2.0 (the "License") +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +CREATE TABLE NodeEmbedding ( + subject_id STRING(1024) NOT NULL, + embedding_content STRING(MAX), + types ARRAY, + embeddings ARRAY(vector_length=>768) +) PRIMARY KEY(subject_id), +INTERLEAVE IN PARENT Node ON DELETE CASCADE; + +CREATE VECTOR INDEX NodeEmbeddingIndex +ON NodeEmbedding(embeddings) +WHERE embeddings IS NOT NULL +OPTIONS ( + distance_type = 'COSINE', + flat_index = true +); + +CREATE MODEL NodeEmbeddingModel +INPUT( + content STRING(MAX), + task_type STRING(MAX), +) +OUTPUT( + embeddings + STRUCT< + statistics STRUCT, + values ARRAY> +) +REMOTE OPTIONS ( + endpoint = '{{ embeddings_endpoint }}' +); diff --git a/pipeline/workflow/ingestion-helper/embedding_utils.py b/pipeline/workflow/ingestion-helper/embedding_utils.py new file mode 100644 index 00000000..333705e3 --- /dev/null +++ b/pipeline/workflow/ingestion-helper/embedding_utils.py @@ -0,0 +1,169 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper utilities for embedding workflows.""" + +import itertools +import logging +import time +from datetime import datetime +from google.cloud.spanner_v1.param_types import TIMESTAMP, STRING, Array, Struct, StructField + + +_BATCH_SIZE = 1000 + +def get_latest_lock_timestamp(database): + """Gets the latest AcquiredTimestamp from IngestionLock table. + + Args: + database: google.cloud.spanner.Database object. + + Returns: + The latest AcquiredTimestamp as a datetime object, or None if no entries exist. + """ + time_lock_sql = "SELECT MAX(AcquiredTimestamp) FROM IngestionLock" + try: + with database.snapshot() as snapshot: + results = snapshot.execute_sql(time_lock_sql) + for row in results: + return row[0] + except Exception as e: + logging.error(f"Error fetching latest lock timestamp: {e}") + raise + return None + +def get_updated_nodes(database, timestamp, node_types): + """Gets subject_ids and names from Node table where update_timestamp > timestamp. + Yields results to avoid loading all into memory. + + Args: + database: google.cloud.spanner.Database object. + timestamp: datetime object to filter by. + node_types: A list of strings representing the node types to filter by. + + Yields: + Dictionaries containing subject_id and name. + """ + timestamp_condition = "update_timestamp > @timestamp" if timestamp else "TRUE" + + updated_node_sql = f""" + SELECT subject_id, name, types FROM Node + WHERE name IS NOT NULL + AND {timestamp_condition} + AND EXISTS ( + SELECT 1 FROM UNNEST(types) AS t WHERE t IN UNNEST(@node_types) + ) + """ + + params = {"node_types": node_types} + param_types = {"node_types": Array(STRING)} + + if timestamp: + logging.info(f"Filtering valid nodes updated after {timestamp}") + params["timestamp"] = timestamp + param_types["timestamp"] = TIMESTAMP + else: + logging.info("No timestamp provided, reading all valid nodes.") + + try: + with database.snapshot() as snapshot: + results = snapshot.execute_sql(updated_node_sql, params=params, param_types=param_types, timeout=300) + fields = None + for row in results: + if fields is None: + fields = [field.name for field in results.fields] + yield dict(zip(fields, row)) + except Exception as e: + logging.error(f"Error fetching updated nodes: {e}") + raise + + +def filter_and_convert_nodes(nodes_generator): + """Filters out nodes without a name and converts dictionaries to tuples. + Reads from a generator and yields results. + + Args: + nodes_generator: A generator yielding dictionaries containing subject_id, name, and types. + + Yields: + Tuples (subject_id, embedding_content, types). + """ + for node in nodes_generator: + if node.get("name"): + yield (node.get("subject_id"), node.get("name"), node.get("types")) + + +def generate_embeddings_partitioned(database, nodes_generator): + """Generates embeddings in batches using standard transactions. + Processes nodes in chunks of 500 to avoid transaction size limits. + Accepts a generator to avoid loading all nodes into memory. + + Args: + database: google.cloud.spanner.Database object. + nodes_generator: A generator yielding tuples containing (subject_id, embedding_content). + + Returns: + The number of affected rows. + """ + global _BATCH_SIZE + total_rows_affected = 0 + + logging.info(f"Generating embeddings in batches of {_BATCH_SIZE}.") + + embeddings_sql = """ + INSERT OR UPDATE INTO NodeEmbedding (subject_id, embedding_content, embeddings, types) + SELECT subject_id, content, embeddings.values, types + FROM ML.PREDICT( + MODEL NodeEmbeddingModel, + (SELECT subject_id, embedding_content AS content, types, "RETRIEVAL_QUERY" AS task_type FROM UNNEST(@nodes)) + ) + """ + + struct_type = Struct([ + StructField("subject_id", STRING), + StructField("embedding_content", STRING), + StructField("types", Array(STRING)) + ]) + + def chunked(iterable, n): + it = iter(iterable) + while True: + chunk = list(itertools.islice(it, n)) + if not chunk: + break + yield chunk + + for batch in chunked(nodes_generator, _BATCH_SIZE): + params = {"nodes": batch} + param_types = {"nodes": Array(struct_type)} + + def _execute_dml(transaction): + return transaction.execute_update(embeddings_sql, params=params, param_types=param_types, timeout=300) + + try: + row_count = database.run_in_transaction(_execute_dml) + total_rows_affected += row_count + logging.info(f"Processed batch of {len(batch)} nodes. Affected total {total_rows_affected} rows.") + time.sleep(0.5) + except Exception as e: + logging.error(f"Error executing batch transaction: {e}") + raise + + logging.info(f"Completed batch processing. Total affected rows: {total_rows_affected}") + return total_rows_affected + + + + + diff --git a/pipeline/workflow/ingestion-helper/embedding_utils_test.py b/pipeline/workflow/ingestion-helper/embedding_utils_test.py new file mode 100644 index 00000000..299b293d --- /dev/null +++ b/pipeline/workflow/ingestion-helper/embedding_utils_test.py @@ -0,0 +1,166 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +from datetime import datetime + +from embedding_utils import ( + get_latest_lock_timestamp, + get_updated_nodes, + filter_and_convert_nodes, + generate_embeddings_partitioned +) + +class TestEmbeddingUtils(unittest.TestCase): + + def test_get_latest_lock_timestamp(self): + mock_database = MagicMock() + mock_snapshot = MagicMock() + mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot + expected_timestamp = datetime(2026, 4, 20, 12, 0, 0) + mock_snapshot.execute_sql.return_value = [(expected_timestamp,)] + + timestamp = get_latest_lock_timestamp(mock_database) + self.assertEqual(timestamp, expected_timestamp) + + def test_get_updated_nodes(self): + mock_database = MagicMock() + mock_snapshot = MagicMock() + mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot + + class MockField: + def __init__(self, name): + self.name = name + + class MockResults: + def __init__(self, rows, field_names): + self.rows = rows + self.fields = [MockField(name) for name in field_names] + + def __iter__(self): + return iter(self.rows) + + mock_snapshot.execute_sql.return_value = MockResults( + rows=[("dc/1", "Node 1", ["Topic"])], + field_names=["subject_id", "name", "types"] + ) + + nodes = list(get_updated_nodes(mock_database, None, ["Topic"])) + + # Verify Spanner call + mock_snapshot.execute_sql.assert_called_once() + args, kwargs = mock_snapshot.execute_sql.call_args + query = args[0] + self.assertIn("SELECT subject_id, name, types FROM Node", query) + self.assertIn("TRUE", query) + self.assertEqual(kwargs["params"], {"node_types": ["Topic"]}) + + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0]["subject_id"], "dc/1") + self.assertEqual(nodes[0]["name"], "Node 1") + self.assertEqual(nodes[0]["types"], ["Topic"]) + + def test_get_updated_nodes_with_timestamp(self): + mock_database = MagicMock() + mock_snapshot = MagicMock() + mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot + + class MockField: + def __init__(self, name): + self.name = name + + class MockResults: + def __init__(self, rows, field_names): + self.rows = rows + self.fields = [MockField(name) for name in field_names] + + def __iter__(self): + return iter(self.rows) + + mock_snapshot.execute_sql.return_value = MockResults( + rows=[("dc/2", "Node 2", ["Topic"])], + field_names=["subject_id", "name", "types"] + ) + + test_timestamp = datetime(2026, 4, 25, 0, 0, 0) + nodes = list(get_updated_nodes(mock_database, test_timestamp, ["Topic"])) + + # Verify Spanner call + mock_snapshot.execute_sql.assert_called_once() + args, kwargs = mock_snapshot.execute_sql.call_args + query = args[0] + self.assertIn("SELECT subject_id, name, types FROM Node", query) + self.assertIn("update_timestamp > @timestamp", query) + self.assertEqual(kwargs["params"], {"node_types": ["Topic"], "timestamp": test_timestamp}) + + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0]["subject_id"], "dc/2") + + def test_filter_and_convert_nodes(self): + nodes = [ + {"subject_id": "dc/1", "name": "Node 1", "types": ["Topic"]}, + {"subject_id": "dc/2", "name": None, "types": ["StatisticalVariable"]}, + {"subject_id": "dc/3", "name": "Node 3", "types": ["Topic", "StatisticalVariable"]}, + {"subject_id": "dc/4", "name": "", "types": ["StatisticalVariable"]} + ] + + converted = list(filter_and_convert_nodes(nodes)) + self.assertEqual(len(converted), 2) + self.assertEqual(converted[0], ("dc/1", "Node 1", ["Topic"])) + self.assertEqual(converted[1], ("dc/3", "Node 3", ["Topic", "StatisticalVariable"])) + + @patch('embedding_utils._BATCH_SIZE', 2) + def test_generate_embeddings_partitioned(self): + mock_database = MagicMock() + + nodes = [ + ("dc/1", "Node 1", ["Topic"]), + ("dc/2", "Node 2", ["Topic"]), + ("dc/3", "Node 3", ["Topic"]), + ("dc/4", "Node 4", ["Topic"]), + ("dc/5", "Node 5", ["Topic"]), + ("dc/6", "Node 6", ["Topic"]), + ("dc/7", "Node 7", ["Topic"]), + ("dc/8", "Node 8", ["Topic"]) + ] + + transactions = [] + def side_effect(func): + mock_transaction = MagicMock() + mock_transaction.execute_update.return_value = 2 + transactions.append(mock_transaction) + return func(mock_transaction) + + mock_database.run_in_transaction.side_effect = side_effect + + affected_rows = generate_embeddings_partitioned(mock_database, nodes) + self.assertEqual(affected_rows, 8) + self.assertEqual(mock_database.run_in_transaction.call_count, 4) + + # Verify execute_update calls + self.assertEqual(len(transactions), 4) + for i, tx in enumerate(transactions): + tx.execute_update.assert_called_once() + args, kwargs = tx.execute_update.call_args + self.assertIn("INSERT OR UPDATE INTO NodeEmbedding", args[0]) + + # Verify batch content + batch = kwargs["params"]["nodes"] + self.assertEqual(len(batch), 2) + self.assertEqual(batch[0][0], f"dc/{i*2 + 1}") + self.assertEqual(batch[1][0], f"dc/{i*2 + 2}") + +if __name__ == '__main__': + unittest.main() diff --git a/pipeline/workflow/ingestion-helper/import_utils.py b/pipeline/workflow/ingestion-helper/import_utils.py new file mode 100644 index 00000000..33f9d1fa --- /dev/null +++ b/pipeline/workflow/ingestion-helper/import_utils.py @@ -0,0 +1,175 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions for the ingestion helper.""" + +import logging +import re +from datetime import datetime, timezone +from googleapiclient.discovery import build +from googleapiclient.errors import HttpError +from google.oauth2 import id_token +from google.auth.transport import requests +from google.auth import jwt + + +def get_next_refresh(project_id: str, location: str, import_name: str) -> str: + """Fetches the next scheduled run time for the import job from Cloud Scheduler. + + Args: + project_id: The GCP project ID. + location: The location of the Cloud Scheduler job. + import_name: The name of the import (used as the job name). + + Returns: + The next scheduled run time as an ISO formatted string, or None if not found/error. + """ + try: + scheduler = build('cloudscheduler', 'v1', cache_discovery=False) + job_id = import_name.split(':')[-1] + job_name = f"projects/{project_id}/locations/{location}/jobs/{job_id}" + job = scheduler.projects().locations().jobs().get( + name=job_name).execute() + return job.get('scheduleTime') + except HttpError as e: + logging.warning(f"Could not fetch scheduler job {import_name}: {e}") + return None + + +def get_caller_identity(request): + """Extracts the caller's email from the Authorization header (JWT). + + Args: + request: The HTTP request object. + + Returns: + The email of the caller, or an error string/warning if extraction fails. + """ + auth_header = request.headers.get('Authorization') + if auth_header: + parts = auth_header.split() + if len(parts) == 2 and parts[0].lower() == 'bearer': + token = parts[1] + unverified_claims = {} + try: + unverified_claims = jwt.decode(token, verify=False) + id_info = id_token.verify_oauth2_token(token, + requests.Request()) + return id_info.get('email', 'unknown_email') + except Exception as e: + if unverified_claims: + logging.warning( + f"Could not decode unverified token for debugging: {e}") + email = unverified_claims.get('email', 'unknown_email') + return f"{email}" + return 'decode_error' + else: + logging.warning( + f"Invalid Authorization header format. Parts: {len(parts)}") + else: + logging.warning("No Authorization header received.") + return 'no_auth_header' + + +def get_import_params(request) -> dict: + """Extracts and calculates import parameters from the request JSON. + + Args: + request_json: A dictionary containing request parameters. + + Returns: + A dictionary with import params. + """ + # Convert CamelCase or mixedCase to snake_case. + request_json = { + re.sub(r'(?, + last_update_timestamp TIMESTAMP OPTIONS (allow_commit_timestamp=true), + name_tokenlist TOKENLIST AS (TOKENIZE_FULLTEXT(name)) HIDDEN, +) PRIMARY KEY(subject_id); + +CREATE TABLE Edge ( + subject_id STRING(1024) NOT NULL, + predicate STRING(1024) NOT NULL, + object_id STRING(1024) NOT NULL, + provenance STRING(1024) NOT NULL, +) PRIMARY KEY(subject_id, predicate, object_id, provenance), +INTERLEAVE IN Node; + +CREATE TABLE Observation ( + observation_about STRING(1024) NOT NULL, + variable_measured STRING(1024) NOT NULL, + facet_id STRING(1024) NOT NULL, + observation_period STRING(1024), + measurement_method STRING(1024), + unit STRING(1024), + scaling_factor STRING(1024), + observations org.datacommons.Observations, + import_name STRING(1024), + provenance_url STRING(1024), + is_dc_aggregate BOOL, +) PRIMARY KEY(observation_about, variable_measured, facet_id); + +CREATE TABLE ImportStatus ( + ImportName STRING(MAX) NOT NULL, + LatestVersion STRING(MAX), + GraphPath STRING(MAX), + State STRING(1024) NOT NULL, + JobId STRING(1024), + WorkflowId STRING(1024), + ExecutionTime INT64, + DataVolume INT64, + DataImportTimestamp TIMESTAMP OPTIONS ( allow_commit_timestamp = TRUE ), + StatusUpdateTimestamp TIMESTAMP OPTIONS ( allow_commit_timestamp = TRUE ), + NextRefreshTimestamp TIMESTAMP, +) PRIMARY KEY(ImportName); + +CREATE TABLE IngestionHistory ( + CompletionTimestamp TIMESTAMP NOT NULL OPTIONS ( allow_commit_timestamp = TRUE ), + IngestionFailure Bool NOT NULL, + WorkflowExecutionID STRING(1024) NOT NULL, + DataflowJobID STRING(1024), + IngestedImports ARRAY, + ExecutionTime INT64, + NodeCount INT64, + EdgeCount INT64, + ObservationCount INT64, +) PRIMARY KEY(CompletionTimestamp DESC); + +CREATE TABLE ImportVersionHistory ( + ImportName STRING(MAX) NOT NULL, + Version STRING(MAX) NOT NULL, + UpdateTimestamp TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true), + Comment STRING(MAX), +) PRIMARY KEY (ImportName, UpdateTimestamp DESC); + +CREATE TABLE IngestionLock ( + LockID STRING(1024) NOT NULL, + LockOwner STRING(1024), + AcquiredTimestamp TIMESTAMP OPTIONS ( allow_commit_timestamp = TRUE ), +) PRIMARY KEY(LockID); + +CREATE PROPERTY GRAPH DCGraph + NODE TABLES( + Node + KEY(subject_id) + LABEL Node PROPERTIES( + bytes, + name, + subject_id, + types, + value) + ) + EDGE TABLES( + Edge + KEY(subject_id, predicate, object_id, provenance) + SOURCE KEY(subject_id) REFERENCES Node(subject_id) + DESTINATION KEY(object_id) REFERENCES Node(subject_id) + LABEL Edge PROPERTIES( + object_id, + predicate, + provenance, + subject_id) + ); + +CREATE TABLE Cache ( + type STRING(1024) NOT NULL, + key STRING(1024) NOT NULL, + provenance STRING(1024) NOT NULL, + value JSON, +) PRIMARY KEY(type, key, provenance); + +CREATE TABLE VariableMetadata ( + variable_measured STRING(1024) NOT NULL, + import_name STRING(1024) NOT NULL, + facet_id STRING(1024) NOT NULL, + observation_period STRING(1024), + measurement_method STRING(1024), + unit STRING(1024), + scaling_factor STRING(1024), + is_dc_aggregate BOOL, + total_observations INT64, + observed_places INT64, + min_date STRING(1024), + max_date STRING(1024), + place_types ARRAY, +) PRIMARY KEY(variable_measured, import_name); + + diff --git a/pipeline/workflow/ingestion-helper/spanner_client.py b/pipeline/workflow/ingestion-helper/spanner_client.py new file mode 100644 index 00000000..27b30a08 --- /dev/null +++ b/pipeline/workflow/ingestion-helper/spanner_client.py @@ -0,0 +1,557 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from google.cloud import spanner +from google.cloud.spanner_admin_database_v1 import DatabaseAdminClient +from google.cloud.spanner_admin_database_v1.types import UpdateDatabaseDdlRequest +from google.cloud.spanner_v1 import Transaction +from google.cloud.spanner_v1.param_types import STRING, TIMESTAMP, Array, INT64 +from datetime import datetime, timezone +from jinja2 import Template + +logging.getLogger().setLevel(logging.INFO) + + +class SpannerClient: + """ + Spanner client to handle tasks like acquiring/releasing lock + and getting/updating import statuses. + """ + _LOCK_ID = "global_ingestion_lock" + _EMBEDDING_MODEL_PATH = "projects/{project}/locations/{location}/publishers/google/models/{model}" + + def __init__(self, + project_id: str, + instance_id: str, + database_id: str, + graph_database_id: str = None, + location: str = None, + model_id: str = None): + """Initializes a Spanner client and connects to a specific database.""" + spanner_client = spanner.Client( + project=project_id, + client_options={'quota_project_id': project_id}, + disable_builtin_metrics=True) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + logging.info(f"Successfully initialized database: {database.name}") + self.database = database + self.graph_database = database + if graph_database_id: + self.graph_database = instance.database(graph_database_id) + logging.info( + f"Successfully initialized graph database: {self.graph_database.name}" + ) + self.project_id = project_id + self.location = location + self.model_id = model_id + + def _get_embeddings_endpoint(self) -> str: + """Returns the parameterized embedding model endpoint.""" + return self._EMBEDDING_MODEL_PATH.format(project=self.project_id, + location=self.location or + "us-central1", + model=self.model_id or + "text-embedding-005") + + def acquire_lock(self, workflow_id: str, timeout: int) -> bool: + """Attempts to acquire the global ingestion lock. + + Args: + workflow_id: The ID of the workflow attempting to acquire the lock. + timeout: The duration in seconds after which a lock is considered stale. + + Returns: + True if the lock was acquired, False otherwise. + """ + logging.info(f"Attempting to acquire lock for {workflow_id}") + + def _acquire(transaction: Transaction) -> bool: + sql = "SELECT LockOwner, AcquiredTimestamp FROM IngestionLock WHERE LockID = @lockId" + params = {"lockId": self._LOCK_ID} + param_types = {"lockId": STRING} + + row_found = False + results = transaction.execute_sql(sql, params, param_types) + for row in results: + row_found = True + current_owner, acquired_at = row[0], row[1] + + lock_is_available = False + if not row_found: + lock_is_available = True + elif current_owner is None: + lock_is_available = True + else: + timeout_threshold = datetime.now(timezone.utc) - acquired_at + if timeout_threshold.total_seconds() > timeout: + logging.info( + f"Stale lock found, owned by {current_owner}. Acquiring." + ) + lock_is_available = True + + if lock_is_available: + if not row_found: + sql_statement = """ + INSERT INTO IngestionLock (LockID, LockOwner, AcquiredTimestamp) + VALUES (@lockId, @workflowId, PENDING_COMMIT_TIMESTAMP()) + """ + log_msg = f"Lock successfully acquired by {workflow_id} (new row created)" + else: + sql_statement = """ + UPDATE IngestionLock + SET LockOwner = @workflowId, AcquiredTimestamp = PENDING_COMMIT_TIMESTAMP() + WHERE LockID = @lockId + """ + log_msg = f"Lock successfully acquired by {workflow_id} (existing row updated)" + + transaction.execute_update(sql_statement, + params={ + "workflowId": workflow_id, + "lockId": self._LOCK_ID + }, + param_types={ + "workflowId": STRING, + "lockId": STRING + }) + logging.info(log_msg) + return True + else: + logging.info(f"Lock is currently held by {current_owner}") + return False + + try: + return self.database.run_in_transaction(_acquire) + except Exception as e: + logging.error(f'Error acquiring lock for {workflow_id}: {e}') + raise + + def release_lock(self, workflow_id: str) -> bool: + """Releases the global lock. + + Args: + workflow_id: The ID of the workflow attempting to release the lock. + + Returns: + True if the lock was released, False otherwise. + """ + logging.info(f"Attempting to release lock for {workflow_id}") + + def _release(transaction: Transaction) -> None: + sql = "SELECT LockOwner, AcquiredTimestamp FROM IngestionLock WHERE LockID = @lockId" + params = {"lockId": self._LOCK_ID} + param_types = {"lockId": STRING} + + current_owner = None + results = transaction.execute_sql(sql, params, param_types) + for row in results: + current_owner = row[0] + + if current_owner == workflow_id: + sql = """ + UPDATE IngestionLock + SET LockOwner = NULL, AcquiredTimestamp = NULL + WHERE LockID = @lockId + """ + transaction.execute_update(sql, + params={"lockId": self._LOCK_ID}, + param_types={"lockId": STRING}) + logging.info(f"Lock successfully released by {workflow_id}") + return True + else: + logging.info(f"Lock is currently held by {current_owner}") + return False + + try: + return self.database.run_in_transaction(_release) + except Exception as e: + logging.error(f'Error releasing lock for {workflow_id}: {e}') + raise + + def get_import_info(self, import_list: list) -> list: + """Get the details of imports to ingest. + + If import_list is empty, return info for ready imports (STAGING). + If import_list is not empty, return info for the imports in the list that are in 'STAGING' status. + + Args: + import_list: A list of import names to fetch details for. + + Returns: + A list of dictionaries, where each dictionary contains 'importName', 'latestVersion', and 'graphPath'. + """ + pending_imports = [] + logging.info(f"Fetching imports from import list {import_list}.") + + params = {} + param_types = {} + if import_list: + sql = "SELECT ImportName, LatestVersion, GraphPath FROM ImportStatus WHERE State = 'STAGING' AND ImportName IN UNNEST(@importNames)" + params = {"importNames": import_list} + param_types = {"importNames": Array(STRING)} + else: + sql = "SELECT ImportName, LatestVersion, GraphPath FROM ImportStatus WHERE State = 'STAGING'" + + # Use a read-only snapshot for this query + try: + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql(sql, + params=params, + param_types=param_types) + for row in results: + import_json = {} + import_json['importName'] = row[0] + import_json['latestVersion'] = os.path.basename(row[1]) + import_json[ + 'graphPath'] = f"{row[1].rstrip('/')}/{row[2].lstrip('/')}" + pending_imports.append(import_json) + + logging.info(f"Found {len(pending_imports)} import jobs.") + return pending_imports + except Exception as e: + logging.error(f'Error getting import list: {e}') + raise + + def update_ingestion_status(self, import_names: list, workflow_id: str, + status: str): + """Updates the ImportStatus table. + + Args: + import_names: List of import names. + workflow_id: The ID of the workflow. + status: The status of the ingestion. + """ + if not import_names: + return + + logging.info(f"Updated ingestion status for {import_names}") + + def _update(transaction: Transaction): + update_sql = "UPDATE ImportStatus SET State = @importStatus, WorkflowId = @workflowId, StatusUpdateTimestamp = PENDING_COMMIT_TIMESTAMP() WHERE ImportName IN UNNEST(@importNames)" + transaction.execute_update(update_sql, + params={ + "importNames": import_names, + "workflowId": workflow_id, + "importStatus": status + }, + param_types={ + "importNames": Array(STRING), + "workflowId": STRING, + "importStatus": STRING + }) + + try: + self.database.run_in_transaction(_update) + logging.info(f"Marked {len(import_names)} import jobs as {status}.") + except Exception as e: + logging.error(f'Error updating ImportStatus table: {e}') + raise + + def update_ingestion_history(self, workflow_id: str, job_id: str, + ingested_imports: list, metrics: dict): + """Updates the IngestionHistory table. + + Args: + workflow_id: The ID of the workflow. + job_id: The Dataflow job ID. + ingested_imports: List of ingested import names. + metrics: A dictionary containing metrics about the ingestion. + """ + + logging.info( + f"Updating IngestionHistory table for workflow {workflow_id}") + + def _insert(transaction: Transaction): + columns = [ + "CompletionTimestamp", "IngestionFailure", + "WorkflowExecutionID", "DataflowJobId", "IngestedImports", + "ExecutionTime", "NodeCount", "EdgeCount", "ObservationCount" + ] + values = [[ + spanner.COMMIT_TIMESTAMP, + self.check_failed_imports(), workflow_id, job_id, + ingested_imports, metrics['execution_time'], + metrics['node_count'], metrics['edge_count'], + metrics['obs_count'] + ]] + transaction.insert_or_update(table="IngestionHistory", + columns=columns, + values=values) + + try: + self.database.run_in_transaction(_insert) + # TODO: remvoe dual writes after switching to the prod setup. + if self.graph_database and self.graph_database.name != self.database.name: + self.graph_database.run_in_transaction(_insert) + logging.info( + f"Updated IngestionHistory table for workflow {workflow_id}") + except Exception as e: + logging.error(f'Error updating IngestionHistory table: {e}') + raise + + def update_import_version_history(self, import_list_json: list, + workflow_id: str): + """Updates the ImportVersionHistory table. + + Args: + import_list_json: A list of dictionaries containing import details. + workflow_id: The ID of the workflow. + """ + if not import_list_json: + return + + logging.info( + f"Updating ImportVersionHistory table for workflow {workflow_id}") + + def _insert(transaction: Transaction): + version_history_columns = [ + "ImportName", "Version", "UpdateTimestamp", "Comment" + ] + version_history_values = [] + for import_json in import_list_json: + version_history_values.append([ + import_json['importName'], import_json['latestVersion'], + spanner.COMMIT_TIMESTAMP, + "ingestion-workflow:" + workflow_id + ]) + + if version_history_values: + transaction.insert(table="ImportVersionHistory", + columns=version_history_columns, + values=version_history_values) + + try: + self.database.run_in_transaction(_insert) + logging.info( + f"Updated ImportVersionHistory table for workflow {workflow_id}" + ) + except Exception as e: + logging.error(f'Error updating ImportVersionHistory table: {e}') + raise + + def check_failed_imports(self) -> bool: + """Checks if there are any failed imports.""" + try: + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql( + "SELECT 1 FROM ImportStatus WHERE State = 'PENDING' LIMIT 1" + ) + return any(results) + except Exception as e: + logging.error(f'Error checking for pending imports: {e}') + return True + + def update_import_status(self, params: dict): + """Updates the status for the specified import job. + + Args: + params: A dictionary containing import parameters. + """ + import_name = params['import_name'] + job_id = params['job_id'] + execution_time = params['execution_time'] + data_volume = params['data_volume'] + status = params['status'] + latest_version = params['latest_version'] + next_refresh = datetime.fromisoformat(params['next_refresh']) + graph_path = params['graph_path'] + logging.info(f"Updating import status in spanner {params}") + + def _record(transaction: Transaction): + columns = [ + "ImportName", "State", "JobId", "ExecutionTime", "DataVolume", + "NextRefreshTimestamp", "LatestVersion", "GraphPath", + "StatusUpdateTimestamp" + ] + + row_values = [ + import_name, status, job_id, execution_time, data_volume, + next_refresh, latest_version, graph_path, + spanner.COMMIT_TIMESTAMP + ] + + if status == 'STAGING': + columns.append("DataImportTimestamp") + row_values.append(spanner.COMMIT_TIMESTAMP) + + transaction.insert_or_update(table="ImportStatus", + columns=columns, + values=[row_values]) + + logging.info(f"Marked {import_name} as {status}.") + + try: + self.database.run_in_transaction(_record) + except Exception as e: + logging.error( + f'Error updating import status for {import_name}: {e}') + raise + + def update_version_history(self, import_name: str, version: str, + comment: str): + """Updates the version history table. + + Args: + import_name: The name of the import. + version: The version string. + comment: The comment for the update. + """ + import_name = import_name.split(':')[-1] + logging.info(f"Updating version history for {import_name} to {version}") + + def _record(transaction: Transaction): + columns = ["ImportName", "Version", "UpdateTimestamp", "Comment"] + values = [[import_name, version, spanner.COMMIT_TIMESTAMP, comment]] + transaction.insert(table="ImportVersionHistory", + columns=columns, + values=values) + logging.info(f"Added version history entry for {import_name}") + + try: + self.database.run_in_transaction(_record) + except Exception as e: + logging.error( + f'Error updating version history for {import_name}: {e}') + raise + + def initialize_database(self, enable_embeddings=False): + """Initializes the database by creating all required tables and proto bundles.""" + logging.info("Initializing database...") + + query = """ + SELECT 'table' as type, table_name as name FROM information_schema.tables WHERE table_schema = '' + UNION ALL + SELECT 'index' as type, index_name as name FROM information_schema.indexes WHERE table_schema = '' AND table_name = 'NodeEmbedding' + UNION ALL + SELECT 'model' as type, model_name as name FROM information_schema.models WHERE model_schema = '' + """ + + existing_tables = [] + existing_indexes = [] + existing_models = [] + + with self.database.snapshot() as snapshot: + results = snapshot.execute_sql(query) + for row in results: + if len(row) < 2: + logging.warning(f"Invalid row from query: {row}") + continue + obj_type = row[0] + obj_name = row[1] + if obj_type == 'table': + existing_tables.append(obj_name) + elif obj_type == 'index': + existing_indexes.append(obj_name) + elif obj_type == 'model': + existing_models.append(obj_name) + + logging.info(f"Existing tables: {existing_tables}") + logging.info(f"Existing indexes: {existing_indexes}") + logging.info(f"Existing models: {existing_models}") + + required_tables = [ + "Node", "Edge", "Observation", "ImportStatus", "IngestionHistory", + "ImportVersionHistory", "IngestionLock", "Cache", "VariableMetadata" + ] + required_indexes = [] + required_models = [] + + if enable_embeddings: + required_tables.append("NodeEmbedding") + required_indexes.append("NodeEmbeddingIndex") + required_models.append("NodeEmbeddingModel") + + missing_tables = [ + t for t in required_tables if t not in existing_tables + ] + missing_indexes = [ + i for i in required_indexes if i not in existing_indexes + ] + missing_models = [ + m for m in required_models if m not in existing_models + ] + + total_required = len(required_tables) + len(required_indexes) + len( + required_models) + total_missing = len(missing_tables) + len(missing_indexes) + len( + missing_models) + + if total_missing == 0: + logging.info("Database is properly initialized.") + return + + if total_missing < total_required: + raise RuntimeError( + f"Database inconsistent state. Missing tables: {missing_tables}, missing indexes: {missing_indexes}, missing models: {missing_models}. Please clean up manually." + ) + + logging.info("Creating all tables and proto bundles...") + + schema_path = os.path.join(os.path.dirname(__file__), 'schema.sql') + logging.info(f"Reading schema from {schema_path}") + try: + with open(schema_path, 'r') as f: + schema_content = f.read() + + ddl_statements = [ + s.strip() for s in schema_content.split(';') if s.strip() + ] + + if enable_embeddings: + embeddings_endpoint = self._get_embeddings_endpoint() + embedding_schema_path = os.path.join(os.path.dirname(__file__), + 'embedding_schema.sql') + logging.info( + f"Reading embedding schema from {embedding_schema_path}") + with open(embedding_schema_path, 'r') as f: + embedding_schema_content = f.read() + embedding_schema_content = Template( + embedding_schema_content).render( + embeddings_endpoint=embeddings_endpoint) + embedding_ddl_statements = [ + s.strip() + for s in embedding_schema_content.split(';') + if s.strip() + ] + ddl_statements.extend(embedding_ddl_statements) + except Exception as e: + logging.error(f"Failed to read schema file: {e}") + raise + + proto_path = os.path.join(os.path.dirname(__file__), 'storage.pb') + logging.info(f"Reading proto descriptors from {proto_path}") + try: + with open(proto_path, 'rb') as f: + proto_descriptors_bytes = f.read() + except Exception as e: + logging.error(f"Failed to read proto descriptors file: {e}") + raise + + database_path = self.database.name + logging.info(f"Updating DDL for {database_path} with protos") + + try: + admin_client = DatabaseAdminClient() + request = UpdateDatabaseDdlRequest( + database=database_path, + statements=ddl_statements, + proto_descriptors=proto_descriptors_bytes) + operation = admin_client.update_database_ddl(request=request) + operation.result() + logging.info("Database initialized successfully with protos.") + except Exception as e: + logging.error(f"Failed to update DDL with protos: {e}") + raise diff --git a/pipeline/workflow/ingestion-helper/spanner_client_test.py b/pipeline/workflow/ingestion-helper/spanner_client_test.py new file mode 100644 index 00000000..3a961db4 --- /dev/null +++ b/pipeline/workflow/ingestion-helper/spanner_client_test.py @@ -0,0 +1,224 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +import sys +import os + +# Add the current directory to path so we can import spanner_client +sys.path.append(os.path.dirname(__file__)) +from spanner_client import SpannerClient + +class TestSpannerClient(unittest.TestCase): + + @patch('google.cloud.spanner.Client') + def test_initialize_database_all_exist(self, mock_spanner_client): + # Setup mock + mock_instance = MagicMock() + mock_db = MagicMock() + mock_spanner_client.return_value.instance.return_value = mock_instance + mock_instance.database.return_value = mock_db + + # Mock snapshot results (all tables exist) + mock_snapshot = MagicMock() + mock_db.snapshot.return_value.__enter__.return_value = mock_snapshot + mock_snapshot.execute_sql.return_value = [ + ["table", "Node"], ["table", "Edge"], ["table", "Observation"], + ["table", "NodeEmbedding"], ["table", "ImportStatus"], + ["table", "IngestionHistory"], ["table", "ImportVersionHistory"], + ["table", "IngestionLock"], + ["index", "NodeEmbeddingIndex"], + ["model", "NodeEmbeddingModel"] + ] + + client = SpannerClient("project", "instance", "database") + + # Run method + client.initialize_database() + + # Verify update_ddl was NOT called + mock_db.update_ddl.assert_not_called() + + @patch('spanner_client.DatabaseAdminClient') + @patch('google.cloud.spanner.Client') + def test_initialize_database_none_exist(self, mock_spanner_client, + mock_admin_client): + # Setup mock + mock_instance = MagicMock() + mock_db = MagicMock() + mock_db.name = "projects/test-project/instances/test-instance/databases/test-db" + mock_spanner_client.return_value.instance.return_value = mock_instance + mock_instance.database.return_value = mock_db + + # Mock DatabaseAdminClient + mock_admin_instance = MagicMock() + mock_admin_client.return_value = mock_admin_instance + mock_operation = MagicMock() + mock_admin_instance.update_database_ddl.return_value = mock_operation + + # Mock snapshot results (no tables exist) + mock_snapshot = MagicMock() + mock_db.snapshot.return_value.__enter__.return_value = mock_snapshot + mock_snapshot.execute_sql.return_value = [] + + client = SpannerClient("project", "instance", "database") + + def open_side_effect(file_path, mode='r', *args, **kwargs): + m = MagicMock() + if 'storage.pb' in str(file_path): + m.__enter__.return_value.read.return_value = b'dummy proto data' + else: + m.__enter__.return_value.read.return_value = 'CREATE TABLE Node;' + return m + + # Run method with patched open + with patch('builtins.open', side_effect=open_side_effect): + client.initialize_database() + + # Verify update_database_ddl WAS called + mock_admin_instance.update_database_ddl.assert_called_once() + mock_operation.result.assert_called_once() + + # Verify placeholder replacement + args, kwargs = mock_admin_instance.update_database_ddl.call_args + request = kwargs.get('request') if kwargs else args[0] + statements = request.statements + self.assertEqual(len(statements), 1) + self.assertEqual(statements[0], "CREATE TABLE Node") + + @patch('spanner_client.DatabaseAdminClient') + @patch('google.cloud.spanner.Client') + def test_initialize_database_with_embeddings(self, mock_spanner_client, mock_admin_client): + # Setup mock + mock_instance = MagicMock() + mock_db = MagicMock() + mock_db.name = "projects/test-project/instances/test-instance/databases/test-db" + mock_spanner_client.return_value.instance.return_value = mock_instance + mock_instance.database.return_value = mock_db + + # Mock DatabaseAdminClient + mock_admin_instance = MagicMock() + mock_admin_client.return_value = mock_admin_instance + mock_operation = MagicMock() + mock_admin_instance.update_database_ddl.return_value = mock_operation + + # Mock snapshot results (no tables exist) + mock_snapshot = MagicMock() + mock_db.snapshot.return_value.__enter__.return_value = mock_snapshot + mock_snapshot.execute_sql.return_value = [] + + client = SpannerClient("project", "instance", "database") + + def open_side_effect(file_path, mode='r', *args, **kwargs): + m = MagicMock() + if 'storage.pb' in str(file_path): + m.__enter__.return_value.read.return_value = b'dummy proto data' + elif 'embedding_schema.sql' in str(file_path): + m.__enter__.return_value.read.return_value = 'CREATE TABLE NodeEmbedding; CREATE MODEL M REMOTE OPTIONS (endpoint = \'{{ embeddings_endpoint }}\');' + else: + m.__enter__.return_value.read.return_value = 'CREATE TABLE Node;' + return m + + # Run method with patched open and parameter + with patch('builtins.open', side_effect=open_side_effect): + client.initialize_database(enable_embeddings=True) + + # Verify update_database_ddl WAS called + mock_admin_instance.update_database_ddl.assert_called_once() + + # Verify that both schemas were loaded + args, kwargs = mock_admin_instance.update_database_ddl.call_args + request = kwargs.get('request') if kwargs else args[0] + statements = request.statements + self.assertEqual(len(statements), 3) + self.assertEqual(statements[0], "CREATE TABLE Node") + self.assertEqual(statements[1], "CREATE TABLE NodeEmbedding") + self.assertIn("projects/project/locations", statements[2]) + + @patch('google.cloud.spanner.Client') + def test_initialize_database_inconsistent_state(self, mock_spanner_client): + # Setup mock + mock_instance = MagicMock() + mock_db = MagicMock() + mock_spanner_client.return_value.instance.return_value = mock_instance + mock_instance.database.return_value = mock_db + + # Mock snapshot results (some tables exist) + mock_snapshot = MagicMock() + mock_db.snapshot.return_value.__enter__.return_value = mock_snapshot + mock_snapshot.execute_sql.return_value = [["table", "Node"]] + + client = SpannerClient("project", "instance", "database") + + # Run method and expect exception + with self.assertRaises(RuntimeError): + client.initialize_database() + + @patch('google.cloud.spanner.Client') + def test_acquire_lock_new_row(self, mock_spanner_client): + # Setup mock + mock_instance = MagicMock() + mock_db = MagicMock() + mock_spanner_client.return_value.instance.return_value = mock_instance + mock_instance.database.return_value = mock_db + + mock_transaction = MagicMock() + def run_in_transaction_side_effect(callback, *args, **kwargs): + return callback(mock_transaction, *args, **kwargs) + mock_db.run_in_transaction.side_effect = run_in_transaction_side_effect + + # Mock execute_sql to return empty results (no row found) + mock_transaction.execute_sql.return_value = [] + + client = SpannerClient("project", "instance", "database") + + # Run method + result = client.acquire_lock("workflow-123", 3600) + + # Verify + self.assertTrue(result) + mock_transaction.execute_update.assert_called_once() + args, _ = mock_transaction.execute_update.call_args + self.assertIn("INSERT INTO IngestionLock", args[0]) + + @patch('google.cloud.spanner.Client') + def test_acquire_lock_existing_row(self, mock_spanner_client): + # Setup mock + mock_instance = MagicMock() + mock_db = MagicMock() + mock_spanner_client.return_value.instance.return_value = mock_instance + mock_instance.database.return_value = mock_db + + mock_transaction = MagicMock() + def run_in_transaction_side_effect(callback, *args, **kwargs): + return callback(mock_transaction, *args, **kwargs) + mock_db.run_in_transaction.side_effect = run_in_transaction_side_effect + + # Mock execute_sql to return existing lock (owner is None) + mock_transaction.execute_sql.return_value = [[None, None]] + + client = SpannerClient("project", "instance", "database") + + # Run method + result = client.acquire_lock("workflow-123", 3600) + + # Verify + self.assertTrue(result) + mock_transaction.execute_update.assert_called_once() + args, _ = mock_transaction.execute_update.call_args + self.assertIn("UPDATE IngestionLock", args[0]) + +if __name__ == '__main__': + unittest.main() diff --git a/pipeline/workflow/ingestion-helper/storage_client.py b/pipeline/workflow/ingestion-helper/storage_client.py new file mode 100644 index 00000000..12f57a50 --- /dev/null +++ b/pipeline/workflow/ingestion-helper/storage_client.py @@ -0,0 +1,156 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Storage client for the ingestion helper.""" + +import logging +from google.cloud import storage +from google.cloud import exceptions +import json +import os + +logging.getLogger().setLevel(logging.INFO) + +_STAGING_VERSION_FILE = 'staging_version.txt' +_LATEST_VERSION_FILE = 'latest_version.txt' +_IMPORT_METADATA_MCF = 'import_metadata_mcf.mcf' +_IMPORT_SUMMARY_JSON = 'import_summary.json' + + +class StorageClient: + + def __init__(self, bucket_name: str): + """Initializes a GCS client.""" + self.storage = storage.Client() + self.bucket = self.storage.bucket(bucket_name) + + def get_import_summary(self, import_name: str, version: str) -> dict: + """Retrieves the import summary from GCS. + + Args: + import_name: The name of the import. + version: The version of the import. + + Returns: + A dictionary containing the import summary, or an empty dict if not found. + """ + output_dir = import_name.replace(':', '/') + summary_file = os.path.join(output_dir, version, _IMPORT_SUMMARY_JSON) + logging.info(f'Reading import summary from {summary_file}') + try: + blob = self.bucket.blob(summary_file) + json_data_string = blob.download_as_text() + data = json.loads(json_data_string) + logging.info(f"Successfully read {summary_file}") + return data + except (exceptions.NotFound, json.JSONDecodeError) as e: + logging.error( + f'Error reading import summary file {summary_file}: {e}') + return {} + + def update_import_summary(self, import_summary: dict): + """Updates the import summary in GCS. + + Args: + import_summary: A dictionary containing the summary of the import. + """ + latest_version = import_summary.get('latest_version') + path = latest_version.removeprefix('gs://').split('/', 1) + summary_file = os.path.join(path[1], _IMPORT_SUMMARY_JSON) + logging.info( + f'Updating import summary at {summary_file} {import_summary}') + blob = self.bucket.blob(summary_file) + blob.upload_from_string(json.dumps(import_summary)) + logging.info(f'Updated import summary at {summary_file}') + + def get_staging_version(self, import_name: str) -> str: + """Retrieves the latest version from the staging directory. + + Args: + import_name: The name of the import. + + Returns: + The version string, or an empty string if not found. + """ + output_dir = import_name.replace(':', '/') + version_file = os.path.join(output_dir, _STAGING_VERSION_FILE) + logging.info(f'Reading version file {version_file}') + try: + blob = self.bucket.blob(version_file) + return blob.download_as_text() + except exceptions.NotFound: + logging.error(f"Version file {version_file} not found") + return '' + + def update_version_file(self, + import_name: str, + version: str, + is_staging: bool = False): + """Updates the version file (staging or latest) in GCS. + + Args: + import_name: The name of the import. + version: The new version string. + is_staging: Whether to update the staging version file or the latest version file. + """ + file_name = _STAGING_VERSION_FILE if is_staging else _LATEST_VERSION_FILE + file_type = "staging" if is_staging else "latest" + logging.info( + f'Updating {file_type} version file for import {import_name} to {version}' + ) + output_dir = import_name.replace(':', '/') + version_file = self.bucket.blob(os.path.join(output_dir, file_name)) + version_file.upload_from_string(version) + logging.info( + f'Updated {file_type} version file {version_file.name} to {version}' + ) + + def update_provenance_file(self, import_name: str, version: str): + """Updates the provenance file for the import. + + Args: + import_name: The name of the import. + version: The version of the import. + """ + logging.info( + f'Updating provenance file for import {import_name} to add {version}' + ) + output_dir = import_name.replace(':', '/') + metadata_blob = self.bucket.blob( + os.path.join(output_dir, version, 'provenance', 'genmcf', + _IMPORT_METADATA_MCF)) + if metadata_blob.exists(): + self.bucket.copy_blob( + metadata_blob, self.bucket, + os.path.join(output_dir, 'import_metadata_mcf.mcf')) + else: + logging.warning( + f'Generating default metadata for import {import_name}') + base_name = import_name.split(':')[-1] + default_provenance = f"Node: dcid:dc/base/{base_name}\ntypeOf: dcid:Provenance\n" + new_blob = self.bucket.blob( + os.path.join(output_dir, version, 'provenance', 'genmcf', + 'import_metadata_mcf.mcf')) + new_blob.upload_from_string(default_provenance) + + provenance_file = import_name.split(':')[-1] + '.mcf' + provenance_blob = self.bucket.blob( + os.path.join('provenance', provenance_file)) + if provenance_blob.exists(): + self.bucket.copy_blob( + provenance_blob, self.bucket, + os.path.join(output_dir, version, 'provenance', 'genmcf', + provenance_file)) + logging.info( + f'Updated provenance file for import {import_name} to add {version}' + ) diff --git a/pipeline/workflow/spanner-ingestion-workflow.yaml b/pipeline/workflow/spanner-ingestion-workflow.yaml new file mode 100644 index 00000000..01f2976d --- /dev/null +++ b/pipeline/workflow/spanner-ingestion-workflow.yaml @@ -0,0 +1,183 @@ +main: + params: [args] + steps: + - init: + assign: + - lock_timeout: 82800 # 23 hours + - wait_period: 300 # seconds + - project_id: '${sys.get_env("PROJECT_ID")}' + - dataflow_job_name: ${"ingestion-job-" + string(int(sys.now()))} + - dataflow_gcs_path: ${default(map.get(args, "templateGcsPath"), "gs://datcom-templates/templates/flex/ingestion.json")} + - location: '${sys.get_env("LOCATION")}' + - spanner_project_id: '${sys.get_env("SPANNER_PROJECT_ID")}' + - spanner_instance_id: '${sys.get_env("SPANNER_INSTANCE_ID")}' + - spanner_database_id: '${sys.get_env("SPANNER_DATABASE_ID")}' + - helper_url: ${"https://ingestion-helper-service-" + sys.get_env("PROJECT_NUMBER") + "." + location + ".run.app"} + - import_list: ${default(map.get(args, "importList"), [])} + - execution_error: null + - acquire_ingestion_lock: + try: + call: http.post + args: + url: ${helper_url} + auth: + type: OIDC + body: + actionType: acquire_ingestion_lock + workflowId: '${sys.get_env("GOOGLE_CLOUD_WORKFLOW_EXECUTION_ID")}' + timeout: ${lock_timeout} + result: lock_status + retry: + predicate: ${http.default_retry_predicate} + max_retries: 20 + backoff: + initial_delay: 300 + max_delay: 600 + multiplier: 2 + - process_ingestion: + try: + steps: + - get_import_info: + call: http.post + args: + url: ${helper_url} + auth: + type: OIDC + body: + actionType: get_import_info + importList: ${import_list} + result: import_info + - run_ingestion_job: + call: run_dataflow_job + args: + import_list: '${json.encode_to_string(import_info.body)}' + project_id: ${project_id} + job_name: ${dataflow_job_name} + template_gcs_path: ${dataflow_gcs_path} + location: ${location} + spanner_project_id: ${spanner_project_id} + spanner_instance_id: ${spanner_instance_id} + spanner_database_id: ${spanner_database_id} + wait_period: ${wait_period} + helper_url: ${helper_url} + workflow_id: '${sys.get_env("GOOGLE_CLOUD_WORKFLOW_EXECUTION_ID")}' + result: dataflow_job_id + - run_aggregation: + call: http.post + args: + url: ${helper_url} + auth: + type: OIDC + body: + actionType: run_aggregation + importList: ${import_info.body} + - update_ingestion_status: + call: http.post + args: + url: ${helper_url} + auth: + type: OIDC + body: + actionType: update_ingestion_status + workflowId: '${sys.get_env("GOOGLE_CLOUD_WORKFLOW_EXECUTION_ID")}' + jobId: '${dataflow_job_id}' + importList: '${import_info.body}' + status: 'SUCCESS' + result: function_response + except: + as: e + steps: + - capture_error: + assign: + - execution_error: ${e} + - release_ingestion_lock: + call: http.post + args: + url: ${helper_url} + auth: + type: OIDC + body: + actionType: release_ingestion_lock + workflowId: '${sys.get_env("GOOGLE_CLOUD_WORKFLOW_EXECUTION_ID")}' + result: function_response + - fail_workflow: + switch: + - condition: ${execution_error != null} + raise: ${execution_error} + - return_import_info: + return: '${import_info.body}' + +# This sub-workflow launches a Dataflow job and waits for it to complete. +run_dataflow_job: + params: [import_list, project_id, job_name, template_gcs_path, location, spanner_project_id, spanner_instance_id, spanner_database_id, wait_period, helper_url, workflow_id] + steps: + - init: + assign: + - jobName: ${job_name} + - machineType: 'n2-highmem-8' + - numWorkers: 3 + - log_imports: + call: sys.log + args: + text: '${"Dataflow job: " + jobName + " Import list: " + import_list}' + severity: INFO + - check_if_empty: + switch: + - condition: ${import_list == "[]"} + return: '' + - launch_dataflow_job: + call: googleapis.dataflow.v1b3.projects.locations.flexTemplates.launch + args: + projectId: '${project_id}' + location: '${location}' + body: + launchParameter: + containerSpecGcsPath: '${template_gcs_path}' + jobName: '${jobName}' + parameters: + importList: '${import_list}' + projectId: '${spanner_project_id}' + spannerInstanceId: '${spanner_instance_id}' + spannerDatabaseId: '${spanner_database_id}' + environment: + numWorkers: ${numWorkers} + machineType: ${machineType} + result: launch_result + - wait_for_job_completion: + call: sys.sleep + args: + seconds: ${wait_period} + next: check_job_status + - check_job_status: + call: googleapis.dataflow.v1b3.projects.locations.jobs.get + args: + projectId: '${project_id}' + location: '${location}' + jobId: '${launch_result.job.id}' + view: 'JOB_VIEW_SUMMARY' + result: job_status + next: check_if_done + - check_if_done: + switch: + - condition: ${job_status.currentState == "JOB_STATE_DONE"} + return: ${launch_result.job.id} + - condition: ${job_status.currentState == "JOB_STATE_FAILED" or job_status.currentState == "JOB_STATE_CANCELLED"} + next: record_failed_imports + next: wait_for_job_completion + - record_failed_imports: + call: http.post + args: + url: ${helper_url} + auth: + type: OIDC + body: + actionType: update_ingestion_status + workflowId: '${workflow_id}' + jobId: '${launch_result.job.id}' + importList: '${json.decode(import_list)}' + status: 'PENDING' + result: retry_response + - fail_workflow: + raise: + message: '${jobName + " dataflow job failed with status: " + job_status.currentState}' + code: 500 \ No newline at end of file diff --git a/pipeline/workflow/spanner_ingestion_test.py b/pipeline/workflow/spanner_ingestion_test.py new file mode 100644 index 00000000..7e3d24b6 --- /dev/null +++ b/pipeline/workflow/spanner_ingestion_test.py @@ -0,0 +1,174 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +End-to-end test for import automation and Spanner ingestion workflows. +""" + +import json +import os +import sys + +from absl import app +from absl import logging + +# Add path for executor modules +sys.path.append( + os.path.abspath(os.path.join(os.path.dirname(__file__), '../executor/app'))) +from executor import cloud_workflow + +from google.cloud import spanner + +PROJECT_ID = os.environ.get('PROJECT_ID', 'datcom-ci') +LOCATION = os.environ.get('LOCATION', 'us-central1') +SPANNER_PROJECT_ID = os.environ.get('SPANNER_PROJECT_ID', 'datcom-ci') +SPANNER_INSTANCE_ID = os.environ.get('SPANNER_INSTANCE_ID', + 'datcom-spanner-test') +SPANNER_DATABASE_ID = os.environ.get('SPANNER_DATABASE_ID', 'dc-test-db') +GCS_BUCKET_ID = os.environ.get('GCS_BUCKET_ID', 'datcom-ci-test') +IMPORT_WORKFLOW_ID = 'import-automation-workflow' +INGESTION_WORKFLOW_ID = 'spanner-ingestion-workflow' + +# Test Import Configuration +TEST_IMPORT_NAME = 'scripts/us_fed/treasury_constant_maturity_rates:USFed_ConstantMaturityRates_Test' + + +def verify_spanner_data(import_name): + """Verifies that the import data exists and is marked as SUCCESS in Spanner.""" + logging.info(f"Verifying Spanner data for import: {import_name}") + spanner_client = spanner.Client(project=SPANNER_PROJECT_ID) + instance = spanner_client.instance(SPANNER_INSTANCE_ID) + database = instance.database(SPANNER_DATABASE_ID) + + try: + with database.snapshot(multi_use=True) as snapshot: + # Check ImportStatus table + query = "SELECT State FROM ImportStatus WHERE ImportName = @import_name" + params = {"import_name": import_name} + param_types = {"import_name": spanner.param_types.STRING} + + results = list( + snapshot.execute_sql(query, + params=params, + param_types=param_types)) + + if not results: + raise AssertionError( + f"Import {import_name} not found in ImportStatus table.") + + state = results[0][0] + if state != 'SUCCESS': + raise AssertionError( + f"Import {import_name} state is {state}, expected 'SUCCESS'." + ) + + logging.info( + f"Import {import_name} verified in ImportStatus with state: {state}" + ) + + # Check IngestionHistory table (optional, but good for E2E) + # We look for a recent entry containing this import + query_history = """ + SELECT count(*) + FROM IngestionHistory + WHERE @import_name IN UNNEST(IngestedImports) + """ + results_history = list( + snapshot.execute_sql(query_history, + params=params, + param_types=param_types)) + count = results_history[0][0] + + if count == 0: + raise AssertionError( + f"Import {import_name} not found in IngestionHistory table." + ) + + logging.info(f"Import {import_name} verified in IngestionHistory.") + + except Exception as e: + logging.error(f"Spanner verification failed: {e}") + raise + + +def cleanup_spanner(import_name): + """Cleans up the import data from Spanner to ensure a clean state.""" + logging.info(f"Cleaning up Spanner data for import: {import_name}") + spanner_client = spanner.Client(project=SPANNER_PROJECT_ID) + instance = spanner_client.instance(SPANNER_INSTANCE_ID) + database = instance.database(SPANNER_DATABASE_ID) + + def _delete_import(transaction): + query = "DELETE FROM ImportStatus WHERE ImportName = @import_name" + params = {"import_name": import_name} + param_types = {"import_name": spanner.param_types.STRING} + transaction.execute_update(query, + params=params, + param_types=param_types) + + try: + database.run_in_transaction(_delete_import) + logging.info( + f"Successfully cleaned up {import_name} from ImportStatus table.") + except Exception as e: + logging.warning(f"Error during Spanner cleanup: {e}") + + +def main(argv): + del argv # Unused. + try: + # 0. Cleanup Spanner + logging.info("Step 0: Cleanup Spanner...") + short_import_name = TEST_IMPORT_NAME.split(':')[-1] + cleanup_spanner(short_import_name) + + # 1. Trigger Import Automation Workflow + job_name = "test-import" + import_config = { + "gcp_project_id": PROJECT_ID, + "gcs_project_id": PROJECT_ID, + "storage_prod_bucket_name": GCS_BUCKET_ID, + "gcs_bucket_volume_mount": GCS_BUCKET_ID + } + + import_workflow_args = { + "importName": TEST_IMPORT_NAME, + "importConfig": json.dumps(import_config), + } + + logging.info("Step 1: Running Import Automation Workflow...") + cloud_workflow.trigger_workflow_and_wait(PROJECT_ID, LOCATION, + IMPORT_WORKFLOW_ID, + import_workflow_args) + + # 2. Trigger Spanner Ingestion Workflow + ingestion_workflow_args = {"importList": [short_import_name]} + + logging.info("Step 2: Running Spanner Ingestion Workflow...") + cloud_workflow.trigger_workflow_and_wait(PROJECT_ID, LOCATION, + INGESTION_WORKFLOW_ID, + ingestion_workflow_args) + + # 3. Verify Data in Spanner + logging.info("Step 3: Verifying Data in Spanner...") + verify_spanner_data(short_import_name) + + logging.info("Spanner ingestion test completed successfully.") + + except Exception as e: + logging.error(f"Spanner ingestion test Failed: {e}") + sys.exit(1) + + +if __name__ == "__main__": + app.run(main)