diff --git a/pipeline/spanner/pom.xml b/pipeline/spanner/pom.xml
index 1db64212..035b825b 100644
--- a/pipeline/spanner/pom.xml
+++ b/pipeline/spanner/pom.xml
@@ -76,4 +76,15 @@
test
+
+
+
+
+ ../workflow/ingestion-helper
+
+ schema.sql
+
+
+
+
diff --git a/pipeline/spanner/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java b/pipeline/spanner/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java
index 87fb58d9..b0483371 100644
--- a/pipeline/spanner/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java
+++ b/pipeline/spanner/src/main/java/org/datacommons/ingestion/spanner/SpannerClient.java
@@ -231,17 +231,17 @@ public void validateOrInitializeDatabase() {
}
}
- /** Reads DDL statements from the spanner_schema.sql file in the resources directory. */
+ /** Reads DDL statements from the schema.sql file in the resources directory. */
List readDdlStatements() {
- InputStream inputStream = getClass().getClassLoader().getResourceAsStream("spanner_schema.sql");
+ InputStream inputStream = getClass().getClassLoader().getResourceAsStream("schema.sql");
if (inputStream == null) {
- throw new IllegalStateException("Could not find spanner_schema.sql in resources.");
+ throw new IllegalStateException("Could not find schema.sql in resources.");
}
try (BufferedReader reader =
new BufferedReader(new InputStreamReader(inputStream, StandardCharsets.UTF_8))) {
return parseDdlStatements(reader);
} catch (IOException e) {
- throw new IllegalStateException("Failed to read spanner_schema.sql", e);
+ throw new IllegalStateException("Failed to read schema.sql", e);
}
}
diff --git a/pipeline/spanner/src/main/resources/spanner_schema.sql b/pipeline/spanner/src/main/resources/spanner_schema.sql
deleted file mode 100644
index c6932226..00000000
--- a/pipeline/spanner/src/main/resources/spanner_schema.sql
+++ /dev/null
@@ -1,34 +0,0 @@
-CREATE PROTO BUNDLE (
- `org.datacommons.Observations`
-)
-
-CREATE TABLE Node (
- subject_id STRING(1024) NOT NULL,
- value STRING(MAX),
- bytes BYTES(MAX),
- name STRING(MAX),
- types ARRAY,
- name_tokenlist TOKENLIST AS (TOKENIZE_FULLTEXT(name)) HIDDEN,
-) PRIMARY KEY(subject_id)
-
-CREATE TABLE Edge (
- subject_id STRING(1024) NOT NULL,
- predicate STRING(1024) NOT NULL,
- object_id STRING(1024) NOT NULL,
- provenance STRING(1024) NOT NULL,
-) PRIMARY KEY(subject_id, predicate, object_id, provenance),
-INTERLEAVE IN Node
-
-CREATE TABLE Observation (
- observation_about STRING(1024) NOT NULL,
- variable_measured STRING(1024) NOT NULL,
- facet_id STRING(1024) NOT NULL,
- observation_period STRING(1024),
- measurement_method STRING(1024),
- unit STRING(1024),
- scaling_factor STRING(1024),
- observations org.datacommons.Observations,
- import_name STRING(1024),
- provenance_url STRING(1024),
- is_dc_aggregate BOOL,
-) PRIMARY KEY(observation_about, variable_measured, facet_id)
\ No newline at end of file
diff --git a/pipeline/terraform/main.tf b/pipeline/terraform/main.tf
new file mode 100644
index 00000000..862636af
--- /dev/null
+++ b/pipeline/terraform/main.tf
@@ -0,0 +1,379 @@
+# Terraform deployment for Data Commons Import Automation Workflow
+#
+# Usage:
+# - Authenticate and set up application default credentials for Terraform to access GCP using 'gcloud auth login --update-adc'.
+# - Obtain DataCommons API key: Get an API key portal https://apikeys.datacommons.org/ to be used as the `dc_api_key` variable.
+# - Deploy the infrastructure and resources defined in this configuration using 'terraform apply'.
+# - The output service account needs to have required permissions to access external resources.
+#
+# Input variables:
+# - GCP project id
+# - DC API key
+#
+# This file sets up:
+# - Necessary GCP APIs
+# - Secret Manager for the import-config secret
+# - GCS Buckets for imports, mounting, and Dataflow templates
+# - Spanner Instance and Database
+# - Artifact Registry for hosting Docker images (Flex Template & Executor)
+# - Pub/Sub Topic and Subscription for triggering imports
+# - Cloud Functions, Workflows, and Ingestion Pipeline
+# - Unified Service Account with necessary IAM roles for Workflows, Functions, and Pub/Sub
+
+terraform {
+ required_providers {
+ google = {
+ source = "hashicorp/google"
+ version = ">= 5.0.0"
+ }
+ archive = {
+ source = "hashicorp/archive"
+ }
+ }
+}
+
+variable "project_id" {
+ description = "The GCP Project ID"
+ type = string
+}
+
+variable "region" {
+ description = "The GCP Region"
+ type = string
+ default = "us-central1"
+}
+
+variable "spanner_instance_id" {
+ description = "Spanner Instance ID"
+ type = string
+ default = "datcom-import-instance"
+}
+
+variable "spanner_database_id" {
+ description = "Spanner Database ID"
+ type = string
+ default = "dc-import-db"
+}
+
+variable "spanner_graph_database_id" {
+ description = "Spanner Graph Database ID"
+ type = string
+ default = "dc-import-db"
+}
+
+variable "bq_dataset_id" {
+ description = "BigQuery Dataset ID for aggregation"
+ type = string
+ default = "datacommons"
+}
+
+variable "dc_api_key" {
+ description = "Data Commons API Key"
+ type = string
+ sensitive = true
+}
+
+variable "artifact_registry_url" {
+ description = "Artifact Registry URL for Cloud Run images"
+ type = string
+ default = "us-docker.pkg.dev/datcom-ci/gcr.io"
+}
+
+# --- APIs ---
+
+locals {
+ services = [
+ "artifactregistry.googleapis.com",
+ "batch.googleapis.com",
+ "cloudbuild.googleapis.com",
+ "cloudfunctions.googleapis.com",
+ "cloudscheduler.googleapis.com",
+ "compute.googleapis.com",
+ "dataflow.googleapis.com",
+ "iam.googleapis.com",
+ "pubsub.googleapis.com",
+ "run.googleapis.com",
+ "secretmanager.googleapis.com",
+ "spanner.googleapis.com",
+ "storage.googleapis.com",
+ "workflows.googleapis.com",
+ ]
+}
+
+resource "google_project_service" "services" {
+ for_each = toset(local.services)
+ project = var.project_id
+ service = each.key
+
+ disable_on_destroy = false
+}
+
+# --- Secret Manager ---
+
+resource "google_secret_manager_secret" "import_config" {
+ secret_id = "import-config"
+ project = var.project_id
+
+ replication {
+ auto {}
+ }
+
+ depends_on = [google_project_service.services]
+}
+
+resource "google_secret_manager_secret_version" "import_config_v1" {
+ secret = google_secret_manager_secret.import_config.id
+ secret_data = jsonencode({
+ dc_api_key = var.dc_api_key
+ })
+}
+
+resource "google_secret_manager_secret" "dc_api_key" {
+ secret_id = "dc-api-key"
+ project = var.project_id
+
+ replication {
+ auto {}
+ }
+
+ depends_on = [google_project_service.services]
+}
+
+resource "google_secret_manager_secret_version" "dc_api_key_v1" {
+ secret = google_secret_manager_secret.dc_api_key.id
+ secret_data = var.dc_api_key
+}
+
+# --- GCS Buckets ---
+
+resource "google_storage_bucket" "import_bucket" {
+ name = "${var.project_id}-imports"
+ location = var.region
+ project = var.project_id
+ uniform_bucket_level_access = true
+
+ depends_on = [google_project_service.services]
+}
+
+resource "google_storage_bucket" "mount_bucket" {
+ name = "${var.project_id}-mount"
+ location = var.region
+ project = var.project_id
+ uniform_bucket_level_access = true
+
+ depends_on = [google_project_service.services]
+}
+
+# --- Cloud Functions Source Packaging ---
+
+# --- Cloud Functions ---
+
+resource "google_cloud_run_v2_service" "ingestion_helper" {
+ name = "ingestion-helper-service"
+ location = var.region
+ project = var.project_id
+
+ template {
+ service_account = google_service_account.automation_sa.email
+ containers {
+ image = "${var.artifact_registry_url}/datacommons-ingestion-helper:latest"
+ env {
+ name = "PROJECT_ID"
+ value = var.project_id
+ }
+ env {
+ name = "SPANNER_PROJECT_ID"
+ value = var.project_id
+ }
+ env {
+ name = "SPANNER_INSTANCE_ID"
+ value = var.spanner_instance_id
+ }
+ env {
+ name = "SPANNER_DATABASE_ID"
+ value = var.spanner_database_id
+ }
+ env {
+ name = "SPANNER_GRAPH_DATABASE_ID"
+ value = var.spanner_graph_database_id
+ }
+ env {
+ name = "GCS_BUCKET_ID"
+ value = google_storage_bucket.import_bucket.name
+ }
+ env {
+ name = "LOCATION"
+ value = var.region
+ }
+ env {
+ name = "BQ_DATASET_ID"
+ value = var.bq_dataset_id
+ }
+ }
+ }
+
+ depends_on = [google_project_service.services]
+}
+
+resource "google_cloud_run_v2_service" "import_helper" {
+ name = "import-helper-service"
+ location = var.region
+ project = var.project_id
+
+ template {
+ service_account = google_service_account.automation_sa.email
+ containers {
+ image = "${var.artifact_registry_url}/datacommons-import-helper:latest"
+ env {
+ name = "PROJECT_ID"
+ value = var.project_id
+ }
+ env {
+ name = "LOCATION"
+ value = var.region
+ }
+ env {
+ name = "GCS_BUCKET_ID"
+ value = google_storage_bucket.import_bucket.name
+ }
+ env {
+ name = "INGESTION_HELPER_URL"
+ value = google_cloud_run_v2_service.ingestion_helper.uri
+ }
+ }
+ }
+
+ depends_on = [google_project_service.services]
+}
+
+# --- Cloud Workflows ---
+
+resource "google_workflows_workflow" "import_automation_workflow" {
+ name = "import-automation-workflow"
+ region = var.region
+ project = var.project_id
+ description = "Orchestrates the import automation process"
+ service_account = google_service_account.automation_sa.id
+ source_contents = file("${path.module}/../workflow/import-automation-workflow.yaml")
+
+ user_env_vars = {
+ LOCATION = var.region
+ GCS_BUCKET_ID = google_storage_bucket.import_bucket.name
+ GCS_MOUNT_BUCKET = google_storage_bucket.mount_bucket.name
+ INGESTION_HELPER_URL = google_cloud_run_v2_service.ingestion_helper.uri
+ }
+
+ depends_on = [google_project_service.services]
+}
+
+resource "google_workflows_workflow" "spanner_ingestion_workflow" {
+ name = "spanner-ingestion-workflow"
+ region = var.region
+ project = var.project_id
+ description = "Orchestrates Spanner ingestion"
+ service_account = google_service_account.automation_sa.id
+ source_contents = file("${path.module}/../workflow/spanner-ingestion-workflow.yaml")
+
+ user_env_vars = {
+ LOCATION = var.region
+ PROJECT_ID = var.project_id
+ SPANNER_PROJECT_ID = var.project_id
+ SPANNER_INSTANCE_ID = var.spanner_instance_id
+ SPANNER_DATABASE_ID = var.spanner_database_id
+ INGESTION_HELPER_URL = google_cloud_run_v2_service.ingestion_helper.uri
+ }
+
+ depends_on = [google_project_service.services]
+}
+
+# --- Spanner ---
+
+resource "google_spanner_instance" "import_instance" {
+ name = var.spanner_instance_id
+ config = "regional-${var.region}"
+ display_name = "Import Automation"
+ num_nodes = 1
+ project = var.project_id
+
+ depends_on = [google_project_service.services]
+}
+
+resource "google_spanner_database" "import_db" {
+ instance = google_spanner_instance.import_instance.name
+ name = var.spanner_database_id
+ project = var.project_id
+ deletion_protection = false
+}
+
+# --- IAM ---
+
+resource "google_service_account" "automation_sa" {
+ account_id = "import-automation-sa"
+ display_name = "Service Account for Import Automation (Workflows & Functions)"
+ project = var.project_id
+}
+
+resource "google_project_iam_member" "automation_roles" {
+ for_each = toset([
+ "roles/workflows.admin",
+ "roles/cloudfunctions.admin",
+ "roles/run.admin",
+ "roles/run.invoker",
+ "roles/batch.jobsEditor",
+ "roles/dataflow.admin",
+ "roles/logging.logWriter",
+ "roles/storage.objectAdmin",
+ "roles/iam.serviceAccountUser",
+ "roles/spanner.databaseAdmin",
+ "roles/bigquery.dataEditor",
+ "roles/bigquery.jobUser",
+ "roles/artifactregistry.admin",
+ "roles/secretmanager.secretAccessor",
+ "roles/cloudbuild.builds.builder",
+ ])
+ project = var.project_id
+ role = each.key
+ member = "serviceAccount:${google_service_account.automation_sa.email}"
+}
+
+# --- Artifact Registry ---
+
+resource "google_artifact_registry_repository" "automation_repo" {
+ location = var.region
+ repository_id = "import-automation"
+ description = "Docker repository for import automation images"
+ format = "DOCKER"
+ project = var.project_id
+
+ depends_on = [google_project_service.services]
+}
+
+# --- Pub/Sub ---
+
+resource "google_pubsub_topic" "import_automation_trigger" {
+ name = "import-automation-trigger"
+ project = var.project_id
+}
+
+resource "google_pubsub_subscription" "import_automation_sub" {
+ name = "import-automation-sub"
+ topic = google_pubsub_topic.import_automation_trigger.name
+ project = var.project_id
+
+ filter = "attributes.transfer_status=\"TRANSFER_COMPLETED\""
+
+ push_config {
+ push_endpoint = google_cloud_run_v2_service.import_helper.uri
+ oidc_token {
+ service_account_email = google_service_account.automation_sa.email
+ }
+ }
+}
+
+# Outputs
+output "automation_service_account_email" {
+ value = google_service_account.automation_sa.email
+ description = "The email of the service account used for import automation."
+}
+
+
diff --git a/pipeline/workflow/cloudbuild.yaml b/pipeline/workflow/cloudbuild.yaml
new file mode 100644
index 00000000..ed1f8720
--- /dev/null
+++ b/pipeline/workflow/cloudbuild.yaml
@@ -0,0 +1,50 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Child Cloud Build configuration for deploying to a specific environment.
+# Defaults are set to PRODUCTION. Staging builds must override these values.
+#
+substitutions:
+ # Production config.
+ _PROJECT_ID: 'datcom-import-automation-prod'
+ _SPANNER_PROJECT_ID: 'datcom-store'
+ _SPANNER_INSTANCE_ID: 'dc-kg-test'
+ _SPANNER_DATABASE_ID: 'dc_graph_import'
+ _SPANNER_GRAPH_DATABASE_ID: 'dc_graph_2025_11_07'
+ _GCS_BUCKET_ID: 'datcom-prod-imports'
+ _LOCATION: 'us-central1'
+ _GCS_MOUNT_BUCKET: 'datcom-volume-mount'
+ _BQ_DATASET_ID: 'datacommons'
+ _PROJECT_NUMBER: '965988403328'
+ _AR_REPO_URL: 'us-docker.pkg.dev/datcom-ci/gcr.io'
+
+steps:
+- id: 'ingestion-helper-service'
+ name: 'gcr.io/cloud-builders/gcloud'
+ args: ['run', 'deploy', 'ingestion-helper-service', '--image', '${_AR_REPO_URL}/datacommons-ingestion-helper:latest', '--region', '${_LOCATION}', '--project', '${_PROJECT_ID}', '--no-allow-unauthenticated', '--set-env-vars', 'PROJECT_ID=${_PROJECT_ID},LOCATION=${_LOCATION},SPANNER_PROJECT_ID=${_SPANNER_PROJECT_ID},SPANNER_INSTANCE_ID=${_SPANNER_INSTANCE_ID},SPANNER_DATABASE_ID=${_SPANNER_DATABASE_ID},SPANNER_GRAPH_DATABASE_ID=${_SPANNER_GRAPH_DATABASE_ID},GCS_BUCKET_ID=${_GCS_BUCKET_ID},BQ_DATASET_ID=${_BQ_DATASET_ID}']
+
+- id: 'import-helper-service'
+ name: 'gcr.io/cloud-builders/gcloud'
+ args: ['run', 'deploy', 'import-helper-service', '--image', '${_AR_REPO_URL}/datacommons-import-helper:latest', '--region', '${_LOCATION}', '--project', '${_PROJECT_ID}', '--no-allow-unauthenticated', '--set-env-vars', 'PROJECT_ID=${_PROJECT_ID},LOCATION=${_LOCATION},PROJECT_NUMBER=${_PROJECT_NUMBER},GCS_BUCKET_ID=${_GCS_BUCKET_ID}']
+
+- id: 'import-automation-workflow'
+ name: 'gcr.io/cloud-builders/gcloud'
+ args: ['workflows', 'deploy', 'import-automation-workflow', '--project', '${_PROJECT_ID}', '--location', '${_LOCATION}', '--source', 'import-automation-workflow.yaml', '--set-env-vars', 'LOCATION=${_LOCATION},GCS_BUCKET_ID=${_GCS_BUCKET_ID},GCS_MOUNT_BUCKET=${_GCS_MOUNT_BUCKET},PROJECT_NUMBER=${_PROJECT_NUMBER}']
+
+- id: 'spanner-ingestion-workflow'
+ name: 'gcr.io/cloud-builders/gcloud'
+ args: ['workflows', 'deploy', 'spanner-ingestion-workflow', '--project', '${_PROJECT_ID}', '--location', '${_LOCATION}', '--source', 'spanner-ingestion-workflow.yaml', '--set-env-vars', 'LOCATION=${_LOCATION},PROJECT_ID=${_PROJECT_ID},SPANNER_PROJECT_ID=${_SPANNER_PROJECT_ID},SPANNER_INSTANCE_ID=${_SPANNER_INSTANCE_ID},SPANNER_DATABASE_ID=${_SPANNER_GRAPH_DATABASE_ID},PROJECT_NUMBER=${_PROJECT_NUMBER}']
+
+options:
+ logging: CLOUD_LOGGING_ONLY
diff --git a/pipeline/workflow/cloudbuild_main.yaml b/pipeline/workflow/cloudbuild_main.yaml
new file mode 100644
index 00000000..8f7343a4
--- /dev/null
+++ b/pipeline/workflow/cloudbuild_main.yaml
@@ -0,0 +1,89 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Parent Cloud Build configuration that orchestrates Staging and Production deployments.
+# Usage: gcloud builds submit . --config=cloudbuild.yaml --project=datcom-ci
+
+substitutions:
+ # Staging Configuration (Overrides defaults in child build)
+ _PROJECT_ID: 'datcom-ci'
+ _SPANNER_PROJECT_ID: 'datcom-ci'
+ _SPANNER_INSTANCE_ID: 'datcom-spanner-test'
+ _SPANNER_DATABASE_ID: 'dc-test-db'
+ _SPANNER_GRAPH_DATABASE_ID: 'dc-test-db'
+ _GCS_BUCKET_ID: 'datcom-ci-test'
+ _GCS_MOUNT_BUCKET: 'datcom-ci-test'
+ _BQ_DATASET_ID: 'datacommons'
+ _LOCATION: 'us-central1'
+ _PROJECT_NUMBER: '879489846695'
+ _AR_REPO_URL: 'us-docker.pkg.dev/datcom-ci/gcr.io'
+
+steps:
+
+# 1. Build and push helper images
+- id: 'build-ingestion-helper'
+ name: 'gcr.io/cloud-builders/gcloud'
+ args: ['builds', 'submit', 'ingestion-helper', '--config', 'ingestion-helper/cloudbuild.yaml', '--substitutions', '_AR_REPO_URL=${_AR_REPO_URL}']
+ dir: 'import-automation/workflow'
+
+- id: 'build-import-helper'
+ name: 'gcr.io/cloud-builders/gcloud'
+ args: ['builds', 'submit', 'import-helper', '--config', 'import-helper/cloudbuild.yaml', '--substitutions', '_AR_REPO_URL=${_AR_REPO_URL}']
+ dir: 'import-automation/workflow'
+
+# 2. Trigger Staging Build (Child)
+# Overrides default (Production) values with Staging values.
+- id: 'deploy-staging'
+ name: 'gcr.io/cloud-builders/gcloud'
+ args:
+ - 'builds'
+ - 'submit'
+ - '.'
+ - '--config=cloudbuild.yaml'
+ - '--project=${_PROJECT_ID}'
+ - '--substitutions=_PROJECT_ID=${_PROJECT_ID},_SPANNER_PROJECT_ID=${_SPANNER_PROJECT_ID},_SPANNER_INSTANCE_ID=${_SPANNER_INSTANCE_ID},_SPANNER_DATABASE_ID=${_SPANNER_DATABASE_ID},_SPANNER_GRAPH_DATABASE_ID=${_SPANNER_GRAPH_DATABASE_ID},_GCS_BUCKET_ID=${_GCS_BUCKET_ID},_LOCATION=${_LOCATION},_GCS_MOUNT_BUCKET=${_GCS_MOUNT_BUCKET},_BQ_DATASET_ID=${_BQ_DATASET_ID},_PROJECT_NUMBER=${_PROJECT_NUMBER}'
+ dir: 'import-automation/workflow'
+
+# 2. Run E2E Tests on Staging
+- id: 'e2e-test-staging'
+ name: 'python:3.11'
+ entrypoint: 'bash'
+ args:
+ - '-c'
+ - |
+ pip install google-cloud-spanner google-cloud-workflows absl-py
+ python spanner_ingestion_test.py
+ env:
+ - 'PROJECT_ID=${_PROJECT_ID}'
+ - 'LOCATION=${_LOCATION}'
+ - 'SPANNER_PROJECT_ID=${_SPANNER_PROJECT_ID}'
+ - 'SPANNER_INSTANCE_ID=${_SPANNER_INSTANCE_ID}'
+ - 'SPANNER_DATABASE_ID=${_SPANNER_DATABASE_ID}'
+ - 'GCS_BUCKET_ID=${_GCS_BUCKET_ID}'
+ dir: 'import-automation/workflow'
+
+# 3. Trigger Production Build (Child)
+# Uses default (Production) values defined in cloudbuild.yaml.
+- id: 'deploy-prod'
+ name: 'gcr.io/cloud-builders/gcloud'
+ args:
+ - 'builds'
+ - 'submit'
+ - '.'
+ - '--config=cloudbuild.yaml'
+ - '--project=${_PROJECT_ID}' # Build runs in CI project, deploys to Prod
+ dir: 'import-automation/workflow'
+
+options:
+ logging: CLOUD_LOGGING_ONLY
diff --git a/pipeline/workflow/import-automation-workflow.yaml b/pipeline/workflow/import-automation-workflow.yaml
new file mode 100644
index 00000000..d6198b8c
--- /dev/null
+++ b/pipeline/workflow/import-automation-workflow.yaml
@@ -0,0 +1,120 @@
+main:
+ params: [args]
+ steps:
+ - init:
+ assign:
+ - projectId: ${sys.get_env("GOOGLE_CLOUD_PROJECT_ID")}
+ - region: ${sys.get_env("LOCATION")}
+ - imageUri: ${default(map.get(args, "imageUri"), "us-docker.pkg.dev/datcom-ci/gcr.io/dc-import-executor:stable")}
+ - jobId: ${text.replace_all(text.to_lower(text.substring(text.split(args.importName, ":")[1], 0, 50) + "-" + string(int(sys.now()))), "_", "-")}
+ - importName: ${args.importName}
+ - importConfig: ${default(map.get(args, "importConfig"), "{}")}
+ - gcsMountBucket: ${sys.get_env("GCS_MOUNT_BUCKET")}
+ - gcsImportBucket: ${sys.get_env("GCS_BUCKET_ID")}
+ - gcsMountPath: "/tmp/gcs"
+ - helperUrl: ${"https://ingestion-helper-service-" + sys.get_env("PROJECT_NUMBER") + "." + region + ".run.app"}
+ - startTime: ${sys.now()}
+ - defaultResources:
+ machine: "n2-standard-8"
+ cpu: 8000
+ memory: 32768
+ disk: 100
+ - resources: ${default(map.get(args, "resources"), defaultResources)}
+ - runIngestion: ${default(map.get(args, "runIngestion"), false)}
+ - ingestionArgs:
+ importList:
+ - ${text.split(importName, ":")[1]}
+ - runImportJob:
+ try:
+ call: googleapis.batch.v1.projects.locations.jobs.create
+ args:
+ parent: ${"projects/" + projectId + "/locations/" + region}
+ jobId: ${jobId}
+ body:
+ allocationPolicy:
+ instances:
+ - policy:
+ machineType: ${resources.machine}
+ provisioningModel: "STANDARD"
+ bootDisk:
+ image: "projects/debian-cloud/global/images/family/debian-12"
+ size_gb: ${resources.disk}
+ installOpsAgent: true
+ taskGroups:
+ taskSpec:
+ volumes:
+ - gcs:
+ remotePath: ${gcsMountBucket}
+ mountPath: ${gcsMountPath}
+ computeResource:
+ cpuMilli: ${resources.cpu}
+ memoryMib: ${resources.memory}
+ runnables:
+ - container:
+ imageUri: ${imageUri}
+ commands:
+ - ${"--import_name=" + importName}
+ - ${"--import_config=" + importConfig}
+ environment:
+ variables:
+ IMPORT_NAME: ${importName}
+ BATCH_JOB_NAME: ${jobId}
+ taskCount: 1
+ parallelism: 1
+ logsPolicy:
+ destination: CLOUD_LOGGING
+ connector_params:
+ timeout: 604800 #7 days
+ polling_policy:
+ initial_delay: 60
+ multiplier: 2
+ max_delay: 600
+ result: importJobResponse
+ except:
+ as: e
+ steps:
+ - updateImportStatus:
+ call: http.post
+ args:
+ url: ${helperUrl}
+ auth:
+ type: OIDC
+ body:
+ actionType: 'update_import_status'
+ jobId: ${jobId}
+ importName: ${importName}
+ status: 'FAILURE'
+ executionTime: ${int(sys.now() - startTime)}
+ latestVersion: ${"gs://" + gcsImportBucket + "/" + text.replace_all(importName, ":", "/")}
+ result: functionResponse
+ - failWorkflow:
+ raise: ${e}
+ - updateImportVersion:
+ call: http.post
+ args:
+ url: ${helperUrl}
+ auth:
+ type: OIDC
+ body:
+ actionType: 'update_import_version'
+ importName: ${importName}
+ version: 'STAGING'
+ override: false
+ comment: '${"import-workflow:" + sys.get_env("GOOGLE_CLOUD_WORKFLOW_EXECUTION_ID")}'
+ result: functionResponse
+ - runIngestion:
+ switch:
+ - condition: ${runIngestion}
+ steps:
+ - runSpannerIngestion:
+ call: googleapis.workflowexecutions.v1.projects.locations.workflows.executions.create
+ args:
+ parent: ${"projects/" + projectId + "/locations/" + region + "/workflows/spanner-ingestion-workflow"}
+ body:
+ argument: ${json.encode_to_string(ingestionArgs)}
+ connector_params:
+ skip_polling: true
+ - returnResult:
+ return:
+ jobId: ${jobId}
+ importName: ${importName}
\ No newline at end of file
diff --git a/pipeline/workflow/import-helper/Dockerfile b/pipeline/workflow/import-helper/Dockerfile
new file mode 100644
index 00000000..2473221b
--- /dev/null
+++ b/pipeline/workflow/import-helper/Dockerfile
@@ -0,0 +1,32 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+FROM python:3.12-slim
+
+# Copy uv binary
+COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
+
+# Allow statements and log messages to immediately appear in the logs
+ENV PYTHONUNBUFFERED True
+
+WORKDIR /app
+
+# Copy local code to the container image.
+COPY . .
+
+# Install production dependencies using uv.
+RUN uv pip install --system --no-cache -r requirements.txt
+
+# Run the functions framework
+CMD ["functions-framework", "--target", "handle_feed_event"]
diff --git a/pipeline/workflow/import-helper/cloudbuild.yaml b/pipeline/workflow/import-helper/cloudbuild.yaml
new file mode 100644
index 00000000..760d7b6b
--- /dev/null
+++ b/pipeline/workflow/import-helper/cloudbuild.yaml
@@ -0,0 +1,31 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+steps:
+ # Build the container image
+ - name: 'gcr.io/cloud-builders/docker'
+ args: ['build', '-t', '${_AR_REPO_URL}/${_IMAGE_NAME}:${_TAG}', '-t', '${_AR_REPO_URL}/${_IMAGE_NAME}:latest', '.']
+
+ # Push the container image
+ - name: 'gcr.io/cloud-builders/docker'
+ args: ['push', '${_AR_REPO_URL}/${_IMAGE_NAME}:${_TAG}']
+
+substitutions:
+ _AR_REPO_URL: 'us-docker.pkg.dev/datcom-ci/gcr.io'
+ _IMAGE_NAME: 'datacommons-import-helper'
+ _TAG: 'latest'
+
+images:
+ - '${_AR_REPO_URL}/${_IMAGE_NAME}:${_TAG}'
+ - '${_AR_REPO_URL}/${_IMAGE_NAME}:latest'
diff --git a/pipeline/workflow/import-helper/import_helper.py b/pipeline/workflow/import-helper/import_helper.py
new file mode 100644
index 00000000..50aeaf08
--- /dev/null
+++ b/pipeline/workflow/import-helper/import_helper.py
@@ -0,0 +1,197 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import json
+import logging
+import os
+import croniter
+from datetime import datetime, timezone
+from google.auth.transport.requests import Request
+from google.oauth2 import id_token
+from google.cloud import storage
+from google.cloud.workflows import executions_v1
+import requests
+
+logging.getLogger().setLevel(logging.INFO)
+
+PROJECT_ID = os.environ.get('PROJECT_ID')
+PROJECT_NUMBER = os.environ.get('PROJECT_NUMBER')
+LOCATION = os.environ.get('LOCATION')
+GCS_BUCKET_ID = os.environ.get('GCS_BUCKET_ID')
+INGESTION_HELPER_URL = f"https://ingestion-helper-service-{PROJECT_NUMBER}.{LOCATION}.run.app"
+SPANNER_INGESTION_WORKFLOW_ID = 'spanner-ingestion-workflow'
+IMPORT_AUTOMATION_WORKFLOW_ID = 'import-automation-workflow'
+
+
+def invoke_spanner_ingestion_workflow(import_name: str):
+ """Triggers the spanner ingestion workflow.
+
+ Args:
+ import_name: The name of the import.
+ """
+ workflow_args = {"importList": [import_name.split(':')[-1]]}
+
+ logging.info(f"Invoking {SPANNER_INGESTION_WORKFLOW_ID} for {import_name}")
+ execution_client = executions_v1.ExecutionsClient()
+ parent = f"projects/{PROJECT_ID}/locations/{LOCATION}/workflows/{SPANNER_INGESTION_WORKFLOW_ID}"
+ execution_req = executions_v1.Execution(argument=json.dumps(workflow_args))
+ response = execution_client.create_execution(parent=parent,
+ execution=execution_req)
+ logging.info(
+ f"Triggered workflow {SPANNER_INGESTION_WORKFLOW_ID} for {import_name}. Execution ID: {response.name}"
+ )
+
+
+def invoke_import_automation_workflow(import_name: str,
+ latest_version: str,
+ import_size: str,
+ graph_path: str, cron_schedule: str,
+ run_ingestion: bool = False):
+ """Triggers the import automation workflow.
+
+ Args:
+ import_name: The name of the import.
+ latest_version: The version of the import.
+ import_size: The size of the import ('small', 'medium', 'large').
+ graph_path: The graph path for the import.
+ cron_schedule: The cron schedule for the import.
+ run_ingestion: Whether to run the ingestion workflow after the import.
+ """
+ import_config = {
+ "user_script_args": [f"--version={latest_version}"],
+ "import_version_override": latest_version,
+ "graph_data_path": graph_path,
+ "cron_schedule_override": cron_schedule
+ }
+ workflow_args = {
+ "importName": import_name,
+ "importConfig": json.dumps(import_config),
+ "runIngestion": run_ingestion
+ }
+
+ if import_size == 'large':
+ workflow_args["resources"] = {
+ "machine": "n2-highmem-16",
+ "cpu": 16000,
+ "memory": 131072,
+ "disk": 100
+ }
+
+ logging.info(f"Invoking {IMPORT_AUTOMATION_WORKFLOW_ID} for {import_name}")
+ execution_client = executions_v1.ExecutionsClient()
+ parent = f"projects/{PROJECT_ID}/locations/{LOCATION}/workflows/{IMPORT_AUTOMATION_WORKFLOW_ID}"
+ execution_req = executions_v1.Execution(argument=json.dumps(workflow_args))
+ response = execution_client.create_execution(parent=parent,
+ execution=execution_req)
+ logging.info(
+ f"Triggered workflow {IMPORT_AUTOMATION_WORKFLOW_ID} for {import_name}. Execution ID: {response.name}"
+ )
+
+
+def update_import_status(import_name,
+ import_status,
+ import_version,
+ graph_path,
+ job_id,
+ cron_schedule=None):
+ """Updates the status for the specified import job.
+
+ Args:
+ import_name: The name of the import.
+ import_status: The new status of the import.
+ import_version: The version of the import.
+ graph_path: The graph path for the import.
+ job_id: The job ID associated with the import.
+ cron_schedule: The cron schedule for the import (optional).
+ """
+ logging.info(f"Updating {import_name} status: {import_status}")
+ latest_version = 'gs://' + GCS_BUCKET_ID + '/' + import_name.replace(
+ ':', '/') + '/' + import_version
+ request = {
+ 'actionType': 'update_import_status',
+ 'importName': import_name,
+ 'status': import_status,
+ 'job_id': job_id,
+ 'latestVersion': latest_version,
+ 'graphPath': graph_path
+ }
+ if cron_schedule:
+ try:
+ next_refresh = croniter.croniter(
+ cron_schedule,
+ datetime.now(timezone.utc)).get_next(datetime).isoformat()
+ request['nextRefresh'] = next_refresh
+ except (croniter.CroniterError) as e:
+ logging.error(
+ f"Error calculating next refresh from schedule '{cron_schedule}': {e}"
+ )
+ logging.info(f"Update request: {request}")
+ auth_req = Request()
+ token = id_token.fetch_id_token(auth_req, INGESTION_HELPER_URL)
+ headers = {'Authorization': f'Bearer {token}'}
+ response = requests.post(INGESTION_HELPER_URL,
+ json=request,
+ headers=headers)
+ response.raise_for_status()
+ logging.info(f"Updated status for {import_name}")
+
+
+def parse_message(request) -> dict:
+ """Processes the incoming Pub/Sub message.
+
+ Args:
+ request: The flask request object.
+
+ Returns:
+ A dictionary containing the message data, or None if invalid.
+ """
+ request_json = request.get_json(silent=True)
+ if not request_json or 'message' not in request_json:
+ logging.error('Invalid Pub/Sub message format')
+ return None
+
+ pubsub_message = request_json['message']
+ logging.info(f"Received Pub/Sub message: {pubsub_message}")
+ try:
+ data_bytes = base64.b64decode(pubsub_message["data"])
+ notification_json = data_bytes.decode("utf-8")
+ logging.info(f"Notification content: {notification_json}")
+ except Exception as e:
+ logging.error(f"Error decoding message data: {e}")
+
+ return pubsub_message
+
+
+def check_duplicate(message_id: str):
+ """Checks for duplicate messages using a GCS file.
+
+ Args:
+ message_id: The ID of the message to check.
+
+ Returns:
+ True if the message is a duplicate, False otherwise.
+ """
+ duplicate = False
+ if not message_id:
+ return duplicate
+ logging.info(f"Checking for existing message: {message_id}")
+ storage_client = storage.Client()
+ bucket = storage_client.bucket(GCS_BUCKET_ID)
+ blob = bucket.blob(f"google3/transfers/{message_id}")
+ try:
+ blob.upload_from_string("", if_generation_match=0)
+ except Exception:
+ duplicate = True
+ return duplicate
diff --git a/pipeline/workflow/import-helper/main.py b/pipeline/workflow/import-helper/main.py
new file mode 100644
index 00000000..c825ec2b
--- /dev/null
+++ b/pipeline/workflow/import-helper/main.py
@@ -0,0 +1,67 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import functions_framework
+import logging
+from datetime import datetime, timezone
+import import_helper as helper
+
+logging.getLogger().setLevel(logging.INFO)
+
+
+# Triggered from a message on a Cloud Pub/Sub topic.
+@functions_framework.http
+def handle_feed_event(request):
+ # Updates status in spanner and triggers ingestion workflow
+ # for an import using CDA feed
+ message = helper.parse_message(request)
+ if not message:
+ return 'Invalid Pub/Sub message format', 400
+
+ attributes = message.get('attributes', {})
+ message_id = message.get('messageId', '')
+ if attributes.get('transfer_status') != 'TRANSFER_COMPLETED':
+ return 'OK', 200
+
+ duplicate = helper.check_duplicate(message_id)
+ if duplicate:
+ logging.info(f"Message {message_id} already processed. Skipping.")
+ return 'OK', 200
+
+ import_name = attributes.get('import_name')
+ latest_version = attributes.get(
+ 'import_version',
+ datetime.now(timezone.utc).strftime("%Y-%m-%d"))
+ import_step = attributes.get('import_step', '')
+ graph_path = attributes.get('graph_path', "/**/*.mcf*")
+ import_size = attributes.get('import_size', '')
+ cron_schedule = attributes.get('cron_schedule', '')
+ if import_step == 'ingestion_workflow_single' or import_step == 'ingestion_workflow_batch':
+ import_status = 'STAGING'
+ job_id = attributes.get('feed_name', 'cda_feed')
+ helper.update_import_status(import_name, import_status, latest_version,
+ graph_path, job_id, cron_schedule)
+ if import_step == 'ingestion_workflow_single':
+ # Invoke ingestion workflow to trigger dataflow job
+ helper.invoke_spanner_ingestion_workflow(import_name)
+ elif import_step == 'import_automation_job' or import_step == 'import_automation_e2e':
+ # Invoke batch import job and optionally ingestion workflow to trigger dataflow job
+ run_ingestion = True if import_step == 'import_automation_e2e' else False
+ helper.invoke_import_automation_workflow(import_name, latest_version,
+ import_size, graph_path,
+ cron_schedule, run_ingestion)
+ else:
+ logging.info(f"Skipping import post processing.")
+
+ return 'OK', 200
diff --git a/pipeline/workflow/import-helper/requirements.txt b/pipeline/workflow/import-helper/requirements.txt
new file mode 100644
index 00000000..9d321e81
--- /dev/null
+++ b/pipeline/workflow/import-helper/requirements.txt
@@ -0,0 +1,6 @@
+functions-framework==3.*
+google-cloud-workflows
+google-auth
+requests
+google-cloud-storage
+croniter
diff --git a/pipeline/workflow/ingestion-helper/Dockerfile b/pipeline/workflow/ingestion-helper/Dockerfile
new file mode 100644
index 00000000..7f95ae8b
--- /dev/null
+++ b/pipeline/workflow/ingestion-helper/Dockerfile
@@ -0,0 +1,41 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+FROM python:3.12-slim
+
+# Copy uv binary
+COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
+
+# Allow statements and log messages to immediately appear in the logs
+ENV PYTHONUNBUFFERED True
+
+WORKDIR /app
+
+# Install protobuf compiler and curl
+RUN apt-get update && apt-get install -y protobuf-compiler curl && rm -rf /var/lib/apt/lists/*
+
+# Copy local code to the container image.
+COPY . .
+
+# Install production dependencies using uv.
+RUN uv pip install --system --no-cache .
+
+# Fetch proto file from GitHub
+RUN curl -o storage.proto https://raw.githubusercontent.com/datacommonsorg/import/master/pipeline/data/src/main/proto/storage.proto
+
+# Generate proto descriptor set
+RUN protoc --include_imports --descriptor_set_out=storage.pb storage.proto
+
+# Run the functions framework
+CMD ["functions-framework", "--target", "ingestion_helper"]
diff --git a/pipeline/workflow/ingestion-helper/README.md b/pipeline/workflow/ingestion-helper/README.md
new file mode 100644
index 00000000..7de6d49a
--- /dev/null
+++ b/pipeline/workflow/ingestion-helper/README.md
@@ -0,0 +1,98 @@
+# Ingestion Helper Cloud Function
+
+This Cloud Function provides helper routines for the Data Commons Spanner ingestion workflow. It handles tasks such as locking, status updates, and import list retrieval.
+
+## Usage
+
+The function expects a JSON payload with a required `actionType` parameter, which determines the operation to perform.
+
+### Common Parameters
+
+* `actionType` (Required): A string specifying the action to execute.
+
+### Supported Actions and Parameters
+
+#### `get_import_info`
+Gets the details of imports that are ready for ingestion.
+
+* `importList` (Optional): list of imports to ingest.
+
+#### `acquire_ingestion_lock`
+Attempts to acquire the global lock for ingestion to prevent concurrent modifications.
+
+* `workflowId` (Required): The ID of the workflow attempting to acquire the lock.
+* `timeout` (Required): The duration (in seconds) for which the lock should be held.
+
+#### `release_ingestion_lock`
+Releases the global ingestion lock.
+
+* `workflowId` (Required): The ID of the workflow releasing the lock.
+
+#### `update_ingestion_status`
+Updates the status of imports after an ingestion job completes.
+
+* `importList` (Required): A list of import names involved in the ingestion.
+* `workflowId` (Required): The ID of the workflow.
+* `status` (Required): Import status.
+* `jobId` (Required): The Dataflow job ID associated with the ingestion.
+
+#### `update_import_status`
+Updates the status of a specific import job.
+
+* `importName` (Required): The name of the import.
+* `status` (Required): The new status to set.
+* `jobId` (Optional): The Dataflow job ID.
+* `executionTime` (Optional): Execution time in seconds.
+* `dataVolume` (Optional): Data volume in bytes.
+* `latestVersion` (Optional): Latest version string.
+* `graphPath` (Optional): Graph path regex.
+* `nextRefresh` (Optional): Next refresh timestamp.
+
+
+#### `update_import_version`
+Updates the version of an import, records version history, and updates the status.
+
+* `importName` (Required): The name of the import.
+* `version` (Required): The version string. If set to `'STAGING'`, it resolves to the current staging version.
+* `comment` (Required): A comment for the audit log explaining the version update.
+* `override` (Optional): Override version without checking import status (boolean)
+
+#### `initialize_database`
+Initializes the Spanner database by creating all necessary tables and uploading proto descriptors.
+
+* This action requires no payload parameters. It automatically reads `schema.sql` and `storage.pb` from the container directory to provision the database schema and proto descriptors.
+* `enableEmbeddings` (Optional): Boolean to enable creation of embedding tables and models.
+* **Note on Protos**: The `storage.pb` file is generated during the Docker build process. The `Dockerfile` fetches `storage.proto` from the `datacommonsorg/import` GitHub repository and compiles it into `storage.pb`.
+
+#### `embedding_ingestion`
+Triggers the generation of embeddings for updated nodes in Spanner. It fetches nodes of specific types (e.g., `StatisticalVariable`, `Topic`) that have been updated, generates embeddings using a remote ML model in Spanner, and stores the results in the `NodeEmbedding` table.
+
+* `enableEmbeddings` (Optional): Boolean to override the default setting for enabling embeddings. If false or missing and default is false, it skips embedding generation.
+* **Flags**:
+ - `--node_types`: A comma-separated list of node types to process (default: `StatisticalVariable,Topic`). This is a command-line flag for the service, not a request parameter.
+
+## Local Development and Testing
+
+To run the helper service locally and test its functionality:
+
+### Running the Server
+Ensure you have installed the requirements (`uv pip install -r requirements.txt`), then start the functions framework:
+
+```bash
+uv run functions-framework --target ingestion_helper
+```
+By default, this will start serving on `http://localhost:8080`.
+
+### Triggering Actions
+You can test specific actions by sending a POST request with a JSON payload. For example, to trigger database initialization locally:
+```bash
+curl -X POST http://localhost:8080 \
+ -H "Content-Type: application/json" \
+ -d '{"actionType": "initialize_database"}'
+```
+### Running unit tests
+Run unit tests with uv using:
+
+```bash
+uv run pytest
+```
diff --git a/pipeline/workflow/ingestion-helper/aggregation_utils.py b/pipeline/workflow/ingestion-helper/aggregation_utils.py
new file mode 100644
index 00000000..16fe06c7
--- /dev/null
+++ b/pipeline/workflow/ingestion-helper/aggregation_utils.py
@@ -0,0 +1,69 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+from google.cloud import bigquery
+
+logging.getLogger().setLevel(logging.INFO)
+
+class AggregationUtils:
+ def __init__(self):
+ # Initialize BigQuery Client
+ try:
+ self.bq_client = bigquery.Client()
+ except Exception as e:
+ logging.warning(f"Failed to initialize BigQuery client: {e}")
+ self.bq_client = None
+
+ self.bq_dataset_id = os.environ.get('BQ_DATASET_ID')
+
+ def run_aggregation(self, import_list):
+ """
+ Runs a BQ query for each import in the import_list.
+ """
+ logging.info(f"Received request for importList: {import_list}")
+ results = []
+ if not self.bq_client:
+ logging.error("BigQuery client not initialized")
+ return False
+
+ try:
+ for import_item in import_list:
+ import_name = import_item.get('importName')
+
+ query = None
+ # Define specific queries based on importName
+ if import_name:
+ query = """
+ SELECT @import_name as import_name, CURRENT_TIMESTAMP() as execution_time
+ """
+ else:
+ logging.info('Skipping aggregation logic')
+ continue
+
+ if query:
+ job_config = bigquery.QueryJobConfig(query_parameters=[
+ bigquery.ScalarQueryParameter("import_name", "STRING",
+ import_name),
+ ])
+ query_job = self.bq_client.query(query, job_config=job_config)
+ query_results = query_job.result()
+ for row in query_results:
+ results.append(dict(row))
+ return True
+
+ except Exception as e:
+ logging.error(f"Aggregation failed: {e}")
+ raise e
diff --git a/pipeline/workflow/ingestion-helper/cloudbuild.yaml b/pipeline/workflow/ingestion-helper/cloudbuild.yaml
new file mode 100644
index 00000000..632b3bf1
--- /dev/null
+++ b/pipeline/workflow/ingestion-helper/cloudbuild.yaml
@@ -0,0 +1,31 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+steps:
+ # Build the container image
+ - name: 'gcr.io/cloud-builders/docker'
+ args: ['build', '-t', '${_AR_REPO_URL}/${_IMAGE_NAME}:${_TAG}', '-t', '${_AR_REPO_URL}/${_IMAGE_NAME}:latest', '.']
+
+ # Push the container image
+ - name: 'gcr.io/cloud-builders/docker'
+ args: ['push', '${_AR_REPO_URL}/${_IMAGE_NAME}:${_TAG}']
+
+substitutions:
+ _AR_REPO_URL: 'us-docker.pkg.dev/datcom-ci/gcr.io'
+ _IMAGE_NAME: 'datacommons-ingestion-helper'
+ _TAG: 'latest'
+
+images:
+ - '${_AR_REPO_URL}/${_IMAGE_NAME}:${_TAG}'
+ - '${_AR_REPO_URL}/${_IMAGE_NAME}:latest'
diff --git a/pipeline/workflow/ingestion-helper/embedding_schema.sql b/pipeline/workflow/ingestion-helper/embedding_schema.sql
new file mode 100644
index 00000000..50d57ce3
--- /dev/null
+++ b/pipeline/workflow/ingestion-helper/embedding_schema.sql
@@ -0,0 +1,44 @@
+-- Copyright 2026 Google LLC
+--
+-- Licensed under the Apache License, Version 2.0 (the "License")
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+
+CREATE TABLE NodeEmbedding (
+ subject_id STRING(1024) NOT NULL,
+ embedding_content STRING(MAX),
+ types ARRAY,
+ embeddings ARRAY(vector_length=>768)
+) PRIMARY KEY(subject_id),
+INTERLEAVE IN PARENT Node ON DELETE CASCADE;
+
+CREATE VECTOR INDEX NodeEmbeddingIndex
+ON NodeEmbedding(embeddings)
+WHERE embeddings IS NOT NULL
+OPTIONS (
+ distance_type = 'COSINE',
+ flat_index = true
+);
+
+CREATE MODEL NodeEmbeddingModel
+INPUT(
+ content STRING(MAX),
+ task_type STRING(MAX),
+)
+OUTPUT(
+ embeddings
+ STRUCT<
+ statistics STRUCT,
+ values ARRAY>
+)
+REMOTE OPTIONS (
+ endpoint = '{{ embeddings_endpoint }}'
+);
diff --git a/pipeline/workflow/ingestion-helper/embedding_utils.py b/pipeline/workflow/ingestion-helper/embedding_utils.py
new file mode 100644
index 00000000..333705e3
--- /dev/null
+++ b/pipeline/workflow/ingestion-helper/embedding_utils.py
@@ -0,0 +1,169 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Helper utilities for embedding workflows."""
+
+import itertools
+import logging
+import time
+from datetime import datetime
+from google.cloud.spanner_v1.param_types import TIMESTAMP, STRING, Array, Struct, StructField
+
+
+_BATCH_SIZE = 1000
+
+def get_latest_lock_timestamp(database):
+ """Gets the latest AcquiredTimestamp from IngestionLock table.
+
+ Args:
+ database: google.cloud.spanner.Database object.
+
+ Returns:
+ The latest AcquiredTimestamp as a datetime object, or None if no entries exist.
+ """
+ time_lock_sql = "SELECT MAX(AcquiredTimestamp) FROM IngestionLock"
+ try:
+ with database.snapshot() as snapshot:
+ results = snapshot.execute_sql(time_lock_sql)
+ for row in results:
+ return row[0]
+ except Exception as e:
+ logging.error(f"Error fetching latest lock timestamp: {e}")
+ raise
+ return None
+
+def get_updated_nodes(database, timestamp, node_types):
+ """Gets subject_ids and names from Node table where update_timestamp > timestamp.
+ Yields results to avoid loading all into memory.
+
+ Args:
+ database: google.cloud.spanner.Database object.
+ timestamp: datetime object to filter by.
+ node_types: A list of strings representing the node types to filter by.
+
+ Yields:
+ Dictionaries containing subject_id and name.
+ """
+ timestamp_condition = "update_timestamp > @timestamp" if timestamp else "TRUE"
+
+ updated_node_sql = f"""
+ SELECT subject_id, name, types FROM Node
+ WHERE name IS NOT NULL
+ AND {timestamp_condition}
+ AND EXISTS (
+ SELECT 1 FROM UNNEST(types) AS t WHERE t IN UNNEST(@node_types)
+ )
+ """
+
+ params = {"node_types": node_types}
+ param_types = {"node_types": Array(STRING)}
+
+ if timestamp:
+ logging.info(f"Filtering valid nodes updated after {timestamp}")
+ params["timestamp"] = timestamp
+ param_types["timestamp"] = TIMESTAMP
+ else:
+ logging.info("No timestamp provided, reading all valid nodes.")
+
+ try:
+ with database.snapshot() as snapshot:
+ results = snapshot.execute_sql(updated_node_sql, params=params, param_types=param_types, timeout=300)
+ fields = None
+ for row in results:
+ if fields is None:
+ fields = [field.name for field in results.fields]
+ yield dict(zip(fields, row))
+ except Exception as e:
+ logging.error(f"Error fetching updated nodes: {e}")
+ raise
+
+
+def filter_and_convert_nodes(nodes_generator):
+ """Filters out nodes without a name and converts dictionaries to tuples.
+ Reads from a generator and yields results.
+
+ Args:
+ nodes_generator: A generator yielding dictionaries containing subject_id, name, and types.
+
+ Yields:
+ Tuples (subject_id, embedding_content, types).
+ """
+ for node in nodes_generator:
+ if node.get("name"):
+ yield (node.get("subject_id"), node.get("name"), node.get("types"))
+
+
+def generate_embeddings_partitioned(database, nodes_generator):
+ """Generates embeddings in batches using standard transactions.
+ Processes nodes in chunks of 500 to avoid transaction size limits.
+ Accepts a generator to avoid loading all nodes into memory.
+
+ Args:
+ database: google.cloud.spanner.Database object.
+ nodes_generator: A generator yielding tuples containing (subject_id, embedding_content).
+
+ Returns:
+ The number of affected rows.
+ """
+ global _BATCH_SIZE
+ total_rows_affected = 0
+
+ logging.info(f"Generating embeddings in batches of {_BATCH_SIZE}.")
+
+ embeddings_sql = """
+ INSERT OR UPDATE INTO NodeEmbedding (subject_id, embedding_content, embeddings, types)
+ SELECT subject_id, content, embeddings.values, types
+ FROM ML.PREDICT(
+ MODEL NodeEmbeddingModel,
+ (SELECT subject_id, embedding_content AS content, types, "RETRIEVAL_QUERY" AS task_type FROM UNNEST(@nodes))
+ )
+ """
+
+ struct_type = Struct([
+ StructField("subject_id", STRING),
+ StructField("embedding_content", STRING),
+ StructField("types", Array(STRING))
+ ])
+
+ def chunked(iterable, n):
+ it = iter(iterable)
+ while True:
+ chunk = list(itertools.islice(it, n))
+ if not chunk:
+ break
+ yield chunk
+
+ for batch in chunked(nodes_generator, _BATCH_SIZE):
+ params = {"nodes": batch}
+ param_types = {"nodes": Array(struct_type)}
+
+ def _execute_dml(transaction):
+ return transaction.execute_update(embeddings_sql, params=params, param_types=param_types, timeout=300)
+
+ try:
+ row_count = database.run_in_transaction(_execute_dml)
+ total_rows_affected += row_count
+ logging.info(f"Processed batch of {len(batch)} nodes. Affected total {total_rows_affected} rows.")
+ time.sleep(0.5)
+ except Exception as e:
+ logging.error(f"Error executing batch transaction: {e}")
+ raise
+
+ logging.info(f"Completed batch processing. Total affected rows: {total_rows_affected}")
+ return total_rows_affected
+
+
+
+
+
diff --git a/pipeline/workflow/ingestion-helper/embedding_utils_test.py b/pipeline/workflow/ingestion-helper/embedding_utils_test.py
new file mode 100644
index 00000000..299b293d
--- /dev/null
+++ b/pipeline/workflow/ingestion-helper/embedding_utils_test.py
@@ -0,0 +1,166 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from unittest.mock import MagicMock, patch
+from datetime import datetime
+
+from embedding_utils import (
+ get_latest_lock_timestamp,
+ get_updated_nodes,
+ filter_and_convert_nodes,
+ generate_embeddings_partitioned
+)
+
+class TestEmbeddingUtils(unittest.TestCase):
+
+ def test_get_latest_lock_timestamp(self):
+ mock_database = MagicMock()
+ mock_snapshot = MagicMock()
+ mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
+ expected_timestamp = datetime(2026, 4, 20, 12, 0, 0)
+ mock_snapshot.execute_sql.return_value = [(expected_timestamp,)]
+
+ timestamp = get_latest_lock_timestamp(mock_database)
+ self.assertEqual(timestamp, expected_timestamp)
+
+ def test_get_updated_nodes(self):
+ mock_database = MagicMock()
+ mock_snapshot = MagicMock()
+ mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
+
+ class MockField:
+ def __init__(self, name):
+ self.name = name
+
+ class MockResults:
+ def __init__(self, rows, field_names):
+ self.rows = rows
+ self.fields = [MockField(name) for name in field_names]
+
+ def __iter__(self):
+ return iter(self.rows)
+
+ mock_snapshot.execute_sql.return_value = MockResults(
+ rows=[("dc/1", "Node 1", ["Topic"])],
+ field_names=["subject_id", "name", "types"]
+ )
+
+ nodes = list(get_updated_nodes(mock_database, None, ["Topic"]))
+
+ # Verify Spanner call
+ mock_snapshot.execute_sql.assert_called_once()
+ args, kwargs = mock_snapshot.execute_sql.call_args
+ query = args[0]
+ self.assertIn("SELECT subject_id, name, types FROM Node", query)
+ self.assertIn("TRUE", query)
+ self.assertEqual(kwargs["params"], {"node_types": ["Topic"]})
+
+ self.assertEqual(len(nodes), 1)
+ self.assertEqual(nodes[0]["subject_id"], "dc/1")
+ self.assertEqual(nodes[0]["name"], "Node 1")
+ self.assertEqual(nodes[0]["types"], ["Topic"])
+
+ def test_get_updated_nodes_with_timestamp(self):
+ mock_database = MagicMock()
+ mock_snapshot = MagicMock()
+ mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot
+
+ class MockField:
+ def __init__(self, name):
+ self.name = name
+
+ class MockResults:
+ def __init__(self, rows, field_names):
+ self.rows = rows
+ self.fields = [MockField(name) for name in field_names]
+
+ def __iter__(self):
+ return iter(self.rows)
+
+ mock_snapshot.execute_sql.return_value = MockResults(
+ rows=[("dc/2", "Node 2", ["Topic"])],
+ field_names=["subject_id", "name", "types"]
+ )
+
+ test_timestamp = datetime(2026, 4, 25, 0, 0, 0)
+ nodes = list(get_updated_nodes(mock_database, test_timestamp, ["Topic"]))
+
+ # Verify Spanner call
+ mock_snapshot.execute_sql.assert_called_once()
+ args, kwargs = mock_snapshot.execute_sql.call_args
+ query = args[0]
+ self.assertIn("SELECT subject_id, name, types FROM Node", query)
+ self.assertIn("update_timestamp > @timestamp", query)
+ self.assertEqual(kwargs["params"], {"node_types": ["Topic"], "timestamp": test_timestamp})
+
+ self.assertEqual(len(nodes), 1)
+ self.assertEqual(nodes[0]["subject_id"], "dc/2")
+
+ def test_filter_and_convert_nodes(self):
+ nodes = [
+ {"subject_id": "dc/1", "name": "Node 1", "types": ["Topic"]},
+ {"subject_id": "dc/2", "name": None, "types": ["StatisticalVariable"]},
+ {"subject_id": "dc/3", "name": "Node 3", "types": ["Topic", "StatisticalVariable"]},
+ {"subject_id": "dc/4", "name": "", "types": ["StatisticalVariable"]}
+ ]
+
+ converted = list(filter_and_convert_nodes(nodes))
+ self.assertEqual(len(converted), 2)
+ self.assertEqual(converted[0], ("dc/1", "Node 1", ["Topic"]))
+ self.assertEqual(converted[1], ("dc/3", "Node 3", ["Topic", "StatisticalVariable"]))
+
+ @patch('embedding_utils._BATCH_SIZE', 2)
+ def test_generate_embeddings_partitioned(self):
+ mock_database = MagicMock()
+
+ nodes = [
+ ("dc/1", "Node 1", ["Topic"]),
+ ("dc/2", "Node 2", ["Topic"]),
+ ("dc/3", "Node 3", ["Topic"]),
+ ("dc/4", "Node 4", ["Topic"]),
+ ("dc/5", "Node 5", ["Topic"]),
+ ("dc/6", "Node 6", ["Topic"]),
+ ("dc/7", "Node 7", ["Topic"]),
+ ("dc/8", "Node 8", ["Topic"])
+ ]
+
+ transactions = []
+ def side_effect(func):
+ mock_transaction = MagicMock()
+ mock_transaction.execute_update.return_value = 2
+ transactions.append(mock_transaction)
+ return func(mock_transaction)
+
+ mock_database.run_in_transaction.side_effect = side_effect
+
+ affected_rows = generate_embeddings_partitioned(mock_database, nodes)
+ self.assertEqual(affected_rows, 8)
+ self.assertEqual(mock_database.run_in_transaction.call_count, 4)
+
+ # Verify execute_update calls
+ self.assertEqual(len(transactions), 4)
+ for i, tx in enumerate(transactions):
+ tx.execute_update.assert_called_once()
+ args, kwargs = tx.execute_update.call_args
+ self.assertIn("INSERT OR UPDATE INTO NodeEmbedding", args[0])
+
+ # Verify batch content
+ batch = kwargs["params"]["nodes"]
+ self.assertEqual(len(batch), 2)
+ self.assertEqual(batch[0][0], f"dc/{i*2 + 1}")
+ self.assertEqual(batch[1][0], f"dc/{i*2 + 2}")
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pipeline/workflow/ingestion-helper/import_utils.py b/pipeline/workflow/ingestion-helper/import_utils.py
new file mode 100644
index 00000000..33f9d1fa
--- /dev/null
+++ b/pipeline/workflow/ingestion-helper/import_utils.py
@@ -0,0 +1,175 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Utility functions for the ingestion helper."""
+
+import logging
+import re
+from datetime import datetime, timezone
+from googleapiclient.discovery import build
+from googleapiclient.errors import HttpError
+from google.oauth2 import id_token
+from google.auth.transport import requests
+from google.auth import jwt
+
+
+def get_next_refresh(project_id: str, location: str, import_name: str) -> str:
+ """Fetches the next scheduled run time for the import job from Cloud Scheduler.
+
+ Args:
+ project_id: The GCP project ID.
+ location: The location of the Cloud Scheduler job.
+ import_name: The name of the import (used as the job name).
+
+ Returns:
+ The next scheduled run time as an ISO formatted string, or None if not found/error.
+ """
+ try:
+ scheduler = build('cloudscheduler', 'v1', cache_discovery=False)
+ job_id = import_name.split(':')[-1]
+ job_name = f"projects/{project_id}/locations/{location}/jobs/{job_id}"
+ job = scheduler.projects().locations().jobs().get(
+ name=job_name).execute()
+ return job.get('scheduleTime')
+ except HttpError as e:
+ logging.warning(f"Could not fetch scheduler job {import_name}: {e}")
+ return None
+
+
+def get_caller_identity(request):
+ """Extracts the caller's email from the Authorization header (JWT).
+
+ Args:
+ request: The HTTP request object.
+
+ Returns:
+ The email of the caller, or an error string/warning if extraction fails.
+ """
+ auth_header = request.headers.get('Authorization')
+ if auth_header:
+ parts = auth_header.split()
+ if len(parts) == 2 and parts[0].lower() == 'bearer':
+ token = parts[1]
+ unverified_claims = {}
+ try:
+ unverified_claims = jwt.decode(token, verify=False)
+ id_info = id_token.verify_oauth2_token(token,
+ requests.Request())
+ return id_info.get('email', 'unknown_email')
+ except Exception as e:
+ if unverified_claims:
+ logging.warning(
+ f"Could not decode unverified token for debugging: {e}")
+ email = unverified_claims.get('email', 'unknown_email')
+ return f"{email}"
+ return 'decode_error'
+ else:
+ logging.warning(
+ f"Invalid Authorization header format. Parts: {len(parts)}")
+ else:
+ logging.warning("No Authorization header received.")
+ return 'no_auth_header'
+
+
+def get_import_params(request) -> dict:
+ """Extracts and calculates import parameters from the request JSON.
+
+ Args:
+ request_json: A dictionary containing request parameters.
+
+ Returns:
+ A dictionary with import params.
+ """
+ # Convert CamelCase or mixedCase to snake_case.
+ request_json = {
+ re.sub(r'(?,
+ last_update_timestamp TIMESTAMP OPTIONS (allow_commit_timestamp=true),
+ name_tokenlist TOKENLIST AS (TOKENIZE_FULLTEXT(name)) HIDDEN,
+) PRIMARY KEY(subject_id);
+
+CREATE TABLE Edge (
+ subject_id STRING(1024) NOT NULL,
+ predicate STRING(1024) NOT NULL,
+ object_id STRING(1024) NOT NULL,
+ provenance STRING(1024) NOT NULL,
+) PRIMARY KEY(subject_id, predicate, object_id, provenance),
+INTERLEAVE IN Node;
+
+CREATE TABLE Observation (
+ observation_about STRING(1024) NOT NULL,
+ variable_measured STRING(1024) NOT NULL,
+ facet_id STRING(1024) NOT NULL,
+ observation_period STRING(1024),
+ measurement_method STRING(1024),
+ unit STRING(1024),
+ scaling_factor STRING(1024),
+ observations org.datacommons.Observations,
+ import_name STRING(1024),
+ provenance_url STRING(1024),
+ is_dc_aggregate BOOL,
+) PRIMARY KEY(observation_about, variable_measured, facet_id);
+
+CREATE TABLE ImportStatus (
+ ImportName STRING(MAX) NOT NULL,
+ LatestVersion STRING(MAX),
+ GraphPath STRING(MAX),
+ State STRING(1024) NOT NULL,
+ JobId STRING(1024),
+ WorkflowId STRING(1024),
+ ExecutionTime INT64,
+ DataVolume INT64,
+ DataImportTimestamp TIMESTAMP OPTIONS ( allow_commit_timestamp = TRUE ),
+ StatusUpdateTimestamp TIMESTAMP OPTIONS ( allow_commit_timestamp = TRUE ),
+ NextRefreshTimestamp TIMESTAMP,
+) PRIMARY KEY(ImportName);
+
+CREATE TABLE IngestionHistory (
+ CompletionTimestamp TIMESTAMP NOT NULL OPTIONS ( allow_commit_timestamp = TRUE ),
+ IngestionFailure Bool NOT NULL,
+ WorkflowExecutionID STRING(1024) NOT NULL,
+ DataflowJobID STRING(1024),
+ IngestedImports ARRAY,
+ ExecutionTime INT64,
+ NodeCount INT64,
+ EdgeCount INT64,
+ ObservationCount INT64,
+) PRIMARY KEY(CompletionTimestamp DESC);
+
+CREATE TABLE ImportVersionHistory (
+ ImportName STRING(MAX) NOT NULL,
+ Version STRING(MAX) NOT NULL,
+ UpdateTimestamp TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true),
+ Comment STRING(MAX),
+) PRIMARY KEY (ImportName, UpdateTimestamp DESC);
+
+CREATE TABLE IngestionLock (
+ LockID STRING(1024) NOT NULL,
+ LockOwner STRING(1024),
+ AcquiredTimestamp TIMESTAMP OPTIONS ( allow_commit_timestamp = TRUE ),
+) PRIMARY KEY(LockID);
+
+CREATE PROPERTY GRAPH DCGraph
+ NODE TABLES(
+ Node
+ KEY(subject_id)
+ LABEL Node PROPERTIES(
+ bytes,
+ name,
+ subject_id,
+ types,
+ value)
+ )
+ EDGE TABLES(
+ Edge
+ KEY(subject_id, predicate, object_id, provenance)
+ SOURCE KEY(subject_id) REFERENCES Node(subject_id)
+ DESTINATION KEY(object_id) REFERENCES Node(subject_id)
+ LABEL Edge PROPERTIES(
+ object_id,
+ predicate,
+ provenance,
+ subject_id)
+ );
+
+CREATE TABLE Cache (
+ type STRING(1024) NOT NULL,
+ key STRING(1024) NOT NULL,
+ provenance STRING(1024) NOT NULL,
+ value JSON,
+) PRIMARY KEY(type, key, provenance);
+
+CREATE TABLE VariableMetadata (
+ variable_measured STRING(1024) NOT NULL,
+ import_name STRING(1024) NOT NULL,
+ facet_id STRING(1024) NOT NULL,
+ observation_period STRING(1024),
+ measurement_method STRING(1024),
+ unit STRING(1024),
+ scaling_factor STRING(1024),
+ is_dc_aggregate BOOL,
+ total_observations INT64,
+ observed_places INT64,
+ min_date STRING(1024),
+ max_date STRING(1024),
+ place_types ARRAY,
+) PRIMARY KEY(variable_measured, import_name);
+
+
diff --git a/pipeline/workflow/ingestion-helper/spanner_client.py b/pipeline/workflow/ingestion-helper/spanner_client.py
new file mode 100644
index 00000000..27b30a08
--- /dev/null
+++ b/pipeline/workflow/ingestion-helper/spanner_client.py
@@ -0,0 +1,557 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+from google.cloud import spanner
+from google.cloud.spanner_admin_database_v1 import DatabaseAdminClient
+from google.cloud.spanner_admin_database_v1.types import UpdateDatabaseDdlRequest
+from google.cloud.spanner_v1 import Transaction
+from google.cloud.spanner_v1.param_types import STRING, TIMESTAMP, Array, INT64
+from datetime import datetime, timezone
+from jinja2 import Template
+
+logging.getLogger().setLevel(logging.INFO)
+
+
+class SpannerClient:
+ """
+ Spanner client to handle tasks like acquiring/releasing lock
+ and getting/updating import statuses.
+ """
+ _LOCK_ID = "global_ingestion_lock"
+ _EMBEDDING_MODEL_PATH = "projects/{project}/locations/{location}/publishers/google/models/{model}"
+
+ def __init__(self,
+ project_id: str,
+ instance_id: str,
+ database_id: str,
+ graph_database_id: str = None,
+ location: str = None,
+ model_id: str = None):
+ """Initializes a Spanner client and connects to a specific database."""
+ spanner_client = spanner.Client(
+ project=project_id,
+ client_options={'quota_project_id': project_id},
+ disable_builtin_metrics=True)
+ instance = spanner_client.instance(instance_id)
+ database = instance.database(database_id)
+ logging.info(f"Successfully initialized database: {database.name}")
+ self.database = database
+ self.graph_database = database
+ if graph_database_id:
+ self.graph_database = instance.database(graph_database_id)
+ logging.info(
+ f"Successfully initialized graph database: {self.graph_database.name}"
+ )
+ self.project_id = project_id
+ self.location = location
+ self.model_id = model_id
+
+ def _get_embeddings_endpoint(self) -> str:
+ """Returns the parameterized embedding model endpoint."""
+ return self._EMBEDDING_MODEL_PATH.format(project=self.project_id,
+ location=self.location or
+ "us-central1",
+ model=self.model_id or
+ "text-embedding-005")
+
+ def acquire_lock(self, workflow_id: str, timeout: int) -> bool:
+ """Attempts to acquire the global ingestion lock.
+
+ Args:
+ workflow_id: The ID of the workflow attempting to acquire the lock.
+ timeout: The duration in seconds after which a lock is considered stale.
+
+ Returns:
+ True if the lock was acquired, False otherwise.
+ """
+ logging.info(f"Attempting to acquire lock for {workflow_id}")
+
+ def _acquire(transaction: Transaction) -> bool:
+ sql = "SELECT LockOwner, AcquiredTimestamp FROM IngestionLock WHERE LockID = @lockId"
+ params = {"lockId": self._LOCK_ID}
+ param_types = {"lockId": STRING}
+
+ row_found = False
+ results = transaction.execute_sql(sql, params, param_types)
+ for row in results:
+ row_found = True
+ current_owner, acquired_at = row[0], row[1]
+
+ lock_is_available = False
+ if not row_found:
+ lock_is_available = True
+ elif current_owner is None:
+ lock_is_available = True
+ else:
+ timeout_threshold = datetime.now(timezone.utc) - acquired_at
+ if timeout_threshold.total_seconds() > timeout:
+ logging.info(
+ f"Stale lock found, owned by {current_owner}. Acquiring."
+ )
+ lock_is_available = True
+
+ if lock_is_available:
+ if not row_found:
+ sql_statement = """
+ INSERT INTO IngestionLock (LockID, LockOwner, AcquiredTimestamp)
+ VALUES (@lockId, @workflowId, PENDING_COMMIT_TIMESTAMP())
+ """
+ log_msg = f"Lock successfully acquired by {workflow_id} (new row created)"
+ else:
+ sql_statement = """
+ UPDATE IngestionLock
+ SET LockOwner = @workflowId, AcquiredTimestamp = PENDING_COMMIT_TIMESTAMP()
+ WHERE LockID = @lockId
+ """
+ log_msg = f"Lock successfully acquired by {workflow_id} (existing row updated)"
+
+ transaction.execute_update(sql_statement,
+ params={
+ "workflowId": workflow_id,
+ "lockId": self._LOCK_ID
+ },
+ param_types={
+ "workflowId": STRING,
+ "lockId": STRING
+ })
+ logging.info(log_msg)
+ return True
+ else:
+ logging.info(f"Lock is currently held by {current_owner}")
+ return False
+
+ try:
+ return self.database.run_in_transaction(_acquire)
+ except Exception as e:
+ logging.error(f'Error acquiring lock for {workflow_id}: {e}')
+ raise
+
+ def release_lock(self, workflow_id: str) -> bool:
+ """Releases the global lock.
+
+ Args:
+ workflow_id: The ID of the workflow attempting to release the lock.
+
+ Returns:
+ True if the lock was released, False otherwise.
+ """
+ logging.info(f"Attempting to release lock for {workflow_id}")
+
+ def _release(transaction: Transaction) -> None:
+ sql = "SELECT LockOwner, AcquiredTimestamp FROM IngestionLock WHERE LockID = @lockId"
+ params = {"lockId": self._LOCK_ID}
+ param_types = {"lockId": STRING}
+
+ current_owner = None
+ results = transaction.execute_sql(sql, params, param_types)
+ for row in results:
+ current_owner = row[0]
+
+ if current_owner == workflow_id:
+ sql = """
+ UPDATE IngestionLock
+ SET LockOwner = NULL, AcquiredTimestamp = NULL
+ WHERE LockID = @lockId
+ """
+ transaction.execute_update(sql,
+ params={"lockId": self._LOCK_ID},
+ param_types={"lockId": STRING})
+ logging.info(f"Lock successfully released by {workflow_id}")
+ return True
+ else:
+ logging.info(f"Lock is currently held by {current_owner}")
+ return False
+
+ try:
+ return self.database.run_in_transaction(_release)
+ except Exception as e:
+ logging.error(f'Error releasing lock for {workflow_id}: {e}')
+ raise
+
+ def get_import_info(self, import_list: list) -> list:
+ """Get the details of imports to ingest.
+
+ If import_list is empty, return info for ready imports (STAGING).
+ If import_list is not empty, return info for the imports in the list that are in 'STAGING' status.
+
+ Args:
+ import_list: A list of import names to fetch details for.
+
+ Returns:
+ A list of dictionaries, where each dictionary contains 'importName', 'latestVersion', and 'graphPath'.
+ """
+ pending_imports = []
+ logging.info(f"Fetching imports from import list {import_list}.")
+
+ params = {}
+ param_types = {}
+ if import_list:
+ sql = "SELECT ImportName, LatestVersion, GraphPath FROM ImportStatus WHERE State = 'STAGING' AND ImportName IN UNNEST(@importNames)"
+ params = {"importNames": import_list}
+ param_types = {"importNames": Array(STRING)}
+ else:
+ sql = "SELECT ImportName, LatestVersion, GraphPath FROM ImportStatus WHERE State = 'STAGING'"
+
+ # Use a read-only snapshot for this query
+ try:
+ with self.database.snapshot() as snapshot:
+ results = snapshot.execute_sql(sql,
+ params=params,
+ param_types=param_types)
+ for row in results:
+ import_json = {}
+ import_json['importName'] = row[0]
+ import_json['latestVersion'] = os.path.basename(row[1])
+ import_json[
+ 'graphPath'] = f"{row[1].rstrip('/')}/{row[2].lstrip('/')}"
+ pending_imports.append(import_json)
+
+ logging.info(f"Found {len(pending_imports)} import jobs.")
+ return pending_imports
+ except Exception as e:
+ logging.error(f'Error getting import list: {e}')
+ raise
+
+ def update_ingestion_status(self, import_names: list, workflow_id: str,
+ status: str):
+ """Updates the ImportStatus table.
+
+ Args:
+ import_names: List of import names.
+ workflow_id: The ID of the workflow.
+ status: The status of the ingestion.
+ """
+ if not import_names:
+ return
+
+ logging.info(f"Updated ingestion status for {import_names}")
+
+ def _update(transaction: Transaction):
+ update_sql = "UPDATE ImportStatus SET State = @importStatus, WorkflowId = @workflowId, StatusUpdateTimestamp = PENDING_COMMIT_TIMESTAMP() WHERE ImportName IN UNNEST(@importNames)"
+ transaction.execute_update(update_sql,
+ params={
+ "importNames": import_names,
+ "workflowId": workflow_id,
+ "importStatus": status
+ },
+ param_types={
+ "importNames": Array(STRING),
+ "workflowId": STRING,
+ "importStatus": STRING
+ })
+
+ try:
+ self.database.run_in_transaction(_update)
+ logging.info(f"Marked {len(import_names)} import jobs as {status}.")
+ except Exception as e:
+ logging.error(f'Error updating ImportStatus table: {e}')
+ raise
+
+ def update_ingestion_history(self, workflow_id: str, job_id: str,
+ ingested_imports: list, metrics: dict):
+ """Updates the IngestionHistory table.
+
+ Args:
+ workflow_id: The ID of the workflow.
+ job_id: The Dataflow job ID.
+ ingested_imports: List of ingested import names.
+ metrics: A dictionary containing metrics about the ingestion.
+ """
+
+ logging.info(
+ f"Updating IngestionHistory table for workflow {workflow_id}")
+
+ def _insert(transaction: Transaction):
+ columns = [
+ "CompletionTimestamp", "IngestionFailure",
+ "WorkflowExecutionID", "DataflowJobId", "IngestedImports",
+ "ExecutionTime", "NodeCount", "EdgeCount", "ObservationCount"
+ ]
+ values = [[
+ spanner.COMMIT_TIMESTAMP,
+ self.check_failed_imports(), workflow_id, job_id,
+ ingested_imports, metrics['execution_time'],
+ metrics['node_count'], metrics['edge_count'],
+ metrics['obs_count']
+ ]]
+ transaction.insert_or_update(table="IngestionHistory",
+ columns=columns,
+ values=values)
+
+ try:
+ self.database.run_in_transaction(_insert)
+ # TODO: remvoe dual writes after switching to the prod setup.
+ if self.graph_database and self.graph_database.name != self.database.name:
+ self.graph_database.run_in_transaction(_insert)
+ logging.info(
+ f"Updated IngestionHistory table for workflow {workflow_id}")
+ except Exception as e:
+ logging.error(f'Error updating IngestionHistory table: {e}')
+ raise
+
+ def update_import_version_history(self, import_list_json: list,
+ workflow_id: str):
+ """Updates the ImportVersionHistory table.
+
+ Args:
+ import_list_json: A list of dictionaries containing import details.
+ workflow_id: The ID of the workflow.
+ """
+ if not import_list_json:
+ return
+
+ logging.info(
+ f"Updating ImportVersionHistory table for workflow {workflow_id}")
+
+ def _insert(transaction: Transaction):
+ version_history_columns = [
+ "ImportName", "Version", "UpdateTimestamp", "Comment"
+ ]
+ version_history_values = []
+ for import_json in import_list_json:
+ version_history_values.append([
+ import_json['importName'], import_json['latestVersion'],
+ spanner.COMMIT_TIMESTAMP,
+ "ingestion-workflow:" + workflow_id
+ ])
+
+ if version_history_values:
+ transaction.insert(table="ImportVersionHistory",
+ columns=version_history_columns,
+ values=version_history_values)
+
+ try:
+ self.database.run_in_transaction(_insert)
+ logging.info(
+ f"Updated ImportVersionHistory table for workflow {workflow_id}"
+ )
+ except Exception as e:
+ logging.error(f'Error updating ImportVersionHistory table: {e}')
+ raise
+
+ def check_failed_imports(self) -> bool:
+ """Checks if there are any failed imports."""
+ try:
+ with self.database.snapshot() as snapshot:
+ results = snapshot.execute_sql(
+ "SELECT 1 FROM ImportStatus WHERE State = 'PENDING' LIMIT 1"
+ )
+ return any(results)
+ except Exception as e:
+ logging.error(f'Error checking for pending imports: {e}')
+ return True
+
+ def update_import_status(self, params: dict):
+ """Updates the status for the specified import job.
+
+ Args:
+ params: A dictionary containing import parameters.
+ """
+ import_name = params['import_name']
+ job_id = params['job_id']
+ execution_time = params['execution_time']
+ data_volume = params['data_volume']
+ status = params['status']
+ latest_version = params['latest_version']
+ next_refresh = datetime.fromisoformat(params['next_refresh'])
+ graph_path = params['graph_path']
+ logging.info(f"Updating import status in spanner {params}")
+
+ def _record(transaction: Transaction):
+ columns = [
+ "ImportName", "State", "JobId", "ExecutionTime", "DataVolume",
+ "NextRefreshTimestamp", "LatestVersion", "GraphPath",
+ "StatusUpdateTimestamp"
+ ]
+
+ row_values = [
+ import_name, status, job_id, execution_time, data_volume,
+ next_refresh, latest_version, graph_path,
+ spanner.COMMIT_TIMESTAMP
+ ]
+
+ if status == 'STAGING':
+ columns.append("DataImportTimestamp")
+ row_values.append(spanner.COMMIT_TIMESTAMP)
+
+ transaction.insert_or_update(table="ImportStatus",
+ columns=columns,
+ values=[row_values])
+
+ logging.info(f"Marked {import_name} as {status}.")
+
+ try:
+ self.database.run_in_transaction(_record)
+ except Exception as e:
+ logging.error(
+ f'Error updating import status for {import_name}: {e}')
+ raise
+
+ def update_version_history(self, import_name: str, version: str,
+ comment: str):
+ """Updates the version history table.
+
+ Args:
+ import_name: The name of the import.
+ version: The version string.
+ comment: The comment for the update.
+ """
+ import_name = import_name.split(':')[-1]
+ logging.info(f"Updating version history for {import_name} to {version}")
+
+ def _record(transaction: Transaction):
+ columns = ["ImportName", "Version", "UpdateTimestamp", "Comment"]
+ values = [[import_name, version, spanner.COMMIT_TIMESTAMP, comment]]
+ transaction.insert(table="ImportVersionHistory",
+ columns=columns,
+ values=values)
+ logging.info(f"Added version history entry for {import_name}")
+
+ try:
+ self.database.run_in_transaction(_record)
+ except Exception as e:
+ logging.error(
+ f'Error updating version history for {import_name}: {e}')
+ raise
+
+ def initialize_database(self, enable_embeddings=False):
+ """Initializes the database by creating all required tables and proto bundles."""
+ logging.info("Initializing database...")
+
+ query = """
+ SELECT 'table' as type, table_name as name FROM information_schema.tables WHERE table_schema = ''
+ UNION ALL
+ SELECT 'index' as type, index_name as name FROM information_schema.indexes WHERE table_schema = '' AND table_name = 'NodeEmbedding'
+ UNION ALL
+ SELECT 'model' as type, model_name as name FROM information_schema.models WHERE model_schema = ''
+ """
+
+ existing_tables = []
+ existing_indexes = []
+ existing_models = []
+
+ with self.database.snapshot() as snapshot:
+ results = snapshot.execute_sql(query)
+ for row in results:
+ if len(row) < 2:
+ logging.warning(f"Invalid row from query: {row}")
+ continue
+ obj_type = row[0]
+ obj_name = row[1]
+ if obj_type == 'table':
+ existing_tables.append(obj_name)
+ elif obj_type == 'index':
+ existing_indexes.append(obj_name)
+ elif obj_type == 'model':
+ existing_models.append(obj_name)
+
+ logging.info(f"Existing tables: {existing_tables}")
+ logging.info(f"Existing indexes: {existing_indexes}")
+ logging.info(f"Existing models: {existing_models}")
+
+ required_tables = [
+ "Node", "Edge", "Observation", "ImportStatus", "IngestionHistory",
+ "ImportVersionHistory", "IngestionLock", "Cache", "VariableMetadata"
+ ]
+ required_indexes = []
+ required_models = []
+
+ if enable_embeddings:
+ required_tables.append("NodeEmbedding")
+ required_indexes.append("NodeEmbeddingIndex")
+ required_models.append("NodeEmbeddingModel")
+
+ missing_tables = [
+ t for t in required_tables if t not in existing_tables
+ ]
+ missing_indexes = [
+ i for i in required_indexes if i not in existing_indexes
+ ]
+ missing_models = [
+ m for m in required_models if m not in existing_models
+ ]
+
+ total_required = len(required_tables) + len(required_indexes) + len(
+ required_models)
+ total_missing = len(missing_tables) + len(missing_indexes) + len(
+ missing_models)
+
+ if total_missing == 0:
+ logging.info("Database is properly initialized.")
+ return
+
+ if total_missing < total_required:
+ raise RuntimeError(
+ f"Database inconsistent state. Missing tables: {missing_tables}, missing indexes: {missing_indexes}, missing models: {missing_models}. Please clean up manually."
+ )
+
+ logging.info("Creating all tables and proto bundles...")
+
+ schema_path = os.path.join(os.path.dirname(__file__), 'schema.sql')
+ logging.info(f"Reading schema from {schema_path}")
+ try:
+ with open(schema_path, 'r') as f:
+ schema_content = f.read()
+
+ ddl_statements = [
+ s.strip() for s in schema_content.split(';') if s.strip()
+ ]
+
+ if enable_embeddings:
+ embeddings_endpoint = self._get_embeddings_endpoint()
+ embedding_schema_path = os.path.join(os.path.dirname(__file__),
+ 'embedding_schema.sql')
+ logging.info(
+ f"Reading embedding schema from {embedding_schema_path}")
+ with open(embedding_schema_path, 'r') as f:
+ embedding_schema_content = f.read()
+ embedding_schema_content = Template(
+ embedding_schema_content).render(
+ embeddings_endpoint=embeddings_endpoint)
+ embedding_ddl_statements = [
+ s.strip()
+ for s in embedding_schema_content.split(';')
+ if s.strip()
+ ]
+ ddl_statements.extend(embedding_ddl_statements)
+ except Exception as e:
+ logging.error(f"Failed to read schema file: {e}")
+ raise
+
+ proto_path = os.path.join(os.path.dirname(__file__), 'storage.pb')
+ logging.info(f"Reading proto descriptors from {proto_path}")
+ try:
+ with open(proto_path, 'rb') as f:
+ proto_descriptors_bytes = f.read()
+ except Exception as e:
+ logging.error(f"Failed to read proto descriptors file: {e}")
+ raise
+
+ database_path = self.database.name
+ logging.info(f"Updating DDL for {database_path} with protos")
+
+ try:
+ admin_client = DatabaseAdminClient()
+ request = UpdateDatabaseDdlRequest(
+ database=database_path,
+ statements=ddl_statements,
+ proto_descriptors=proto_descriptors_bytes)
+ operation = admin_client.update_database_ddl(request=request)
+ operation.result()
+ logging.info("Database initialized successfully with protos.")
+ except Exception as e:
+ logging.error(f"Failed to update DDL with protos: {e}")
+ raise
diff --git a/pipeline/workflow/ingestion-helper/spanner_client_test.py b/pipeline/workflow/ingestion-helper/spanner_client_test.py
new file mode 100644
index 00000000..3a961db4
--- /dev/null
+++ b/pipeline/workflow/ingestion-helper/spanner_client_test.py
@@ -0,0 +1,224 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from unittest.mock import MagicMock, patch
+import sys
+import os
+
+# Add the current directory to path so we can import spanner_client
+sys.path.append(os.path.dirname(__file__))
+from spanner_client import SpannerClient
+
+class TestSpannerClient(unittest.TestCase):
+
+ @patch('google.cloud.spanner.Client')
+ def test_initialize_database_all_exist(self, mock_spanner_client):
+ # Setup mock
+ mock_instance = MagicMock()
+ mock_db = MagicMock()
+ mock_spanner_client.return_value.instance.return_value = mock_instance
+ mock_instance.database.return_value = mock_db
+
+ # Mock snapshot results (all tables exist)
+ mock_snapshot = MagicMock()
+ mock_db.snapshot.return_value.__enter__.return_value = mock_snapshot
+ mock_snapshot.execute_sql.return_value = [
+ ["table", "Node"], ["table", "Edge"], ["table", "Observation"],
+ ["table", "NodeEmbedding"], ["table", "ImportStatus"],
+ ["table", "IngestionHistory"], ["table", "ImportVersionHistory"],
+ ["table", "IngestionLock"],
+ ["index", "NodeEmbeddingIndex"],
+ ["model", "NodeEmbeddingModel"]
+ ]
+
+ client = SpannerClient("project", "instance", "database")
+
+ # Run method
+ client.initialize_database()
+
+ # Verify update_ddl was NOT called
+ mock_db.update_ddl.assert_not_called()
+
+ @patch('spanner_client.DatabaseAdminClient')
+ @patch('google.cloud.spanner.Client')
+ def test_initialize_database_none_exist(self, mock_spanner_client,
+ mock_admin_client):
+ # Setup mock
+ mock_instance = MagicMock()
+ mock_db = MagicMock()
+ mock_db.name = "projects/test-project/instances/test-instance/databases/test-db"
+ mock_spanner_client.return_value.instance.return_value = mock_instance
+ mock_instance.database.return_value = mock_db
+
+ # Mock DatabaseAdminClient
+ mock_admin_instance = MagicMock()
+ mock_admin_client.return_value = mock_admin_instance
+ mock_operation = MagicMock()
+ mock_admin_instance.update_database_ddl.return_value = mock_operation
+
+ # Mock snapshot results (no tables exist)
+ mock_snapshot = MagicMock()
+ mock_db.snapshot.return_value.__enter__.return_value = mock_snapshot
+ mock_snapshot.execute_sql.return_value = []
+
+ client = SpannerClient("project", "instance", "database")
+
+ def open_side_effect(file_path, mode='r', *args, **kwargs):
+ m = MagicMock()
+ if 'storage.pb' in str(file_path):
+ m.__enter__.return_value.read.return_value = b'dummy proto data'
+ else:
+ m.__enter__.return_value.read.return_value = 'CREATE TABLE Node;'
+ return m
+
+ # Run method with patched open
+ with patch('builtins.open', side_effect=open_side_effect):
+ client.initialize_database()
+
+ # Verify update_database_ddl WAS called
+ mock_admin_instance.update_database_ddl.assert_called_once()
+ mock_operation.result.assert_called_once()
+
+ # Verify placeholder replacement
+ args, kwargs = mock_admin_instance.update_database_ddl.call_args
+ request = kwargs.get('request') if kwargs else args[0]
+ statements = request.statements
+ self.assertEqual(len(statements), 1)
+ self.assertEqual(statements[0], "CREATE TABLE Node")
+
+ @patch('spanner_client.DatabaseAdminClient')
+ @patch('google.cloud.spanner.Client')
+ def test_initialize_database_with_embeddings(self, mock_spanner_client, mock_admin_client):
+ # Setup mock
+ mock_instance = MagicMock()
+ mock_db = MagicMock()
+ mock_db.name = "projects/test-project/instances/test-instance/databases/test-db"
+ mock_spanner_client.return_value.instance.return_value = mock_instance
+ mock_instance.database.return_value = mock_db
+
+ # Mock DatabaseAdminClient
+ mock_admin_instance = MagicMock()
+ mock_admin_client.return_value = mock_admin_instance
+ mock_operation = MagicMock()
+ mock_admin_instance.update_database_ddl.return_value = mock_operation
+
+ # Mock snapshot results (no tables exist)
+ mock_snapshot = MagicMock()
+ mock_db.snapshot.return_value.__enter__.return_value = mock_snapshot
+ mock_snapshot.execute_sql.return_value = []
+
+ client = SpannerClient("project", "instance", "database")
+
+ def open_side_effect(file_path, mode='r', *args, **kwargs):
+ m = MagicMock()
+ if 'storage.pb' in str(file_path):
+ m.__enter__.return_value.read.return_value = b'dummy proto data'
+ elif 'embedding_schema.sql' in str(file_path):
+ m.__enter__.return_value.read.return_value = 'CREATE TABLE NodeEmbedding; CREATE MODEL M REMOTE OPTIONS (endpoint = \'{{ embeddings_endpoint }}\');'
+ else:
+ m.__enter__.return_value.read.return_value = 'CREATE TABLE Node;'
+ return m
+
+ # Run method with patched open and parameter
+ with patch('builtins.open', side_effect=open_side_effect):
+ client.initialize_database(enable_embeddings=True)
+
+ # Verify update_database_ddl WAS called
+ mock_admin_instance.update_database_ddl.assert_called_once()
+
+ # Verify that both schemas were loaded
+ args, kwargs = mock_admin_instance.update_database_ddl.call_args
+ request = kwargs.get('request') if kwargs else args[0]
+ statements = request.statements
+ self.assertEqual(len(statements), 3)
+ self.assertEqual(statements[0], "CREATE TABLE Node")
+ self.assertEqual(statements[1], "CREATE TABLE NodeEmbedding")
+ self.assertIn("projects/project/locations", statements[2])
+
+ @patch('google.cloud.spanner.Client')
+ def test_initialize_database_inconsistent_state(self, mock_spanner_client):
+ # Setup mock
+ mock_instance = MagicMock()
+ mock_db = MagicMock()
+ mock_spanner_client.return_value.instance.return_value = mock_instance
+ mock_instance.database.return_value = mock_db
+
+ # Mock snapshot results (some tables exist)
+ mock_snapshot = MagicMock()
+ mock_db.snapshot.return_value.__enter__.return_value = mock_snapshot
+ mock_snapshot.execute_sql.return_value = [["table", "Node"]]
+
+ client = SpannerClient("project", "instance", "database")
+
+ # Run method and expect exception
+ with self.assertRaises(RuntimeError):
+ client.initialize_database()
+
+ @patch('google.cloud.spanner.Client')
+ def test_acquire_lock_new_row(self, mock_spanner_client):
+ # Setup mock
+ mock_instance = MagicMock()
+ mock_db = MagicMock()
+ mock_spanner_client.return_value.instance.return_value = mock_instance
+ mock_instance.database.return_value = mock_db
+
+ mock_transaction = MagicMock()
+ def run_in_transaction_side_effect(callback, *args, **kwargs):
+ return callback(mock_transaction, *args, **kwargs)
+ mock_db.run_in_transaction.side_effect = run_in_transaction_side_effect
+
+ # Mock execute_sql to return empty results (no row found)
+ mock_transaction.execute_sql.return_value = []
+
+ client = SpannerClient("project", "instance", "database")
+
+ # Run method
+ result = client.acquire_lock("workflow-123", 3600)
+
+ # Verify
+ self.assertTrue(result)
+ mock_transaction.execute_update.assert_called_once()
+ args, _ = mock_transaction.execute_update.call_args
+ self.assertIn("INSERT INTO IngestionLock", args[0])
+
+ @patch('google.cloud.spanner.Client')
+ def test_acquire_lock_existing_row(self, mock_spanner_client):
+ # Setup mock
+ mock_instance = MagicMock()
+ mock_db = MagicMock()
+ mock_spanner_client.return_value.instance.return_value = mock_instance
+ mock_instance.database.return_value = mock_db
+
+ mock_transaction = MagicMock()
+ def run_in_transaction_side_effect(callback, *args, **kwargs):
+ return callback(mock_transaction, *args, **kwargs)
+ mock_db.run_in_transaction.side_effect = run_in_transaction_side_effect
+
+ # Mock execute_sql to return existing lock (owner is None)
+ mock_transaction.execute_sql.return_value = [[None, None]]
+
+ client = SpannerClient("project", "instance", "database")
+
+ # Run method
+ result = client.acquire_lock("workflow-123", 3600)
+
+ # Verify
+ self.assertTrue(result)
+ mock_transaction.execute_update.assert_called_once()
+ args, _ = mock_transaction.execute_update.call_args
+ self.assertIn("UPDATE IngestionLock", args[0])
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/pipeline/workflow/ingestion-helper/storage_client.py b/pipeline/workflow/ingestion-helper/storage_client.py
new file mode 100644
index 00000000..12f57a50
--- /dev/null
+++ b/pipeline/workflow/ingestion-helper/storage_client.py
@@ -0,0 +1,156 @@
+# Copyright 2025 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Storage client for the ingestion helper."""
+
+import logging
+from google.cloud import storage
+from google.cloud import exceptions
+import json
+import os
+
+logging.getLogger().setLevel(logging.INFO)
+
+_STAGING_VERSION_FILE = 'staging_version.txt'
+_LATEST_VERSION_FILE = 'latest_version.txt'
+_IMPORT_METADATA_MCF = 'import_metadata_mcf.mcf'
+_IMPORT_SUMMARY_JSON = 'import_summary.json'
+
+
+class StorageClient:
+
+ def __init__(self, bucket_name: str):
+ """Initializes a GCS client."""
+ self.storage = storage.Client()
+ self.bucket = self.storage.bucket(bucket_name)
+
+ def get_import_summary(self, import_name: str, version: str) -> dict:
+ """Retrieves the import summary from GCS.
+
+ Args:
+ import_name: The name of the import.
+ version: The version of the import.
+
+ Returns:
+ A dictionary containing the import summary, or an empty dict if not found.
+ """
+ output_dir = import_name.replace(':', '/')
+ summary_file = os.path.join(output_dir, version, _IMPORT_SUMMARY_JSON)
+ logging.info(f'Reading import summary from {summary_file}')
+ try:
+ blob = self.bucket.blob(summary_file)
+ json_data_string = blob.download_as_text()
+ data = json.loads(json_data_string)
+ logging.info(f"Successfully read {summary_file}")
+ return data
+ except (exceptions.NotFound, json.JSONDecodeError) as e:
+ logging.error(
+ f'Error reading import summary file {summary_file}: {e}')
+ return {}
+
+ def update_import_summary(self, import_summary: dict):
+ """Updates the import summary in GCS.
+
+ Args:
+ import_summary: A dictionary containing the summary of the import.
+ """
+ latest_version = import_summary.get('latest_version')
+ path = latest_version.removeprefix('gs://').split('/', 1)
+ summary_file = os.path.join(path[1], _IMPORT_SUMMARY_JSON)
+ logging.info(
+ f'Updating import summary at {summary_file} {import_summary}')
+ blob = self.bucket.blob(summary_file)
+ blob.upload_from_string(json.dumps(import_summary))
+ logging.info(f'Updated import summary at {summary_file}')
+
+ def get_staging_version(self, import_name: str) -> str:
+ """Retrieves the latest version from the staging directory.
+
+ Args:
+ import_name: The name of the import.
+
+ Returns:
+ The version string, or an empty string if not found.
+ """
+ output_dir = import_name.replace(':', '/')
+ version_file = os.path.join(output_dir, _STAGING_VERSION_FILE)
+ logging.info(f'Reading version file {version_file}')
+ try:
+ blob = self.bucket.blob(version_file)
+ return blob.download_as_text()
+ except exceptions.NotFound:
+ logging.error(f"Version file {version_file} not found")
+ return ''
+
+ def update_version_file(self,
+ import_name: str,
+ version: str,
+ is_staging: bool = False):
+ """Updates the version file (staging or latest) in GCS.
+
+ Args:
+ import_name: The name of the import.
+ version: The new version string.
+ is_staging: Whether to update the staging version file or the latest version file.
+ """
+ file_name = _STAGING_VERSION_FILE if is_staging else _LATEST_VERSION_FILE
+ file_type = "staging" if is_staging else "latest"
+ logging.info(
+ f'Updating {file_type} version file for import {import_name} to {version}'
+ )
+ output_dir = import_name.replace(':', '/')
+ version_file = self.bucket.blob(os.path.join(output_dir, file_name))
+ version_file.upload_from_string(version)
+ logging.info(
+ f'Updated {file_type} version file {version_file.name} to {version}'
+ )
+
+ def update_provenance_file(self, import_name: str, version: str):
+ """Updates the provenance file for the import.
+
+ Args:
+ import_name: The name of the import.
+ version: The version of the import.
+ """
+ logging.info(
+ f'Updating provenance file for import {import_name} to add {version}'
+ )
+ output_dir = import_name.replace(':', '/')
+ metadata_blob = self.bucket.blob(
+ os.path.join(output_dir, version, 'provenance', 'genmcf',
+ _IMPORT_METADATA_MCF))
+ if metadata_blob.exists():
+ self.bucket.copy_blob(
+ metadata_blob, self.bucket,
+ os.path.join(output_dir, 'import_metadata_mcf.mcf'))
+ else:
+ logging.warning(
+ f'Generating default metadata for import {import_name}')
+ base_name = import_name.split(':')[-1]
+ default_provenance = f"Node: dcid:dc/base/{base_name}\ntypeOf: dcid:Provenance\n"
+ new_blob = self.bucket.blob(
+ os.path.join(output_dir, version, 'provenance', 'genmcf',
+ 'import_metadata_mcf.mcf'))
+ new_blob.upload_from_string(default_provenance)
+
+ provenance_file = import_name.split(':')[-1] + '.mcf'
+ provenance_blob = self.bucket.blob(
+ os.path.join('provenance', provenance_file))
+ if provenance_blob.exists():
+ self.bucket.copy_blob(
+ provenance_blob, self.bucket,
+ os.path.join(output_dir, version, 'provenance', 'genmcf',
+ provenance_file))
+ logging.info(
+ f'Updated provenance file for import {import_name} to add {version}'
+ )
diff --git a/pipeline/workflow/spanner-ingestion-workflow.yaml b/pipeline/workflow/spanner-ingestion-workflow.yaml
new file mode 100644
index 00000000..01f2976d
--- /dev/null
+++ b/pipeline/workflow/spanner-ingestion-workflow.yaml
@@ -0,0 +1,183 @@
+main:
+ params: [args]
+ steps:
+ - init:
+ assign:
+ - lock_timeout: 82800 # 23 hours
+ - wait_period: 300 # seconds
+ - project_id: '${sys.get_env("PROJECT_ID")}'
+ - dataflow_job_name: ${"ingestion-job-" + string(int(sys.now()))}
+ - dataflow_gcs_path: ${default(map.get(args, "templateGcsPath"), "gs://datcom-templates/templates/flex/ingestion.json")}
+ - location: '${sys.get_env("LOCATION")}'
+ - spanner_project_id: '${sys.get_env("SPANNER_PROJECT_ID")}'
+ - spanner_instance_id: '${sys.get_env("SPANNER_INSTANCE_ID")}'
+ - spanner_database_id: '${sys.get_env("SPANNER_DATABASE_ID")}'
+ - helper_url: ${"https://ingestion-helper-service-" + sys.get_env("PROJECT_NUMBER") + "." + location + ".run.app"}
+ - import_list: ${default(map.get(args, "importList"), [])}
+ - execution_error: null
+ - acquire_ingestion_lock:
+ try:
+ call: http.post
+ args:
+ url: ${helper_url}
+ auth:
+ type: OIDC
+ body:
+ actionType: acquire_ingestion_lock
+ workflowId: '${sys.get_env("GOOGLE_CLOUD_WORKFLOW_EXECUTION_ID")}'
+ timeout: ${lock_timeout}
+ result: lock_status
+ retry:
+ predicate: ${http.default_retry_predicate}
+ max_retries: 20
+ backoff:
+ initial_delay: 300
+ max_delay: 600
+ multiplier: 2
+ - process_ingestion:
+ try:
+ steps:
+ - get_import_info:
+ call: http.post
+ args:
+ url: ${helper_url}
+ auth:
+ type: OIDC
+ body:
+ actionType: get_import_info
+ importList: ${import_list}
+ result: import_info
+ - run_ingestion_job:
+ call: run_dataflow_job
+ args:
+ import_list: '${json.encode_to_string(import_info.body)}'
+ project_id: ${project_id}
+ job_name: ${dataflow_job_name}
+ template_gcs_path: ${dataflow_gcs_path}
+ location: ${location}
+ spanner_project_id: ${spanner_project_id}
+ spanner_instance_id: ${spanner_instance_id}
+ spanner_database_id: ${spanner_database_id}
+ wait_period: ${wait_period}
+ helper_url: ${helper_url}
+ workflow_id: '${sys.get_env("GOOGLE_CLOUD_WORKFLOW_EXECUTION_ID")}'
+ result: dataflow_job_id
+ - run_aggregation:
+ call: http.post
+ args:
+ url: ${helper_url}
+ auth:
+ type: OIDC
+ body:
+ actionType: run_aggregation
+ importList: ${import_info.body}
+ - update_ingestion_status:
+ call: http.post
+ args:
+ url: ${helper_url}
+ auth:
+ type: OIDC
+ body:
+ actionType: update_ingestion_status
+ workflowId: '${sys.get_env("GOOGLE_CLOUD_WORKFLOW_EXECUTION_ID")}'
+ jobId: '${dataflow_job_id}'
+ importList: '${import_info.body}'
+ status: 'SUCCESS'
+ result: function_response
+ except:
+ as: e
+ steps:
+ - capture_error:
+ assign:
+ - execution_error: ${e}
+ - release_ingestion_lock:
+ call: http.post
+ args:
+ url: ${helper_url}
+ auth:
+ type: OIDC
+ body:
+ actionType: release_ingestion_lock
+ workflowId: '${sys.get_env("GOOGLE_CLOUD_WORKFLOW_EXECUTION_ID")}'
+ result: function_response
+ - fail_workflow:
+ switch:
+ - condition: ${execution_error != null}
+ raise: ${execution_error}
+ - return_import_info:
+ return: '${import_info.body}'
+
+# This sub-workflow launches a Dataflow job and waits for it to complete.
+run_dataflow_job:
+ params: [import_list, project_id, job_name, template_gcs_path, location, spanner_project_id, spanner_instance_id, spanner_database_id, wait_period, helper_url, workflow_id]
+ steps:
+ - init:
+ assign:
+ - jobName: ${job_name}
+ - machineType: 'n2-highmem-8'
+ - numWorkers: 3
+ - log_imports:
+ call: sys.log
+ args:
+ text: '${"Dataflow job: " + jobName + " Import list: " + import_list}'
+ severity: INFO
+ - check_if_empty:
+ switch:
+ - condition: ${import_list == "[]"}
+ return: ''
+ - launch_dataflow_job:
+ call: googleapis.dataflow.v1b3.projects.locations.flexTemplates.launch
+ args:
+ projectId: '${project_id}'
+ location: '${location}'
+ body:
+ launchParameter:
+ containerSpecGcsPath: '${template_gcs_path}'
+ jobName: '${jobName}'
+ parameters:
+ importList: '${import_list}'
+ projectId: '${spanner_project_id}'
+ spannerInstanceId: '${spanner_instance_id}'
+ spannerDatabaseId: '${spanner_database_id}'
+ environment:
+ numWorkers: ${numWorkers}
+ machineType: ${machineType}
+ result: launch_result
+ - wait_for_job_completion:
+ call: sys.sleep
+ args:
+ seconds: ${wait_period}
+ next: check_job_status
+ - check_job_status:
+ call: googleapis.dataflow.v1b3.projects.locations.jobs.get
+ args:
+ projectId: '${project_id}'
+ location: '${location}'
+ jobId: '${launch_result.job.id}'
+ view: 'JOB_VIEW_SUMMARY'
+ result: job_status
+ next: check_if_done
+ - check_if_done:
+ switch:
+ - condition: ${job_status.currentState == "JOB_STATE_DONE"}
+ return: ${launch_result.job.id}
+ - condition: ${job_status.currentState == "JOB_STATE_FAILED" or job_status.currentState == "JOB_STATE_CANCELLED"}
+ next: record_failed_imports
+ next: wait_for_job_completion
+ - record_failed_imports:
+ call: http.post
+ args:
+ url: ${helper_url}
+ auth:
+ type: OIDC
+ body:
+ actionType: update_ingestion_status
+ workflowId: '${workflow_id}'
+ jobId: '${launch_result.job.id}'
+ importList: '${json.decode(import_list)}'
+ status: 'PENDING'
+ result: retry_response
+ - fail_workflow:
+ raise:
+ message: '${jobName + " dataflow job failed with status: " + job_status.currentState}'
+ code: 500
\ No newline at end of file
diff --git a/pipeline/workflow/spanner_ingestion_test.py b/pipeline/workflow/spanner_ingestion_test.py
new file mode 100644
index 00000000..7e3d24b6
--- /dev/null
+++ b/pipeline/workflow/spanner_ingestion_test.py
@@ -0,0 +1,174 @@
+# Copyright 2026 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+End-to-end test for import automation and Spanner ingestion workflows.
+"""
+
+import json
+import os
+import sys
+
+from absl import app
+from absl import logging
+
+# Add path for executor modules
+sys.path.append(
+ os.path.abspath(os.path.join(os.path.dirname(__file__), '../executor/app')))
+from executor import cloud_workflow
+
+from google.cloud import spanner
+
+PROJECT_ID = os.environ.get('PROJECT_ID', 'datcom-ci')
+LOCATION = os.environ.get('LOCATION', 'us-central1')
+SPANNER_PROJECT_ID = os.environ.get('SPANNER_PROJECT_ID', 'datcom-ci')
+SPANNER_INSTANCE_ID = os.environ.get('SPANNER_INSTANCE_ID',
+ 'datcom-spanner-test')
+SPANNER_DATABASE_ID = os.environ.get('SPANNER_DATABASE_ID', 'dc-test-db')
+GCS_BUCKET_ID = os.environ.get('GCS_BUCKET_ID', 'datcom-ci-test')
+IMPORT_WORKFLOW_ID = 'import-automation-workflow'
+INGESTION_WORKFLOW_ID = 'spanner-ingestion-workflow'
+
+# Test Import Configuration
+TEST_IMPORT_NAME = 'scripts/us_fed/treasury_constant_maturity_rates:USFed_ConstantMaturityRates_Test'
+
+
+def verify_spanner_data(import_name):
+ """Verifies that the import data exists and is marked as SUCCESS in Spanner."""
+ logging.info(f"Verifying Spanner data for import: {import_name}")
+ spanner_client = spanner.Client(project=SPANNER_PROJECT_ID)
+ instance = spanner_client.instance(SPANNER_INSTANCE_ID)
+ database = instance.database(SPANNER_DATABASE_ID)
+
+ try:
+ with database.snapshot(multi_use=True) as snapshot:
+ # Check ImportStatus table
+ query = "SELECT State FROM ImportStatus WHERE ImportName = @import_name"
+ params = {"import_name": import_name}
+ param_types = {"import_name": spanner.param_types.STRING}
+
+ results = list(
+ snapshot.execute_sql(query,
+ params=params,
+ param_types=param_types))
+
+ if not results:
+ raise AssertionError(
+ f"Import {import_name} not found in ImportStatus table.")
+
+ state = results[0][0]
+ if state != 'SUCCESS':
+ raise AssertionError(
+ f"Import {import_name} state is {state}, expected 'SUCCESS'."
+ )
+
+ logging.info(
+ f"Import {import_name} verified in ImportStatus with state: {state}"
+ )
+
+ # Check IngestionHistory table (optional, but good for E2E)
+ # We look for a recent entry containing this import
+ query_history = """
+ SELECT count(*)
+ FROM IngestionHistory
+ WHERE @import_name IN UNNEST(IngestedImports)
+ """
+ results_history = list(
+ snapshot.execute_sql(query_history,
+ params=params,
+ param_types=param_types))
+ count = results_history[0][0]
+
+ if count == 0:
+ raise AssertionError(
+ f"Import {import_name} not found in IngestionHistory table."
+ )
+
+ logging.info(f"Import {import_name} verified in IngestionHistory.")
+
+ except Exception as e:
+ logging.error(f"Spanner verification failed: {e}")
+ raise
+
+
+def cleanup_spanner(import_name):
+ """Cleans up the import data from Spanner to ensure a clean state."""
+ logging.info(f"Cleaning up Spanner data for import: {import_name}")
+ spanner_client = spanner.Client(project=SPANNER_PROJECT_ID)
+ instance = spanner_client.instance(SPANNER_INSTANCE_ID)
+ database = instance.database(SPANNER_DATABASE_ID)
+
+ def _delete_import(transaction):
+ query = "DELETE FROM ImportStatus WHERE ImportName = @import_name"
+ params = {"import_name": import_name}
+ param_types = {"import_name": spanner.param_types.STRING}
+ transaction.execute_update(query,
+ params=params,
+ param_types=param_types)
+
+ try:
+ database.run_in_transaction(_delete_import)
+ logging.info(
+ f"Successfully cleaned up {import_name} from ImportStatus table.")
+ except Exception as e:
+ logging.warning(f"Error during Spanner cleanup: {e}")
+
+
+def main(argv):
+ del argv # Unused.
+ try:
+ # 0. Cleanup Spanner
+ logging.info("Step 0: Cleanup Spanner...")
+ short_import_name = TEST_IMPORT_NAME.split(':')[-1]
+ cleanup_spanner(short_import_name)
+
+ # 1. Trigger Import Automation Workflow
+ job_name = "test-import"
+ import_config = {
+ "gcp_project_id": PROJECT_ID,
+ "gcs_project_id": PROJECT_ID,
+ "storage_prod_bucket_name": GCS_BUCKET_ID,
+ "gcs_bucket_volume_mount": GCS_BUCKET_ID
+ }
+
+ import_workflow_args = {
+ "importName": TEST_IMPORT_NAME,
+ "importConfig": json.dumps(import_config),
+ }
+
+ logging.info("Step 1: Running Import Automation Workflow...")
+ cloud_workflow.trigger_workflow_and_wait(PROJECT_ID, LOCATION,
+ IMPORT_WORKFLOW_ID,
+ import_workflow_args)
+
+ # 2. Trigger Spanner Ingestion Workflow
+ ingestion_workflow_args = {"importList": [short_import_name]}
+
+ logging.info("Step 2: Running Spanner Ingestion Workflow...")
+ cloud_workflow.trigger_workflow_and_wait(PROJECT_ID, LOCATION,
+ INGESTION_WORKFLOW_ID,
+ ingestion_workflow_args)
+
+ # 3. Verify Data in Spanner
+ logging.info("Step 3: Verifying Data in Spanner...")
+ verify_spanner_data(short_import_name)
+
+ logging.info("Spanner ingestion test completed successfully.")
+
+ except Exception as e:
+ logging.error(f"Spanner ingestion test Failed: {e}")
+ sys.exit(1)
+
+
+if __name__ == "__main__":
+ app.run(main)