Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public static void main(String[] args) {
.build();

LOGGER.info("Starting Spanner DDL creation...");
spannerClient.validateOrInitializeDatabase();
// spannerClient.validateOrInitializeDatabase();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Commenting out spannerClient.validateOrInitializeDatabase() disables critical database schema validation and initialization. If this was done for testing purposes, please ensure it is reverted before merging, or implement a conditional check if skipping is intended for specific environments.

LOGGER.info("Spanner DDL creation complete.");

Pipeline pipeline = Pipeline.create(options);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import static org.datacommons.ingestion.pipeline.SkipProcessing.SKIP_OBS;

import com.google.cloud.spanner.Mutation;
import com.google.common.base.Joiner;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.io.gcp.spanner.MutationGroup;
import org.apache.beam.sdk.metrics.Counter;
import org.apache.beam.sdk.metrics.Metrics;
import org.apache.beam.sdk.transforms.*;
Expand Down Expand Up @@ -88,9 +91,12 @@ public void processElement(@Element String row, MultiOutputReceiver out) {
NodesEdges nodesEdges = cacheReader.parseArcRow(row, MCF_NODES_WITHOUT_TYPE_COUNTER);
outputGraphMutations(nodesEdges, out);
} else if (CacheReader.isObsTimeSeriesCacheRow(row) && skipProcessing != SKIP_OBS) {
var obs = cacheReader.parseTimeSeriesRow(row);
var kvs = spannerClient.toObservationKVMutations(obs);
var filtered = spannerClient.filterObservationKVMutations(kvs, seenObs);
var obsList = cacheReader.parseTimeSeriesRow(row);
List<KV<String, Mutation>> kvs = new ArrayList<>();
for (var obs : obsList) {
kvs.addAll(toNewSchemaMutations(obs));
}
var filtered = filterNewSchemaMutations(kvs, seenObs);
filtered.forEach(out.get(observationTag)::output);

var dups = kvs.size() - filtered.size();
Expand All @@ -99,7 +105,7 @@ public void processElement(@Element String row, MultiOutputReceiver out) {
}

if (writeObsGraph) {
obs.stream()
obsList.stream()
.map(Observation::getObsGraph)
.forEach(obsGraph -> outputGraphMutations(obsGraph, out));
}
Expand Down Expand Up @@ -212,19 +218,57 @@ private PCollection<Void> groupByGraphOnly(PCollection<String> cacheRows) {
.withOutputTags(graphTag, TupleTagList.of(observationTag)));

var observations =
kvs.get(observationTag).apply("ExtractObservationMutations", Values.create());
kvs.get(observationTag)
.apply("GroupObsMutations", GroupByKey.create())
.apply(
"CreateMutationGroups",
ParDo.of(
new DoFn<KV<String, Iterable<Mutation>>, MutationGroup>() {
@ProcessElement
public void processElement(ProcessContext c) {
KV<String, Iterable<Mutation>> element = c.element();
Mutation primary = null;
List<Mutation> secondaries = new ArrayList<>();
for (Mutation m : element.getValue()) {
if (m.getTable().equals("TimeSeries")) {
primary = m;
} else {
secondaries.add(m);
}
}
if (primary != null) {
c.output(MutationGroup.create(primary, secondaries));
} else {
LOGGER.warn(
"No TimeSeries mutation found for group: " + element.getKey());
}
}
}));
// TODO: Explore emitting protos instead of mutations to reduce shuffle cost.
var graph =
kvs.get(graphTag)
.apply("GroupGraphMutations", GroupByKey.create())
.apply("ExtractGraphMutations", ParDo.of(new ExtractKVMutationsDoFn(spannerClient)));

var write =
PCollectionList.of(graph)
.and(observations)
.apply("MergeMutations", Flatten.<Mutation>pCollections())
.apply("WriteToSpanner", spannerClient.getWriteTransform());
return write.getOutput();
var graphMutations = kvs.get(graphTag);

var nodeMutations =
graphMutations
.apply("FilterNodes", Filter.by(kv -> kv.getValue().getTable().equals("Node")))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid hardcoding the table name "Node". Use spannerClient.getNodeTableName() to ensure the filter remains correct if the table name is customized via configuration.

Suggested change
.apply("FilterNodes", Filter.by(kv -> kv.getValue().getTable().equals("Node")))
.apply("FilterNodes", Filter.by(kv -> kv.getValue().getTable().equals(spannerClient.getNodeTableName())))

.apply("GroupNodeMutations", GroupByKey.create())
.apply("ExtractNodeMutations", ParDo.of(new ExtractKVMutationsDoFn(spannerClient)));

var edgeMutations =
graphMutations
.apply("FilterEdges", Filter.by(kv -> kv.getValue().getTable().equals("Edge")))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Avoid hardcoding the table name "Edge". Use spannerClient.getEdgeTableName() to ensure the filter remains correct if the table name is customized via configuration.

Suggested change
.apply("FilterEdges", Filter.by(kv -> kv.getValue().getTable().equals("Edge")))
.apply("FilterEdges", Filter.by(kv -> kv.getValue().getTable().equals(spannerClient.getEdgeTableName())))

.apply("GroupEdgeMutations", GroupByKey.create())
.apply("ExtractEdgeMutations", ParDo.of(new ExtractKVMutationsDoFn(spannerClient)));

var writtenNodes =
nodeMutations.apply("WriteNodesToSpanner", spannerClient.getWriteTransform());

var waitingEdges = edgeMutations.apply("EdgesWaitOnNodes", Wait.on(writtenNodes.getOutput()));

waitingEdges.apply("WriteEdgesToSpanner", spannerClient.getWriteTransform());

var writeObs =
observations.apply("WriteObsToSpanner", spannerClient.getWriteGroupedTransform());
return writeObs.getOutput();
Comment on lines +262 to +271
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The method now only returns the output of writeObs, ignoring the completion of writtenNodes and the edge writes. This is a regression from the previous implementation which merged all mutation writes. If a downstream transform depends on the output of groupByGraphOnly, it may start before nodes and edges are fully committed to Spanner. You should merge the outputs of all write transforms before returning.

      var writtenNodes =
          nodeMutations.apply("WriteNodesToSpanner", spannerClient.getWriteTransform());

      var writtenEdges =
          edgeMutations
              .apply("EdgesWaitOnNodes", Wait.on(writtenNodes.getOutput()))
              .apply("WriteEdgesToSpanner", spannerClient.getWriteTransform());

      var writtenObs =
          observations.apply("WriteObsToSpanner", spannerClient.getWriteGroupedTransform());

      return PCollectionList.of(writtenNodes.getOutput())
          .and(writtenEdges.getOutput())
          .and(writtenObs.getOutput())
          .apply("MergeOutputs", Flatten.pCollections());

}
}

