diff --git a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/IngestionPipeline.java b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/IngestionPipeline.java index 76d53f6b..88405557 100644 --- a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/IngestionPipeline.java +++ b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/IngestionPipeline.java @@ -31,7 +31,7 @@ public static void main(String[] args) { .build(); LOGGER.info("Starting Spanner DDL creation..."); - spannerClient.validateOrInitializeDatabase(); + // spannerClient.validateOrInitializeDatabase(); LOGGER.info("Spanner DDL creation complete."); Pipeline pipeline = Pipeline.create(options); diff --git a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/Transforms.java b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/Transforms.java index adc1cc3d..821783cd 100644 --- a/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/Transforms.java +++ b/pipeline/ingestion/src/main/java/org/datacommons/ingestion/pipeline/Transforms.java @@ -4,6 +4,8 @@ import static org.datacommons.ingestion.pipeline.SkipProcessing.SKIP_OBS; import com.google.cloud.spanner.Mutation; +import com.google.common.base.Joiner; +import java.util.ArrayList; import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; @@ -11,6 +13,7 @@ import java.util.Set; import org.apache.beam.sdk.Pipeline; import org.apache.beam.sdk.io.TextIO; +import org.apache.beam.sdk.io.gcp.spanner.MutationGroup; import org.apache.beam.sdk.metrics.Counter; import org.apache.beam.sdk.metrics.Metrics; import org.apache.beam.sdk.transforms.*; @@ -88,9 +91,12 @@ public void processElement(@Element String row, MultiOutputReceiver out) { NodesEdges nodesEdges = cacheReader.parseArcRow(row, MCF_NODES_WITHOUT_TYPE_COUNTER); outputGraphMutations(nodesEdges, out); } else if (CacheReader.isObsTimeSeriesCacheRow(row) && skipProcessing != SKIP_OBS) { - var obs = cacheReader.parseTimeSeriesRow(row); - var kvs = spannerClient.toObservationKVMutations(obs); - var filtered = spannerClient.filterObservationKVMutations(kvs, seenObs); + var obsList = cacheReader.parseTimeSeriesRow(row); + List> kvs = new ArrayList<>(); + for (var obs : obsList) { + kvs.addAll(toNewSchemaMutations(obs)); + } + var filtered = filterNewSchemaMutations(kvs, seenObs); filtered.forEach(out.get(observationTag)::output); var dups = kvs.size() - filtered.size(); @@ -99,7 +105,7 @@ public void processElement(@Element String row, MultiOutputReceiver out) { } if (writeObsGraph) { - obs.stream() + obsList.stream() .map(Observation::getObsGraph) .forEach(obsGraph -> outputGraphMutations(obsGraph, out)); } @@ -212,19 +218,57 @@ private PCollection groupByGraphOnly(PCollection cacheRows) { .withOutputTags(graphTag, TupleTagList.of(observationTag))); var observations = - kvs.get(observationTag).apply("ExtractObservationMutations", Values.create()); + kvs.get(observationTag) + .apply("GroupObsMutations", GroupByKey.create()) + .apply( + "CreateMutationGroups", + ParDo.of( + new DoFn>, MutationGroup>() { + @ProcessElement + public void processElement(ProcessContext c) { + KV> element = c.element(); + Mutation primary = null; + List secondaries = new ArrayList<>(); + for (Mutation m : element.getValue()) { + if (m.getTable().equals("TimeSeries")) { + primary = m; + } else { + secondaries.add(m); + } + } + if (primary != null) { + c.output(MutationGroup.create(primary, secondaries)); + } else { + LOGGER.warn( + "No TimeSeries mutation found for group: " + element.getKey()); + } + } + })); // TODO: Explore emitting protos instead of mutations to reduce shuffle cost. - var graph = - kvs.get(graphTag) - .apply("GroupGraphMutations", GroupByKey.create()) - .apply("ExtractGraphMutations", ParDo.of(new ExtractKVMutationsDoFn(spannerClient))); - - var write = - PCollectionList.of(graph) - .and(observations) - .apply("MergeMutations", Flatten.pCollections()) - .apply("WriteToSpanner", spannerClient.getWriteTransform()); - return write.getOutput(); + var graphMutations = kvs.get(graphTag); + + var nodeMutations = + graphMutations + .apply("FilterNodes", Filter.by(kv -> kv.getValue().getTable().equals("Node"))) + .apply("GroupNodeMutations", GroupByKey.create()) + .apply("ExtractNodeMutations", ParDo.of(new ExtractKVMutationsDoFn(spannerClient))); + + var edgeMutations = + graphMutations + .apply("FilterEdges", Filter.by(kv -> kv.getValue().getTable().equals("Edge"))) + .apply("GroupEdgeMutations", GroupByKey.create()) + .apply("ExtractEdgeMutations", ParDo.of(new ExtractKVMutationsDoFn(spannerClient))); + + var writtenNodes = + nodeMutations.apply("WriteNodesToSpanner", spannerClient.getWriteTransform()); + + var waitingEdges = edgeMutations.apply("EdgesWaitOnNodes", Wait.on(writtenNodes.getOutput())); + + waitingEdges.apply("WriteEdgesToSpanner", spannerClient.getWriteTransform()); + + var writeObs = + observations.apply("WriteObsToSpanner", spannerClient.getWriteGroupedTransform()); + return writeObs.getOutput(); } } @@ -266,7 +310,7 @@ static void buildIngestionPipeline( * This method updates the key's usage, unlike `containsKey()`, which doesn't and would therefore * disrupt the LRU sequence. */ - private static class LRUCache extends LinkedHashMap { + static class LRUCache extends LinkedHashMap { private final int capacity; public LRUCache(int capacity) { @@ -279,4 +323,105 @@ protected boolean removeEldestEntry(Map.Entry eldest) { return size() > capacity; } } + + static List> toNewSchemaMutations(Observation obs) { + List> mutations = new ArrayList<>(); + String seriesDcid = + "dc/os/" + + Joiner.on("_") + .join( + obs.getVariableMeasured().replace('/', '_'), + obs.getObservationAbout().replace('/', '_'), + obs.getFacetId()); + + // 1. TimeSeries + mutations.add( + KV.of( + seriesDcid, + Mutation.newInsertOrUpdateBuilder("TimeSeries") + .set("id") + .to(seriesDcid) + .set("variable_measured") + .to(obs.getVariableMeasured()) + .set("provenance") + .to("dc/base/" + obs.getImportName()) + .build())); + + // 2. TimeSeriesAttribute + mutations.add( + KV.of( + seriesDcid, + Mutation.newInsertOrUpdateBuilder("TimeSeriesAttribute") + .set("id") + .to(seriesDcid) + .set("property") + .to("observationAbout") + .set("value") + .to(obs.getObservationAbout()) + .build())); + + addIfNotEmpty(mutations, seriesDcid, "unit", obs.getUnit()); + addIfNotEmpty(mutations, seriesDcid, "scalingFactor", obs.getScalingFactor()); + addIfNotEmpty(mutations, seriesDcid, "measurementMethod", obs.getMeasurementMethod()); + addIfNotEmpty(mutations, seriesDcid, "observationPeriod", obs.getObservationPeriod()); + + // 3. StatVarObservation + for (Map.Entry entry : obs.getObservations().getValuesMap().entrySet()) { + mutations.add( + KV.of( + seriesDcid, + Mutation.newInsertOrUpdateBuilder("StatVarObservation") + .set("id") + .to(seriesDcid) + .set("date") + .to(entry.getKey()) + .set("value") + .to(entry.getValue()) + .build())); + } + + return mutations; + } + + static void addIfNotEmpty( + List> mutations, String id, String property, String value) { + if (value != null && !value.isEmpty()) { + mutations.add( + KV.of( + id, + Mutation.newInsertOrUpdateBuilder("TimeSeriesAttribute") + .set("id") + .to(id) + .set("property") + .to(property) + .set("value") + .to(value) + .build())); + } + } + + static List> filterNewSchemaMutations( + List> kvs, LRUCache seenObs) { + List> filtered = new ArrayList<>(); + for (var kv : kvs) { + Mutation mutation = kv.getValue(); + String table = mutation.getTable(); + String id = kv.getKey(); + String dedupKey = ""; + + if (table.equals("TimeSeries") || table.equals("TimeSeriesAttribute")) { + dedupKey = table + "::" + id; + } else if (table.equals("StatVarObservation")) { + String date = mutation.asMap().get("date").getString(); + dedupKey = table + "::" + id + "::" + date; + } + + if (seenObs.get(dedupKey) != null) { + continue; + } + seenObs.put(dedupKey, true); + filtered.add(kv); + } + return filtered; + } } diff --git a/pipeline/ingestion/src/test/java/org/datacommons/ingestion/pipeline/TransformsTest.java b/pipeline/ingestion/src/test/java/org/datacommons/ingestion/pipeline/TransformsTest.java new file mode 100644 index 00000000..4ba53e69 --- /dev/null +++ b/pipeline/ingestion/src/test/java/org/datacommons/ingestion/pipeline/TransformsTest.java @@ -0,0 +1,118 @@ +package org.datacommons.ingestion.pipeline; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import com.google.cloud.spanner.Mutation; +import java.util.List; +import org.apache.beam.sdk.values.KV; +import org.datacommons.Storage.Observations; +import org.datacommons.ingestion.data.Observation; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class TransformsTest { + + @Test + public void testToNewSchemaMutations() { + Observations.Builder obsBuilder = Observations.newBuilder(); + obsBuilder.putValues("2020", "10.0"); + obsBuilder.putValues("2021", "12.0"); + + Observation obs = + Observation.builder() + .variableMeasured("count") + .observationAbout("country/USA") + .importName("test_import") + .unit("Count") + .measurementMethod("Census") + .observationPeriod("P1Y") + .scalingFactor("1.0") + .observations(obsBuilder.build()) + .build(); + + List> mutations = Transforms.toNewSchemaMutations(obs); + + // Expected mutations: + // 1 TimeSeries + // 5 TimeSeriesAttribute (observationAbout, unit, scalingFactor, measurementMethod, + // observationPeriod) + // 2 StatVarObservation (for 2020 and 2021) + // Total = 8 + + assertEquals(8, mutations.size()); + + // Verify TimeSeries mutation + Mutation tsMutation = findMutationByTable(mutations, "TimeSeries"); + assertEquals("TimeSeries", tsMutation.getTable()); + assertEquals( + "dc/os/count_country_USA_" + obs.getFacetId(), tsMutation.asMap().get("id").getString()); + assertEquals("count", tsMutation.asMap().get("variable_measured").getString()); + assertEquals("dc/base/test_import", tsMutation.asMap().get("provenance").getString()); + + // Verify StatVarObservation mutations + List svoMutations = findMutationsByTable(mutations, "StatVarObservation"); + assertEquals(2, svoMutations.size()); + + Mutation m2020 = findMutationByDate(svoMutations, "2020"); + assertEquals("10.0", m2020.asMap().get("value").getString()); + + Mutation m2021 = findMutationByDate(svoMutations, "2021"); + assertEquals("12.0", m2021.asMap().get("value").getString()); + } + + @Test + public void testFilterNewSchemaMutations() { + Transforms.LRUCache seenObs = new Transforms.LRUCache<>(100); + + Mutation m1 = Mutation.newInsertOrUpdateBuilder("TimeSeries").set("id").to("ts1").build(); + Mutation m2 = + Mutation.newInsertOrUpdateBuilder("StatVarObservation") + .set("id") + .to("ts1") + .set("date") + .to("2020") + .set("value") + .to("10") + .build(); + + List> kvs = + List.of( + KV.of("ts1", m1), + KV.of("ts1", m2), + KV.of("ts1", m1), // Duplicate + KV.of("ts1", m2) // Duplicate + ); + + List> filtered = Transforms.filterNewSchemaMutations(kvs, seenObs); + + assertEquals(2, filtered.size()); + assertTrue( + filtered.stream().map(KV::getValue).anyMatch(m -> m.getTable().equals("TimeSeries"))); + assertTrue( + filtered.stream() + .map(KV::getValue) + .anyMatch(m -> m.getTable().equals("StatVarObservation"))); + } + + private Mutation findMutationByTable(List> kvs, String table) { + return kvs.stream() + .map(KV::getValue) + .filter(m -> m.getTable().equals(table)) + .findFirst() + .orElseThrow(() -> new AssertionError("Mutation for table " + table + " not found")); + } + + private List findMutationsByTable(List> kvs, String table) { + return kvs.stream().map(KV::getValue).filter(m -> m.getTable().equals(table)).toList(); + } + + private Mutation findMutationByDate(List mutations, String date) { + return mutations.stream() + .filter(m -> m.asMap().get("date").getString().equals(date)) + .findFirst() + .orElseThrow(() -> new AssertionError("Mutation for date " + date + " not found")); + } +}