Expand Down Expand Up @@ -266,7 +310,7 @@ static void buildIngestionPipeline(
* This method updates the key's usage, unlike `containsKey()`, which doesn't and would therefore
* disrupt the LRU sequence.
*/
private static class LRUCache<K, V> extends LinkedHashMap<K, V> {
static class LRUCache<K, V> extends LinkedHashMap<K, V> {
private final int capacity;

public LRUCache(int capacity) {
Expand All @@ -279,4 +323,105 @@ protected boolean removeEldestEntry(Map.Entry<K, V> eldest) {
return size() > capacity;
}
}

static List<KV<String, Mutation>> toNewSchemaMutations(Observation obs) {
List<KV<String, Mutation>> mutations = new ArrayList<>();
String seriesDcid =
"dc/os/"
+ Joiner.on("_")
.join(
obs.getVariableMeasured().replace('/', '_'),
obs.getObservationAbout().replace('/', '_'),
obs.getFacetId());
Comment on lines +329 to +335
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for generating seriesDcid is duplicated from Observation.toObsGraph(). To ensure consistency and simplify maintenance, this logic should be centralized, for example by adding a public static method to the Observation class that returns the series DCID.


// 1. TimeSeries
mutations.add(
KV.of(
seriesDcid,
Mutation.newInsertOrUpdateBuilder("TimeSeries")
.set("id")
.to(seriesDcid)
.set("variable_measured")
.to(obs.getVariableMeasured())
.set("provenance")
.to("dc/base/" + obs.getImportName())
.build()));

// 2. TimeSeriesAttribute
mutations.add(
KV.of(
seriesDcid,
Mutation.newInsertOrUpdateBuilder("TimeSeriesAttribute")
.set("id")
.to(seriesDcid)
.set("property")
.to("observationAbout")
.set("value")
.to(obs.getObservationAbout())
.build()));

addIfNotEmpty(mutations, seriesDcid, "unit", obs.getUnit());
addIfNotEmpty(mutations, seriesDcid, "scalingFactor", obs.getScalingFactor());
addIfNotEmpty(mutations, seriesDcid, "measurementMethod", obs.getMeasurementMethod());
addIfNotEmpty(mutations, seriesDcid, "observationPeriod", obs.getObservationPeriod());

// 3. StatVarObservation
for (Map.Entry<String, String> entry : obs.getObservations().getValuesMap().entrySet()) {
mutations.add(
KV.of(
seriesDcid,
Mutation.newInsertOrUpdateBuilder("StatVarObservation")
.set("id")
.to(seriesDcid)
.set("date")
.to(entry.getKey())
.set("value")
.to(entry.getValue())
.build()));
}

return mutations;
}

static void addIfNotEmpty(
List<KV<String, Mutation>> mutations, String id, String property, String value) {
if (value != null && !value.isEmpty()) {
mutations.add(
KV.of(
id,
Mutation.newInsertOrUpdateBuilder("TimeSeriesAttribute")
.set("id")
.to(id)
.set("property")
.to(property)
.set("value")
.to(value)
.build()));
}
}

static List<KV<String, Mutation>> filterNewSchemaMutations(
List<KV<String, Mutation>> kvs, LRUCache<String, Boolean> seenObs) {
List<KV<String, Mutation>> filtered = new ArrayList<>();
for (var kv : kvs) {
Mutation mutation = kv.getValue();
String table = mutation.getTable();
String id = kv.getKey();
String dedupKey = "";

if (table.equals("TimeSeries") || table.equals("TimeSeriesAttribute")) {
dedupKey = table + "::" + id;
} else if (table.equals("StatVarObservation")) {
String date = mutation.asMap().get("date").getString();
dedupKey = table + "::" + id + "::" + date;
}

if (seenObs.get(dedupKey) != null) {
continue;
}
seenObs.put(dedupKey, true);
filtered.add(kv);
}
return filtered;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package org.datacommons.ingestion.pipeline;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

import com.google.cloud.spanner.Mutation;
import java.util.List;
import org.apache.beam.sdk.values.KV;
import org.datacommons.Storage.Observations;
import org.datacommons.ingestion.data.Observation;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class TransformsTest {

@Test
public void testToNewSchemaMutations() {
Observations.Builder obsBuilder = Observations.newBuilder();
obsBuilder.putValues("2020", "10.0");
obsBuilder.putValues("2021", "12.0");

Observation obs =
Observation.builder()
.variableMeasured("count")
.observationAbout("country/USA")
.importName("test_import")
.unit("Count")
.measurementMethod("Census")
.observationPeriod("P1Y")
.scalingFactor("1.0")
.observations(obsBuilder.build())
.build();

List<KV<String, Mutation>> mutations = Transforms.toNewSchemaMutations(obs);

// Expected mutations:
// 1 TimeSeries
// 5 TimeSeriesAttribute (observationAbout, unit, scalingFactor, measurementMethod,
// observationPeriod)
// 2 StatVarObservation (for 2020 and 2021)
// Total = 8

assertEquals(8, mutations.size());

// Verify TimeSeries mutation
Mutation tsMutation = findMutationByTable(mutations, "TimeSeries");
assertEquals("TimeSeries", tsMutation.getTable());
assertEquals(
"dc/os/count_country_USA_" + obs.getFacetId(), tsMutation.asMap().get("id").getString());
assertEquals("count", tsMutation.asMap().get("variable_measured").getString());
assertEquals("dc/base/test_import", tsMutation.asMap().get("provenance").getString());

// Verify StatVarObservation mutations
List<Mutation> svoMutations = findMutationsByTable(mutations, "StatVarObservation");
assertEquals(2, svoMutations.size());

Mutation m2020 = findMutationByDate(svoMutations, "2020");
assertEquals("10.0", m2020.asMap().get("value").getString());

Mutation m2021 = findMutationByDate(svoMutations, "2021");
assertEquals("12.0", m2021.asMap().get("value").getString());
}

@Test
public void testFilterNewSchemaMutations() {
Transforms.LRUCache<String, Boolean> seenObs = new Transforms.LRUCache<>(100);

Mutation m1 = Mutation.newInsertOrUpdateBuilder("TimeSeries").set("id").to("ts1").build();
Mutation m2 =
Mutation.newInsertOrUpdateBuilder("StatVarObservation")
.set("id")
.to("ts1")
.set("date")
.to("2020")
.set("value")
.to("10")
.build();

List<KV<String, Mutation>> kvs =
List.of(
KV.of("ts1", m1),
KV.of("ts1", m2),
KV.of("ts1", m1), // Duplicate
KV.of("ts1", m2) // Duplicate
);

List<KV<String, Mutation>> filtered = Transforms.filterNewSchemaMutations(kvs, seenObs);

assertEquals(2, filtered.size());
assertTrue(
filtered.stream().map(KV::getValue).anyMatch(m -> m.getTable().equals("TimeSeries")));
assertTrue(
filtered.stream()
.map(KV::getValue)
.anyMatch(m -> m.getTable().equals("StatVarObservation")));
}

private Mutation findMutationByTable(List<KV<String, Mutation>> kvs, String table) {
return kvs.stream()
.map(KV::getValue)
.filter(m -> m.getTable().equals(table))
.findFirst()
.orElseThrow(() -> new AssertionError("Mutation for table " + table + " not found"));
}

private List<Mutation> findMutationsByTable(List<KV<String, Mutation>> kvs, String table) {
return kvs.stream().map(KV::getValue).filter(m -> m.getTable().equals(table)).toList();
}

private Mutation findMutationByDate(List<Mutation> mutations, String date) {
return mutations.stream()
.filter(m -> m.asMap().get("date").getString().equals(date))
.findFirst()
.orElseThrow(() -> new AssertionError("Mutation for date " + date + " not found"));
}
}