diff --git a/.gitignore b/.gitignore index 5bcfb2a..9bc2bad 100644 --- a/.gitignore +++ b/.gitignore @@ -72,3 +72,7 @@ hs_err_pid* .apt_generated **/.sts4-cache/* + +# Media files +*.mp4 +*.jpg \ No newline at end of file diff --git a/build.gradle b/build.gradle index 2299f69..fecb51a 100644 --- a/build.gradle +++ b/build.gradle @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,18 +15,18 @@ */ buildscript { ext { - dataflowBeamVersion = '2.24.0' - visionApiVersion = '1.99.3' + beamVersion = '2.37.0' + visionApiVersion = '2.0.21' + videoApiVersion = '2.0.18' } repositories { mavenCentral() - jcenter() maven { url "https://plugins.gradle.org/m2/" } dependencies { - classpath "gradle.plugin.com.google.cloud.tools:jib-gradle-plugin:2.5.0" - classpath "com.diffplug.spotless:spotless-plugin-gradle:3.24.2" + classpath "gradle.plugin.com.google.cloud.tools:jib-gradle-plugin:3.2.0" + classpath "com.diffplug.spotless:spotless-plugin-gradle:6.3.0" } } } @@ -43,16 +43,16 @@ apply plugin: 'application' apply plugin: 'eclipse' apply plugin: 'idea' apply plugin: 'com.google.cloud.tools.jib' -apply plugin: "com.diffplug.gradle.spotless" +apply plugin: 'com.diffplug.spotless' // Licence header enforced by spotless def javaLicenseHeader = """/* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * https://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, @@ -62,15 +62,15 @@ def javaLicenseHeader = """/* */ """ java { - sourceCompatibility = JavaVersion.VERSION_1_8 - targetCompatibility = JavaVersion.VERSION_1_8 + sourceCompatibility = JavaVersion.VERSION_11 + targetCompatibility = JavaVersion.VERSION_11 } -mainClassName = System.getProperty('mainClass', 'com.google.solutions.ml.api.vision.VisionAnalyticsPipeline') +mainClassName = System.getProperty('mainClass', 'com.google.solutions.annotation.AnnotationPipeline') jib { from { - image = 'gcr.io/dataflow-templates-base/java8-template-launcher-base:latest' + image = 'gcr.io/dataflow-templates-base/java11-template-launcher-base:20220124_RC00' } to { credHelper = 'gcloud' @@ -92,19 +92,24 @@ repositories { mavenCentral() } dependencies { - implementation group: 'org.apache.beam', name: 'beam-sdks-java-core', version: dataflowBeamVersion - implementation group: 'org.apache.beam', name: 'beam-runners-google-cloud-dataflow-java', version: dataflowBeamVersion - implementation group: 'org.apache.beam', name: 'beam-runners-direct-java', version: dataflowBeamVersion - implementation group: 'org.apache.beam', name: 'beam-sdks-java-extensions-ml', version: dataflowBeamVersion - implementation group: 'org.slf4j', name: 'slf4j-jdk14', version: '1.7.5' - implementation "com.google.auto.value:auto-value-annotations:1.6.2" - annotationProcessor "com.google.auto.value:auto-value:1.6.2" + implementation group: 'org.apache.beam', name: 'beam-sdks-java-core', version: beamVersion + implementation(group: 'org.apache.beam', name: 'beam-runners-google-cloud-dataflow-java', version: beamVersion) { + exclude group: 'io.confluent', module: 'kafka-schema-registry-client' + exclude group: 'io.confluent', module: 'kafka-avro-serializer' + } + implementation group: 'org.apache.beam', name: 'beam-runners-direct-java', version: beamVersion + implementation group: 'org.apache.beam', name: 'beam-sdks-java-extensions-ml', version: beamVersion + implementation group: 'org.slf4j', name: 'slf4j-jdk14', version: '1.7.36' + implementation "com.google.auto.value:auto-value-annotations:1.9" + annotationProcessor "com.google.auto.value:auto-value:1.9" + implementation 'com.google.http-client:google-http-client:1.41.4' implementation group: 'com.google.cloud', name: 'google-cloud-vision', version: visionApiVersion - testImplementation group: 'org.apache.beam', name: 'beam-runners-direct-java', version: dataflowBeamVersion - testImplementation group: 'org.slf4j', name: 'slf4j-jdk14', version: '1.7.5' - testImplementation group: 'org.hamcrest', name: 'hamcrest-core', version: '1.3' - testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '1.3' - testImplementation 'junit:junit:4.12' + implementation group: 'com.google.cloud', name: 'google-cloud-video-intelligence', version: videoApiVersion + testImplementation group: 'org.apache.beam', name: 'beam-runners-direct-java', version: beamVersion + testImplementation group: 'org.slf4j', name: 'slf4j-jdk14', version: '1.7.36' + testImplementation group: 'org.hamcrest', name: 'hamcrest-core', version: '2.2' + testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.2' + testImplementation 'junit:junit:4.13.2' } jar { @@ -134,7 +139,7 @@ run { task execPipeline(type: JavaExec) { dependsOn(assemble) - main = System.getProperty("mainClass") + mainClass = System.getProperty("mainClass") classpath = sourceSets.main.runtimeClasspath systemProperties System.getProperties() def execArgs = System.getProperty("exec.args") @@ -143,7 +148,7 @@ task execPipeline(type: JavaExec) { // Spotless configuration -Boolean enableSpotlessCheck = project.hasProperty('enableSpotlessCheck') && project.enableSpotlessCheck == 'true' +def enableSpotlessCheck = project.hasProperty('enableSpotlessCheck') && project.enableSpotlessCheck == 'true' spotless { enforceCheck enableSpotlessCheck java { @@ -151,5 +156,22 @@ spotless { googleJavaFormat() } } + run.mustRunAfter 'resources' +// Tests + +test { + useJUnit() + testLogging { + events "passed", "skipped", "failed" + showStandardStreams = true + exceptionFormat "full" + } +} + +sourceSets { + test { + java.srcDir file('src/test') + } +} \ No newline at end of file diff --git a/gradle.properties b/gradle.properties index cb7dbdd..7e90283 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,5 +1,5 @@ # -# Copyright 2020 Google LLC +# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 8fbc6bd..e5160ab 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ # -# Copyright 2020 Google LLC +# Copyright 2022 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,6 +16,6 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.5.1-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-7.4-bin.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/settings.gradle b/settings.gradle index 6510743..3363e86 100644 --- a/settings.gradle +++ b/settings.gradle @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/main/java/com/google/solutions/annotation/AnnotateFilesDoFn.java b/src/main/java/com/google/solutions/annotation/AnnotateFilesDoFn.java new file mode 100644 index 0000000..c987285 --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/AnnotateFilesDoFn.java @@ -0,0 +1,86 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation; + +import com.google.cloud.videointelligence.v1p3beta1.StreamingFeature; +import com.google.cloud.vision.v1.Feature; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.ml.videointelligence.VideoAnnotator; +import com.google.solutions.annotation.ml.vision.ImageAnnotator; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.values.KV; + +public class AnnotateFilesDoFn + extends DoFn, KV> { + + private static final long serialVersionUID = 1L; + private final List videoFeatures; + private VideoAnnotator videoAnnotator; + private ImageAnnotator imageAnnotator; + private final List imageFeatures; + + public AnnotateFilesDoFn(List imageFeatures, List videoFeatures) { + this.imageFeatures = imageFeatures; + this.videoFeatures = videoFeatures; + } + + @Setup + public void setup() throws IOException { + imageAnnotator = new ImageAnnotator(imageFeatures); + videoAnnotator = new VideoAnnotator(videoFeatures); + } + + @Teardown + public void teardown() { + imageAnnotator.teardown(); + videoAnnotator.teardown(); + } + + @ProcessElement + public void processElement( + @Element Iterable fileInfos, + OutputReceiver> out) { + List videoFiles = new ArrayList<>(); + List imageFiles = new ArrayList<>(); + for (GCSFileInfo fileInfo : fileInfos) { + if (AnnotationPipeline.SUPPORTED_IMAGE_CONTENT_TYPES.stream() + .anyMatch(fileInfo.getContentType()::equalsIgnoreCase)) { + imageFiles.add(fileInfo); + } else if (AnnotationPipeline.SUPPORTED_VIDEO_CONTENT_TYPES.stream() + .anyMatch(fileInfo.getContentType()::equalsIgnoreCase)) { + videoFiles.add(fileInfo); + } else { + throw new RuntimeException("Unsupported content type: " + fileInfo.getContentType()); + } + } + + List> responses = new ArrayList<>(); + if (!imageFiles.isEmpty()) { + responses.addAll(imageAnnotator.processFiles(imageFiles)); + } + if (!videoFiles.isEmpty()) { + responses.addAll(videoAnnotator.processFiles(videoFiles)); + } + + for (KV response : responses) { + out.output(response); + } + } +} diff --git a/src/main/java/com/google/solutions/ml/api/vision/AnnotateImagesSimulatorDoFn.java b/src/main/java/com/google/solutions/annotation/AnnotateFilesSimulatorDoFn.java similarity index 52% rename from src/main/java/com/google/solutions/ml/api/vision/AnnotateImagesSimulatorDoFn.java rename to src/main/java/com/google/solutions/annotation/AnnotateFilesSimulatorDoFn.java index 5a8c1a6..5fb2155 100644 --- a/src/main/java/com/google/solutions/ml/api/vision/AnnotateImagesSimulatorDoFn.java +++ b/src/main/java/com/google/solutions/annotation/AnnotateFilesSimulatorDoFn.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,27 +13,31 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.google.solutions.ml.api.vision; +package com.google.solutions.annotation; +import com.google.cloud.videointelligence.v1p3beta1.StreamingFeature; import com.google.cloud.vision.v1.AnnotateImageResponse; import com.google.cloud.vision.v1.AnnotateImageResponse.Builder; import com.google.cloud.vision.v1.Feature; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.gcs.GCSFileInfo; import java.util.List; import java.util.Random; import org.apache.beam.sdk.transforms.DoFn; import org.apache.beam.sdk.values.KV; /** - * Image annotation simulation class to test batching logic without incurring Vision API costs. + * Annotation simulation class to test batching logic without incurring API costs. * - * It simulates the delay of calling the API and produces a single annotation. + *

It simulates the delay of calling the API and produces a single annotation. */ -public class AnnotateImagesSimulatorDoFn extends - DoFn, KV> { +public class AnnotateFilesSimulatorDoFn + extends DoFn, KV> { private static final long serialVersionUID = 1L; - public AnnotateImagesSimulatorDoFn(List featureTypes) { + public AnnotateFilesSimulatorDoFn( + List imageFeatures, List videoFeatures) { /* * Feature types are ignored at the moment. But the simulation logic can be enhanced if needed to produce annotations * based on the requested features. @@ -41,27 +45,32 @@ public AnnotateImagesSimulatorDoFn(List featureTypes) { } @ProcessElement - public void processElement(@Element Iterable imageUris, - OutputReceiver> out) { - VisionAnalyticsPipeline.numberOfRequests.inc(); + public void processElement( + @Element Iterable fileInfos, + OutputReceiver> out) { + AnnotationPipeline.numberOfImageApiRequests.inc(); try { /** - * It creates a pattern similar to using the actual APIs with 16 requests per batch and two features requested. - * If more sophisticated simulation is needed - externalize the values or make these parameters - * dependent on batch size and the number of features requested. + * It creates a pattern similar to using the actual APIs with 16 requests per batch and two + * features requested. If more sophisticated simulation is needed - externalize the values or + * make these parameters dependent on batch size and the number of features requested. */ Thread.sleep(500 + (new Random().nextInt(1000))); } catch (InterruptedException e) { // Do nothing } - imageUris.forEach( - imageUri -> { + fileInfos.forEach( + fileInfo -> { Builder responseBuilder = AnnotateImageResponse.newBuilder(); - responseBuilder.addLabelAnnotationsBuilder(0).setDescription("Test").setScore(.5F) - .setTopicality(.6F).setMid("/m/test"); + responseBuilder + .addLabelAnnotationsBuilder(0) + .setDescription("Test") + .setScore(.5F) + .setTopicality(.6F) + .setMid("/m/test"); AnnotateImageResponse response = responseBuilder.build(); - out.output(KV.of(imageUri, response)); + out.output(KV.of(fileInfo, response)); }); } } diff --git a/src/main/java/com/google/solutions/annotation/AnnotationPipeline.java b/src/main/java/com/google/solutions/annotation/AnnotationPipeline.java new file mode 100644 index 0000000..31a6432 --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/AnnotationPipeline.java @@ -0,0 +1,452 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableReference; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Iterables; +import com.google.common.collect.Sets; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.*; +import com.google.solutions.annotation.bigquery.BigQueryConstants.Mode; +import com.google.solutions.annotation.bigquery.BigQueryConstants.Type; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.gcs.GCSUtils; +import com.google.solutions.annotation.ml.MLApiResponseProcessor; +import com.google.solutions.annotation.ml.ProcessMLApiResponseDoFn; +import com.google.solutions.annotation.ml.ProcessorUtils; +import com.google.solutions.annotation.ml.videointelligence.processors.VideoLabelAnnotationProcessor; +import com.google.solutions.annotation.ml.videointelligence.processors.VideoObjectTrackingAnnotationProcessor; +import com.google.solutions.annotation.ml.vision.processors.CropHintAnnotationProcessor; +import com.google.solutions.annotation.ml.vision.processors.ErrorProcessor; +import com.google.solutions.annotation.ml.vision.processors.FaceAnnotationProcessor; +import com.google.solutions.annotation.ml.vision.processors.ImagePropertiesProcessor; +import com.google.solutions.annotation.ml.vision.processors.LabelAnnotationProcessor; +import com.google.solutions.annotation.ml.vision.processors.LandmarkAnnotationProcessor; +import com.google.solutions.annotation.ml.vision.processors.LogoAnnotationProcessor; +import com.google.solutions.annotation.pubsub.PubSubNotificationToGCSInfoDoFn; +import com.google.solutions.annotation.pubsub.WriteRelevantAnnotationsToPubSubTransform; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.beam.sdk.Pipeline; +import org.apache.beam.sdk.PipelineResult; +import org.apache.beam.sdk.io.FileIO; +import org.apache.beam.sdk.io.fs.MatchResult.Metadata; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; +import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubIO; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessage; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Distribution; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.options.PipelineOptionsFactory; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.Filter; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.transforms.SerializableFunction; +import org.apache.beam.sdk.transforms.windowing.AfterWatermark; +import org.apache.beam.sdk.transforms.windowing.BoundedWindow; +import org.apache.beam.sdk.transforms.windowing.FixedWindows; +import org.apache.beam.sdk.transforms.windowing.Window; +import org.apache.beam.sdk.values.*; +import org.joda.time.Duration; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class AnnotationPipeline { + + public static final Logger LOG = LoggerFactory.getLogger(AnnotationPipeline.class); + + public static final Counter totalFiles = Metrics.counter(AnnotationPipeline.class, "totalFiles"); + public static final Counter rejectedFiles = + Metrics.counter(AnnotationPipeline.class, "rejectedFiles"); + public static final Counter numberOfImageApiRequests = + Metrics.counter(AnnotationPipeline.class, "numberOfImageApiRequests"); + public static final Counter numberOfVideoApiRequests = + Metrics.counter(AnnotationPipeline.class, "numberOfVideoApiRequests"); + public static final Counter numberOfQuotaExceededRequests = + Metrics.counter(AnnotationPipeline.class, "numberOfQuotaExceededRequests"); + + public static final Distribution batchSizeDistribution = + Metrics.distribution(AnnotationPipeline.class, "batchSizeDistribution"); + + public static final Set SUPPORTED_IMAGE_CONTENT_TYPES = + ImmutableSet.of("image/jpeg", "image/png", "image/tiff", "image/tif", "image/gif"); + + public static final Set SUPPORTED_VIDEO_CONTENT_TYPES = + ImmutableSet.of("video/mov", "video/mpeg4", "video/mp4", "video/avi"); + + public static final Set SUPPORTED_CONTENT_TYPES = + Sets.union(SUPPORTED_IMAGE_CONTENT_TYPES, SUPPORTED_VIDEO_CONTENT_TYPES).immutableCopy(); + + public static final String ACCEPTED_IMAGE_FILE_PATTERN = "jpeg|jpg|png|gif|tiff|tif"; + + public static final String ACCEPTED_VIDEO_FILE_PATTERN = "mp4|mov|mpeg4|avi"; + + public static final String ACCEPTED_FILE_PATTERN = + ACCEPTED_IMAGE_FILE_PATTERN + "|" + ACCEPTED_VIDEO_FILE_PATTERN; + + private static final TupleTag> allRows = + new TupleTag>() {}; + private static final TupleTag> relevantRows = + new TupleTag>() {}; + + /** + * Main entry point for executing the pipeline. This will run the pipeline asynchronously. If + * blocking execution is required, use the {@link + * AnnotationPipeline#run(AnnotationPipelineOptions)} method to start the pipeline and invoke + * {@code result.waitUntilFinish()} on the {@link PipelineResult} + * + * @param args The command-line arguments to the pipeline. + */ + public static void main(String[] args) throws IOException { + + AnnotationPipelineOptions options = + PipelineOptionsFactory.fromArgs(args).withValidation().as(AnnotationPipelineOptions.class); + + run(options); + } + + /** + * Runs the pipeline + * + * @return result + */ + public static PipelineResult run(AnnotationPipelineOptions options) throws IOException { + Pipeline p = Pipeline.create(options); + + PCollection fileInfos; + if (options.getInputNotificationSubscription() != null) { + fileInfos = convertPubSubNotificationsToGCSFileInfos(p, options); + } else if (options.getFileList() != null) { + fileInfos = listGCSFiles(p, options); + } else { + throw new RuntimeException("Either the subscriber id or the file list should be provided."); + } + + PCollection> batchedFileInfos = + fileInfos.apply( + "Batch files", + BatchRequestsTransform.create(options.getBatchSize(), options.getKeyRange())); + + PCollection> annotatedFiles = + options.isSimulate() + ? batchedFileInfos.apply( + "Simulate Annotation", + ParDo.of( + new AnnotateFilesSimulatorDoFn( + options.getImageFeatures(), options.getVideoFeatures()))) + : batchedFileInfos.apply( + "Annotate files", + ParDo.of( + new AnnotateFilesDoFn(options.getImageFeatures(), options.getVideoFeatures()))); + + Map processors = configureProcessors(options); + + PCollectionTuple annotationOutcome = + annotatedFiles.apply( + "Process Annotations", + ParDo.of( + ProcessMLApiResponseDoFn.create( + ImmutableSet.copyOf(processors.values()), allRows, relevantRows)) + .withOutputTags(allRows, TupleTagList.of(relevantRows))); + + annotationOutcome + .get(allRows) + .apply( + "Write All Annotations To BigQuery", + new BigQueryDynamicWriteTransform( + BigQueryDynamicDestinations.builder() + .project(options.getProject()) + .datasetName(options.getDatasetName()) + .metadataKeys(options.getMetadataKeys()) + .tableNameToTableDetailsMap(tableNameToTableDetailsMap(processors)) + .build())); + + annotationOutcome + .get(relevantRows) + .apply( + WriteRelevantAnnotationsToPubSubTransform.newBuilder() + .setTopicId(options.getRelevantAnnotationOutputTopic()) + .build()); + + collectBatchStatistics(batchedFileInfos, options); + + return p.run(); + } + + /** + * Collect the statistics on batching the requests. The results are published to a metric. If + * {@link AnnotationPipelineOptions#isCollectBatchData()} is true the batch data is saved to + * BigQuery table "batch_info". + */ + static void collectBatchStatistics( + PCollection> batchedFileInfos, AnnotationPipelineOptions options) { + + PCollection batchInfo = + batchedFileInfos.apply( + "Collect Batch Stats", + ParDo.of( + new DoFn, TableRow>() { + private static final long serialVersionUID = 1L; + + @ProcessElement + public void processElement( + @Element Iterable fileInfos, + BoundedWindow window, + OutputReceiver out, + ProcessContext context) { + int size = Iterables.size(fileInfos); + batchSizeDistribution.update(size); + if (context + .getPipelineOptions() + .as(AnnotationPipelineOptions.class) + .isCollectBatchData()) { + TableRow row = new TableRow(); + row.put("window", window.toString()); + row.put("timestamp", ProcessorUtils.getTimeStamp()); + row.put("size", size); + List fileUris = new ArrayList<>(); + fileInfos.forEach( + (fileInfo) -> { + fileUris.add(fileInfo.getUri()); + }); + row.put("items", fileUris); + out.output(row); + } + } + })); + if (!options.isCollectBatchData()) { + return; + } + batchInfo.apply( + BigQueryIO.writeTableRows() + .to( + new TableReference() + .setProjectId(options.getProject()) + .setDatasetId(options.getDatasetName()) + .setTableId("batch_info")) + .withWriteDisposition(WriteDisposition.WRITE_APPEND) + .withoutValidation() + .withClustering() + .ignoreInsertIds() + .withSchema( + new TableSchema() + .setFields( + ImmutableList.of( + new TableFieldSchema().setName("window").setType(Type.STRING), + new TableFieldSchema().setName("timestamp").setType(Type.TIMESTAMP), + new TableFieldSchema().setName("size").setType(Type.NUMERIC), + new TableFieldSchema() + .setName("items") + .setType(Type.STRING) + .setMode(Mode.REPEATED)))) + .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED)); + } + + /** + * Create a map of the table details. Each processor will produce TableRows destined + * to a different table. Each processor will provide the details about that table. + * + * @return map of table details keyed by table name + */ + static Map tableNameToTableDetailsMap( + Map processors) { + Map tableNameToTableDetailsMap = new HashMap<>(); + processors.forEach( + (tableName, processor) -> + tableNameToTableDetailsMap.put(tableName, processor.destinationTableDetails())); + return tableNameToTableDetailsMap; + } + + /** + * Reads PubSub messages from the subscription provided by {@link + * AnnotationPipelineOptions#getInputNotificationSubscription()}. + * + *

The messages are expected to confirm to the GCS notification message format defined in + * https://cloud.google.com/storage/docs/pubsub-notifications + * + *

Notifications are filtered to have one of the supported content types: {@link + * AnnotationPipeline#SUPPORTED_CONTENT_TYPES}. + * + * @return PCollection of GCS URIs + */ + static PCollection convertPubSubNotificationsToGCSFileInfos( + Pipeline p, AnnotationPipelineOptions options) { + PCollection gcsFileInfos; + PCollection pubSubNotifications = + p.begin() + .apply( + "Read PubSub", + PubsubIO.readMessagesWithAttributes() + .fromSubscription(options.getInputNotificationSubscription())); + gcsFileInfos = + pubSubNotifications + .apply( + "PubSub to GCS URIs", + ParDo.of(PubSubNotificationToGCSInfoDoFn.create(SUPPORTED_CONTENT_TYPES))) + .apply( + "Fixed Window", + Window.into( + FixedWindows.of(Duration.standardSeconds(options.getWindowInterval()))) + .triggering(AfterWatermark.pastEndOfWindow()) + .discardingFiredPanes() + .withAllowedLateness(Duration.standardMinutes(15))); + return gcsFileInfos; + } + + /** + * Reads the GCS objects provided by {@link AnnotationPipelineOptions#getFileList()}. + * + *

The file list can contain multiple entries. Each entry can contain wildcards supported by + * {@link FileIO#matchAll()}. + * + *

Files are filtered based on their suffixes as defined in {@link + * AnnotationPipeline#ACCEPTED_FILE_PATTERN}. + * + * @return PCollection of GCS URIs + */ + static PCollection listGCSFiles(Pipeline p, AnnotationPipelineOptions options) { + PCollection fileInfos; + PCollection allFiles = + p.begin() + .apply("Get File List", Create.of(options.getFileList())) + .apply("Match GCS Files", FileIO.matchAll()); + fileInfos = + allFiles + .apply( + ParDo.of( + new DoFn() { + private static final long serialVersionUID = 1L; + + @ProcessElement + public void processElement( + @Element Metadata metadata, OutputReceiver out) { + out.output(GCSUtils.getFileInfo(metadata.resourceId().toString())); + } + })) + .apply( + "Filter out non-image files", + Filter.by( + (SerializableFunction) + fileName -> { + totalFiles.inc(); + if (fileName + .getUri() + .matches("(^(?i).*\\.(" + ACCEPTED_FILE_PATTERN + ")$)")) { + return true; + } + LOG.warn("File {} does not contain a valid extension", fileName); + rejectedFiles.inc(); + return false; + })); + return fileInfos; + } + + /** + * Creates a map of well-known {@link MLApiResponseProcessor}s. + * + *

If additional processors are needed they should be configured in this method. + */ + private static Map configureProcessors( + AnnotationPipelineOptions options) { + Map result = new HashMap<>(); + + // Image processors + // ------------------------------------------------------------------------------ + + String tableName = options.getImageLabelAnnotationTable(); + result.put( + tableName, + new LabelAnnotationProcessor( + tableName, + options.getMetadataKeys(), + options.getRelevantImageLabels(), + options.getImageLabelAnnotationScoreThreshold())); + + tableName = options.getImageLandmarkAnnotationTable(); + result.put( + tableName, + new LandmarkAnnotationProcessor( + tableName, + options.getMetadataKeys(), + options.getRelevantImageLandmarks(), + options.getImageLandmarkAnnotationScoreThreshold())); + + tableName = options.getImageLogoAnnotationTable(); + result.put( + tableName, + new LogoAnnotationProcessor( + tableName, + options.getMetadataKeys(), + options.getRelevantLogos(), + options.getLogoAnnotationScoreThreshold())); + + tableName = options.getImageFaceAnnotationTable(); + result.put( + tableName, + new FaceAnnotationProcessor( + tableName, + options.getMetadataKeys(), + options.getFaceAnnotationDetectionConfidenceThreshold())); + + tableName = options.getImagePropertiesTable(); + result.put(tableName, new ImagePropertiesProcessor(tableName, options.getMetadataKeys())); + + tableName = options.getImageCropHintAnnotationTable(); + result.put( + tableName, + new CropHintAnnotationProcessor( + tableName, + options.getMetadataKeys(), + options.getImageCropAnnotationConfidenceThreshold())); + + tableName = options.getErrorLogTable(); + result.put(tableName, new ErrorProcessor(tableName, options.getMetadataKeys())); + + // Video processors + // ------------------------------------------------------------------------------ + + tableName = options.getVideoObjectTrackingAnnotationTable(); + result.put( + tableName, + new VideoObjectTrackingAnnotationProcessor( + tableName, + options.getMetadataKeys(), + options.getRelevantObjectTrackingEntities(), + options.getObjectTrackingConfidenceThreshold())); + + tableName = options.getVideoLabelAnnotationTable(); + result.put( + tableName, + new VideoLabelAnnotationProcessor( + tableName, + options.getMetadataKeys(), + options.getRelevantVideoLabelEntities(), + options.getVideoLabelConfidenceThreshold())); + + return result; + } +} diff --git a/src/main/java/com/google/solutions/annotation/AnnotationPipelineOptions.java b/src/main/java/com/google/solutions/annotation/AnnotationPipelineOptions.java new file mode 100644 index 0000000..2aa165a --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/AnnotationPipelineOptions.java @@ -0,0 +1,235 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation; + +import com.google.cloud.videointelligence.v1p3beta1.StreamingFeature; +import com.google.cloud.vision.v1.Feature; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; +import org.apache.beam.sdk.options.*; + +/** Interface to store pipeline options provided by the user */ +public interface AnnotationPipelineOptions extends DataflowPipelineOptions { + + @Description("Pub/Sub subscription ID to receive input Cloud Storage notifications from") + String getInputNotificationSubscription(); + + void setInputNotificationSubscription(String value); + + @Description("Google Cloud Storage files to process") + List getFileList(); + + void setFileList(List value); + + @Description("Key range") + @Default.Integer(1) + Integer getKeyRange(); + + void setKeyRange(Integer value); + + @Description("Image annotation request batch size") + @Default.Integer(1) + Integer getBatchSize(); + + void setBatchSize(Integer value); + + @Description("Window interval in seconds (default is 5)") + @Default.Integer(5) + Integer getWindowInterval(); + + void setWindowInterval(Integer value); + + @Description("BigQuery dataset") + @Validation.Required + String getDatasetName(); + + void setDatasetName(String value); + + @Description("Vision API features to use") + List getImageFeatures(); + + void setImageFeatures(List value); + + @Description("Streaming video features") + List getVideoFeatures(); + + void setVideoFeatures(List value); + + @Description("Simulate annotations") + @Default.Boolean(false) + boolean isSimulate(); + + void setSimulate(boolean value); + + @Description("Collect batch data") + @Default.Boolean(false) + boolean isCollectBatchData(); + + void setCollectBatchData(boolean value); + + @Description("Table name for image label annotations") + @Default.String("image_label_annotation") + String getImageLabelAnnotationTable(); + + void setImageLabelAnnotationTable(String value); + + @Description("Table name for image landmark annotations") + @Default.String("image_landmark_annotation") + String getImageLandmarkAnnotationTable(); + + void setImageLandmarkAnnotationTable(String value); + + @Description("Table name for image logo annotations") + @Default.String("image_logo_annotation") + String getImageLogoAnnotationTable(); + + void setImageLogoAnnotationTable(String value); + + @Description("Table name for image face annotations") + @Default.String("image_face_annotation") + String getImageFaceAnnotationTable(); + + void setImageFaceAnnotationTable(String value); + + @Description("Table name for image properties") + @Default.String("image_properties") + String getImagePropertiesTable(); + + void setImagePropertiesTable(String value); + + @Description("Table name for image crop hint annotations") + @Default.String("image_crop_hint_annotation") + String getImageCropHintAnnotationTable(); + + void setImageCropHintAnnotationTable(String value); + + @Description("Table name for video object tracking annotations") + @Default.String("video_object_tracking_annotation") + String getVideoObjectTrackingAnnotationTable(); + + void setVideoObjectTrackingAnnotationTable(String value); + + @Description("Table name for video label annotations") + @Default.String("video_label_annotation") + String getVideoLabelAnnotationTable(); + + void setVideoLabelAnnotationTable(String value); + + @Description("Table name for error logs") + @Default.String("error_log") + String getErrorLogTable(); + + void setErrorLogTable(String value); + + class EmptySet implements DefaultValueFactory> { + @Override + public Set create(PipelineOptions options) { + return new HashSet<>(); + } + } + + @Description("GCS metadata values to store in BigQuery") + @Default.InstanceFactory(EmptySet.class) + Set getMetadataKeys(); + + void setMetadataKeys(Set value); + + @Description( + "Minimum score level (value between 0 and 1) that the image label annotations must meet to be considered significant and be published the output Pub/Sub topic") + @Default.Float(0.8f) + Float getImageLabelAnnotationScoreThreshold(); + + void setImageLabelAnnotationScoreThreshold(Float value); + + @Description( + "Comma-separated list of image labels. Labels annotations must contain at least one of those values to be considered significant and be published the output Pub/Sub topic") + Set getRelevantImageLabels(); + + void setRelevantImageLabels(Set value); + + @Description( + "Minimum score level (value between 0 and 1) that the image landmark annotations must meet to be considered significant and be published the output Pub/Sub topic") + @Default.Float(0.8f) + Float getImageLandmarkAnnotationScoreThreshold(); + + void setImageLandmarkAnnotationScoreThreshold(Float value); + + @Description( + "Comma-separated list of landmarks. Landmark annotations must contain at least one of those values to be considered significant and be published the output Pub/Sub topic") + Set getRelevantImageLandmarks(); + + void setRelevantImageLandmarks(Set value); + + @Description( + "Minimum score level (value between 0 and 1) that the logo annotations must meet to be considered significant and be published the output Pub/Sub topic") + @Default.Float(0.8f) + Float getLogoAnnotationScoreThreshold(); + + void setLogoAnnotationScoreThreshold(Float value); + + @Description( + "Comma-separated list of logos. Logo annotations must contain at least one of those values to be considered significant and be published the output Pub/Sub topic") + Set getRelevantLogos(); + + void setRelevantLogos(Set value); + + @Description( + "Minimum detection confidence level (value between 0 and 1) that the image face annotations must meet to be considered significant and be published the output Pub/Sub topic") + @Default.Float(0.8f) + Float getFaceAnnotationDetectionConfidenceThreshold(); + + void setFaceAnnotationDetectionConfidenceThreshold(Float value); + + @Description( + "Minimum confidence level (value between 0 and 1) that the image crop annotations must meet to be considered significant and be published the output Pub/Sub topic") + @Default.Float(0.8f) + Float getImageCropAnnotationConfidenceThreshold(); + + void setImageCropAnnotationConfidenceThreshold(Float value); + + @Description( + "Minimum confidence level (value between 0 and 1) that the video object tracking annotations must meet to be considered significant and be published the output Pub/Sub topic") + @Default.Float(0.8f) + Float getObjectTrackingConfidenceThreshold(); + + void setObjectTrackingConfidenceThreshold(Float value); + + @Description( + "Comma-separated list of object tracking entities. Video object tracking annotations must contain at least one of those values to be considered significant and be published the output Pub/Sub topic") + Set getRelevantObjectTrackingEntities(); + + void setRelevantObjectTrackingEntities(Set value); + + @Description( + "Minimum confidence level (value between 0 and 1) that the video label annotations must meet to be considered significant and be published the output Pub/Sub topic") + @Default.Float(0.8f) + Float getVideoLabelConfidenceThreshold(); + + void setVideoLabelConfidenceThreshold(Float value); + + @Description( + "Comma-separated list of video label entities. Video label annotations must contain at least one of those values to be considered significant and be published the output Pub/Sub topic") + Set getRelevantVideoLabelEntities(); + + void setRelevantVideoLabelEntities(Set value); + + @Description("Pub/Sub topic ID to publish the results to") + String getRelevantAnnotationOutputTopic(); + + void setRelevantAnnotationOutputTopic(String value); +} diff --git a/src/main/java/com/google/solutions/ml/api/vision/BatchRequestsTransform.java b/src/main/java/com/google/solutions/annotation/BatchRequestsTransform.java similarity index 59% rename from src/main/java/com/google/solutions/ml/api/vision/BatchRequestsTransform.java rename to src/main/java/com/google/solutions/annotation/BatchRequestsTransform.java index b210209..4801692 100644 --- a/src/main/java/com/google/solutions/ml/api/vision/BatchRequestsTransform.java +++ b/src/main/java/com/google/solutions/annotation/BatchRequestsTransform.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,10 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -package com.google.solutions.ml.api.vision; +package com.google.solutions.annotation; import com.google.auto.value.AutoValue; +import com.google.solutions.annotation.gcs.GCSFileInfo; import java.util.Collections; import java.util.Random; import org.apache.beam.sdk.transforms.DoFn; @@ -29,12 +29,12 @@ import org.apache.beam.sdk.values.PCollection; /** - * Groups the requests into certain size batches. See {@link GroupIntoBatches} for effects - * of windowing on the output of this transform. + * Groups the requests into certain size batches. See {@link GroupIntoBatches} for effects of + * windowing on the output of this transform. */ @AutoValue -public abstract class BatchRequestsTransform extends - PTransform, PCollection>> { +public abstract class BatchRequestsTransform + extends PTransform, PCollection>> { private static final long serialVersionUID = 1L; @@ -42,43 +42,46 @@ public abstract class BatchRequestsTransform extends public abstract int getKeyRange(); - /** * @param batchSize should be between 1 and 16 * @param keyRange determines the level of parallelism. Should be a positive non-zero integer. * @return a new transform */ public static BatchRequestsTransform create(long batchSize, int keyRange) { - return builder() - .setBatchSize(batchSize) - .setKeyRange(keyRange) - .build(); + return builder().setBatchSize(batchSize).setKeyRange(keyRange).build(); } @Override - public PCollection> expand(PCollection input) { + public PCollection> expand(PCollection input) { if (getBatchSize() > 1) { return input - .apply("Assign Keys", WithKeys.of(new SerializableFunction() { - private static final long serialVersionUID = 1L; - private Random random = new Random(); - - @Override - public Integer apply(String input) { - return random.nextInt(getKeyRange()); - } - })) + .apply( + "Assign Keys", + WithKeys.of( + new SerializableFunction() { + private static final long serialVersionUID = 1L; + private final Random random = new Random(); + + @Override + public Integer apply(GCSFileInfo input) { + return random.nextInt(getKeyRange()); + } + })) .apply("Group Into Batches", GroupIntoBatches.ofSize(getBatchSize())) .apply("Convert to Batches", Values.create()); } else { - return input.apply("Convert to Iterable", ParDo.of(new DoFn>() { - private final static long serialVersionUID = 1L; - - @ProcessElement - public void process(@Element String element, OutputReceiver> out) { - out.output(Collections.singleton(element)); - } - })); + return input.apply( + "Convert to Iterable", + ParDo.of( + new DoFn>() { + private static final long serialVersionUID = 1L; + + @ProcessElement + public void process( + @Element GCSFileInfo element, OutputReceiver> out) { + out.output(Collections.singleton(element)); + } + })); } } diff --git a/src/main/java/com/google/solutions/ml/api/vision/BigQueryConstants.java b/src/main/java/com/google/solutions/annotation/bigquery/BigQueryConstants.java similarity index 92% rename from src/main/java/com/google/solutions/ml/api/vision/BigQueryConstants.java rename to src/main/java/com/google/solutions/annotation/bigquery/BigQueryConstants.java index 9e19d52..1606a21 100644 --- a/src/main/java/com/google/solutions/ml/api/vision/BigQueryConstants.java +++ b/src/main/java/com/google/solutions/annotation/bigquery/BigQueryConstants.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -package com.google.solutions.ml.api.vision; +package com.google.solutions.annotation.bigquery; public interface BigQueryConstants { diff --git a/src/main/java/com/google/solutions/ml/api/vision/BQDestination.java b/src/main/java/com/google/solutions/annotation/bigquery/BigQueryDestination.java similarity index 82% rename from src/main/java/com/google/solutions/ml/api/vision/BQDestination.java rename to src/main/java/com/google/solutions/annotation/bigquery/BigQueryDestination.java index a749546..0c958ae 100644 --- a/src/main/java/com/google/solutions/ml/api/vision/BQDestination.java +++ b/src/main/java/com/google/solutions/annotation/bigquery/BigQueryDestination.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -package com.google.solutions.ml.api.vision; +package com.google.solutions.annotation.bigquery; import java.io.Serializable; import java.util.Objects; @@ -22,17 +21,17 @@ import org.apache.beam.sdk.coders.DefaultCoder; @DefaultCoder(AvroCoder.class) -public class BQDestination implements Serializable { +public class BigQueryDestination implements Serializable { public static final long serialVersionUID = 1L; private String tableName; - BQDestination() { + BigQueryDestination() { // Needed for AvroCoder } - public BQDestination(String tableName) { + public BigQueryDestination(String tableName) { this.tableName = tableName; } @@ -48,7 +47,7 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) { return false; } - BQDestination that = (BQDestination) o; + BigQueryDestination that = (BigQueryDestination) o; return Objects.equals(tableName, that.tableName); } diff --git a/src/main/java/com/google/solutions/ml/api/vision/BQDynamicDestinations.java b/src/main/java/com/google/solutions/annotation/bigquery/BigQueryDynamicDestinations.java similarity index 53% rename from src/main/java/com/google/solutions/ml/api/vision/BQDynamicDestinations.java rename to src/main/java/com/google/solutions/annotation/bigquery/BigQueryDynamicDestinations.java index fc726bd..9e0ed02 100644 --- a/src/main/java/com/google/solutions/ml/api/vision/BQDynamicDestinations.java +++ b/src/main/java/com/google/solutions/annotation/bigquery/BigQueryDynamicDestinations.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,48 +13,50 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package com.google.solutions.annotation.bigquery; -package com.google.solutions.ml.api.vision; - +import com.google.api.services.bigquery.model.TableFieldSchema; import com.google.api.services.bigquery.model.TableReference; import com.google.api.services.bigquery.model.TableRow; import com.google.api.services.bigquery.model.TableSchema; import com.google.auto.value.AutoValue; +import com.google.solutions.annotation.ml.Constants; import java.util.Map; import java.util.Objects; +import java.util.Set; import org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations; import org.apache.beam.sdk.io.gcp.bigquery.TableDestination; import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.ValueInSingleWindow; -/** - * Provides details of the target table. - */ +/** Provides details of the target table. */ @AutoValue -abstract public class BQDynamicDestinations extends - DynamicDestinations, BQDestination> { +public abstract class BigQueryDynamicDestinations + extends DynamicDestinations, BigQueryDestination> { + + private static final long serialVersionUID = 1L; - private final static long serialVersionUID = 1L; + abstract String project(); - abstract String projectId(); + abstract String datasetName(); - abstract String datasetId(); + abstract Set metadataKeys(); abstract Map tableNameToTableDetailsMap(); @Override - public BQDestination getDestination( - ValueInSingleWindow> element) { + public BigQueryDestination getDestination( + ValueInSingleWindow> element) { return Objects.requireNonNull(element.getValue()).getKey(); } @Override - public TableDestination getTable(BQDestination destination) { + public TableDestination getTable(BigQueryDestination destination) { TableDetails tableDetails = tableDetails(destination); return new TableDestination( new TableReference() - .setProjectId(projectId()) - .setDatasetId(datasetId()) + .setProjectId(project()) + .setDatasetId(datasetName()) .setTableId(destination.getTableName()), tableDetails.description(), tableDetails.timePartitioningJson(), @@ -62,11 +64,23 @@ public TableDestination getTable(BQDestination destination) { } @Override - public TableSchema getSchema(BQDestination destination) { - return tableDetails(destination).schemaProducer().getTableSchema(); + public TableSchema getSchema(BigQueryDestination destination) { + TableSchema schema = tableDetails(destination).schemaProducer().getTableSchema(); + + // Add metadata fields, if any. + TableFieldSchema metadataSchema = + new TableFieldSchema() + .setName(Constants.Field.METADATA) + .setType("RECORD") + .setMode("REPEATED"); + for (String key : metadataKeys()) { + metadataSchema.set(key, new TableFieldSchema().setName(key).setType("STRING")); + } + schema.set("metadata", metadataSchema); + return schema; } - private TableDetails tableDetails(BQDestination destination) { + private TableDetails tableDetails(BigQueryDestination destination) { TableDetails result = tableNameToTableDetailsMap().get(destination.getTableName()); if (result == null) { throw new RuntimeException("Unable to find schema for table " + destination.getTableName()); @@ -75,19 +89,21 @@ private TableDetails tableDetails(BQDestination destination) { } public static Builder builder() { - return new AutoValue_BQDynamicDestinations.Builder(); + return new AutoValue_BigQueryDynamicDestinations.Builder(); } @AutoValue.Builder public abstract static class Builder { - public abstract Builder projectId(String projectId); + public abstract Builder project(String projectId); + + public abstract Builder datasetName(String datasetId); - public abstract Builder datasetId(String datasetId); + public abstract Builder metadataKeys(Set metadataKeys); public abstract Builder tableNameToTableDetailsMap( Map tableNameToTableDetailsMap); - public abstract BQDynamicDestinations build(); + public abstract BigQueryDynamicDestinations build(); } } diff --git a/src/main/java/com/google/solutions/ml/api/vision/BigQueryDynamicWriteTransform.java b/src/main/java/com/google/solutions/annotation/bigquery/BigQueryDynamicWriteTransform.java similarity index 72% rename from src/main/java/com/google/solutions/ml/api/vision/BigQueryDynamicWriteTransform.java rename to src/main/java/com/google/solutions/annotation/bigquery/BigQueryDynamicWriteTransform.java index 85ebd43..bba59c1 100644 --- a/src/main/java/com/google/solutions/ml/api/vision/BigQueryDynamicWriteTransform.java +++ b/src/main/java/com/google/solutions/annotation/bigquery/BigQueryDynamicWriteTransform.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,7 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package com.google.solutions.ml.api.vision; +package com.google.solutions.annotation.bigquery; import com.google.api.services.bigquery.model.TableRow; import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; @@ -24,26 +24,25 @@ import org.apache.beam.sdk.values.KV; import org.apache.beam.sdk.values.PCollection; -/** - * Writes TableRows to a {@link DynamicDestinations}. - */ +/** Writes TableRows to a {@link DynamicDestinations}. */ public class BigQueryDynamicWriteTransform - extends PTransform>, WriteResult> { + extends PTransform>, WriteResult> { private static final long serialVersionUID = 1L; - private final DynamicDestinations, BQDestination> destinations; + private final DynamicDestinations, BigQueryDestination> + destinations; public BigQueryDynamicWriteTransform( - DynamicDestinations, BQDestination> destinations) { + DynamicDestinations, BigQueryDestination> destinations) { this.destinations = destinations; } @Override - public WriteResult expand(PCollection> input) { + public WriteResult expand(PCollection> input) { return input.apply( - "BQ Write", - BigQueryIO.>write() + "BigQuery Write", + BigQueryIO.>write() .to(destinations) .withFormatFunction(KV::getValue) .withWriteDisposition(WriteDisposition.WRITE_APPEND) diff --git a/src/main/java/com/google/solutions/ml/api/vision/TableDetails.java b/src/main/java/com/google/solutions/annotation/bigquery/TableDetails.java similarity index 88% rename from src/main/java/com/google/solutions/ml/api/vision/TableDetails.java rename to src/main/java/com/google/solutions/annotation/bigquery/TableDetails.java index 797fb24..b80b18b 100644 --- a/src/main/java/com/google/solutions/ml/api/vision/TableDetails.java +++ b/src/main/java/com/google/solutions/annotation/bigquery/TableDetails.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -package com.google.solutions.ml.api.vision; +package com.google.solutions.annotation.bigquery; import com.google.api.services.bigquery.model.Clustering; import com.google.api.services.bigquery.model.TimePartitioning; @@ -35,8 +34,11 @@ public abstract class TableDetails implements Serializable { public abstract TableSchemaProducer schemaProducer(); - public static TableDetails create(String description, Clustering clustering, - TimePartitioning timePartitioning, TableSchemaProducer schemaProducer) { + public static TableDetails create( + String description, + Clustering clustering, + TimePartitioning timePartitioning, + TableSchemaProducer schemaProducer) { return builder() .description(description) .clusteringJson(clustering == null ? null : BigQueryHelpers.toJsonString(clustering)) @@ -63,5 +65,4 @@ public abstract static class Builder { public abstract TableDetails build(); } - } diff --git a/src/main/java/com/google/solutions/ml/api/vision/TableSchemaProducer.java b/src/main/java/com/google/solutions/annotation/bigquery/TableSchemaProducer.java similarity index 74% rename from src/main/java/com/google/solutions/ml/api/vision/TableSchemaProducer.java rename to src/main/java/com/google/solutions/annotation/bigquery/TableSchemaProducer.java index 02634ce..2ba7722 100644 --- a/src/main/java/com/google/solutions/ml/api/vision/TableSchemaProducer.java +++ b/src/main/java/com/google/solutions/annotation/bigquery/TableSchemaProducer.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,15 +13,15 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -package com.google.solutions.ml.api.vision; +package com.google.solutions.annotation.bigquery; import com.google.api.services.bigquery.model.TableSchema; import java.io.Serializable; /** - * Interface used in combination with {@link org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations} - * implementations. See {@link BQDynamicDestinations} for details. + * Interface used in combination with {@link + * org.apache.beam.sdk.io.gcp.bigquery.DynamicDestinations} implementations. See {@link + * BigQueryDynamicDestinations} for details. */ public interface TableSchemaProducer extends Serializable { TableSchema getTableSchema(); diff --git a/src/main/java/com/google/solutions/annotation/gcs/GCSFileInfo.java b/src/main/java/com/google/solutions/annotation/gcs/GCSFileInfo.java new file mode 100644 index 0000000..ad19186 --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/gcs/GCSFileInfo.java @@ -0,0 +1,73 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.gcs; + +import java.io.Serializable; +import java.util.Map; +import java.util.Objects; +import org.apache.beam.sdk.coders.AvroCoder; +import org.apache.beam.sdk.coders.DefaultCoder; + +@DefaultCoder(AvroCoder.class) +public class GCSFileInfo implements Serializable { + + public static final long serialVersionUID = 1L; + + private String uri; + private String contentType; + private Map metadata; + + GCSFileInfo() { + // Needed for AvroCoder + } + + public GCSFileInfo(String uri, String contentType, Map metadata) { + this.uri = uri; + this.contentType = contentType; + this.metadata = metadata; + } + + public String getUri() { + return uri; + } + + public String getContentType() { + return contentType; + } + + public Map getMetadata() { + return metadata; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + GCSFileInfo that = (GCSFileInfo) o; + return Objects.equals(uri, that.uri) + && Objects.equals(contentType, that.contentType) + && Objects.equals(metadata, that.metadata); + } + + @Override + public int hashCode() { + return Objects.hash(uri, contentType); + } +} diff --git a/src/main/java/com/google/solutions/annotation/gcs/GCSUtils.java b/src/main/java/com/google/solutions/annotation/gcs/GCSUtils.java new file mode 100644 index 0000000..489d615 --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/gcs/GCSUtils.java @@ -0,0 +1,98 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.gcs; + +import static com.google.protobuf.ByteString.*; + +import com.google.api.client.googleapis.javanet.GoogleNetHttpTransport; +import com.google.api.client.http.HttpRequestInitializer; +import com.google.api.client.http.HttpTransport; +import com.google.api.client.json.jackson2.JacksonFactory; +import com.google.api.services.storage.Storage; +import com.google.api.services.storage.StorageScopes; +import com.google.api.services.storage.model.StorageObject; +import com.google.auth.http.HttpCredentialsAdapter; +import com.google.auth.oauth2.GoogleCredentials; +import com.google.protobuf.ByteString; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import com.google.api.client.json.gson.GsonFactory; + +public class GCSUtils { + + private static final String APPLICATION_NAME = "my-app-name"; // FIXME + private static Storage storageService; + + private static URI getURI(String gcsURI) { + try { + return new URI(gcsURI); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + public static GCSFileInfo getFileInfo(String gcsURI) { + URI uri = getURI(gcsURI); + Storage storageClient = getStorageClient(); + Storage.Objects.Get getObject; + StorageObject object; + try { + getObject = storageClient.objects().get(uri.getHost(), uri.getPath().substring(1)); + object = getObject.execute(); + } catch (IOException e) { + throw new RuntimeException(e); + } + + return new GCSFileInfo(gcsURI, object.getContentType(), object.getMetadata()); + } + + public static ByteString getBytes(String gcsURI) { + URI uri = getURI(gcsURI); + Storage storageClient = getStorageClient(); + Storage.Objects.Get getObject; + try { + getObject = storageClient.objects().get(uri.getHost(), uri.getPath().substring(1)); + getObject.getMediaHttpDownloader().setDirectDownloadEnabled(true); + Output output = newOutput(); + getObject.executeMediaAndDownloadTo(output); + return output.toByteString(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static Storage getStorageClient() { + if (null == storageService) { + HttpTransport httpTransport; + GoogleCredentials credentials; + try { + httpTransport = GoogleNetHttpTransport.newTrustedTransport(); + credentials = + GoogleCredentials.getApplicationDefault().createScoped(StorageScopes.CLOUD_PLATFORM); + } catch (Exception e) { + throw new RuntimeException(e); + } + HttpRequestInitializer requestInitializer = new HttpCredentialsAdapter(credentials); + GsonFactory gsonFactory = GsonFactory.getDefaultInstance(); + storageService = + new Storage.Builder(httpTransport, gsonFactory, requestInitializer) + .setApplicationName(APPLICATION_NAME) + .build(); + } + return storageService; + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/BackOffUtils.java b/src/main/java/com/google/solutions/annotation/ml/BackOffUtils.java new file mode 100644 index 0000000..aba9e6b --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/BackOffUtils.java @@ -0,0 +1,66 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml; + +import com.google.api.client.util.ExponentialBackOff; +import com.google.api.gax.rpc.ResourceExhaustedException; +import com.google.solutions.annotation.AnnotationPipeline; +import java.io.IOException; +import java.util.concurrent.TimeUnit; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class BackOffUtils { + + public static final Logger LOG = LoggerFactory.getLogger(BackOffUtils.class); + + public static ExponentialBackOff createBackOff() { + return new ExponentialBackOff.Builder() + .setInitialIntervalMillis(10 * 1000 /* 10 seconds */) + .setMaxElapsedTimeMillis(10 * 60 * 1000 /* 10 minutes */) + .setMaxIntervalMillis(90 * 1000 /* 90 seconds */) + .setMultiplier(1.5) + .setRandomizationFactor(0.5) + .build(); + } + + /** + * Attempts to backoff unless reaches the max elapsed time. + * + * @param backoff + * @param e + */ + public static void handleQuotaReachedException( + ExponentialBackOff backoff, ResourceExhaustedException e) { + AnnotationPipeline.numberOfQuotaExceededRequests.inc(); + long waitInMillis = 0; + try { + waitInMillis = backoff.nextBackOffMillis(); + } catch (IOException ioException) { + // Will not occur with this implementation of Backoff. + } + if (waitInMillis == ExponentialBackOff.STOP) { + LOG.warn("Reached the limit of backoff retries. Throwing the exception to the pipeline"); + throw e; + } + LOG.info("Received {}. Will retry in {} seconds.", e.getClass().getName(), waitInMillis / 1000); + try { + TimeUnit.MILLISECONDS.sleep(waitInMillis); + } catch (InterruptedException interruptedException) { + // Do nothing + } + } +} diff --git a/src/main/java/com/google/solutions/ml/api/vision/processor/Constants.java b/src/main/java/com/google/solutions/annotation/ml/Constants.java similarity index 50% rename from src/main/java/com/google/solutions/ml/api/vision/processor/Constants.java rename to src/main/java/com/google/solutions/annotation/ml/Constants.java index 5c377ad..78df771 100644 --- a/src/main/java/com/google/solutions/ml/api/vision/processor/Constants.java +++ b/src/main/java/com/google/solutions/annotation/ml/Constants.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,22 +13,19 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -package com.google.solutions.ml.api.vision.processor; +package com.google.solutions.annotation.ml; import com.google.api.services.bigquery.model.TableFieldSchema; -import com.google.solutions.ml.api.vision.BigQueryConstants.Mode; -import com.google.solutions.ml.api.vision.BigQueryConstants.Type; +import com.google.solutions.annotation.bigquery.BigQueryConstants; import java.util.Arrays; import java.util.Collections; import java.util.List; -/** - * Helper interface for common image annotation response processor constants - */ -interface Constants { +/** Helper interface for common annotation response processor constants */ +public interface Constants { interface Field { + String METADATA = "metadata"; String BOUNDING_POLY = "bounding_poly"; String FD_BOUNDING_POLY = "fd_bounding_poly"; String LOCATIONS = "locations"; @@ -41,7 +38,7 @@ interface Field { String JOY_LIKELIHOOD = "joy_likelihood"; String SORROW_LIKELIHOOD = "sorrow_likelihood"; String ANGER_LIKELIHOOD = "anger_likelihood"; - String SURPISE_LIKELIHOOD = "surprise_likelihood"; + String SURPRISE_LIKELIHOOD = "surprise_likelihood"; String GCS_URI_FIELD = "gcs_uri"; String TIMESTAMP_FIELD = "transaction_timestamp"; String MID_FIELD = "mid"; @@ -65,21 +62,51 @@ interface Field { String CONFIDENCE = "confidence"; String IMPORTANCE_FRACTION = "importance_fraction"; String CROP_HINTS = "crop_hints"; + + // Video field names: + String ENTITY = "entity"; + String SEGMENTS = "segments"; + String START_TIME_OFFSET = "start_time_offset"; + String END_TIME_OFFSET = "end_time_offset"; + String FRAMES = "frames"; + String TIME_OFFSET = "time_offset"; + String LEFT = "left"; + String TOP = "top"; + String RIGHT = "right"; + String BOTTOM = "bottom"; } - List VERTEX_FIELDS = Arrays.asList( - new TableFieldSchema().setName(Field.VERTEX_X).setType(Type.FLOAT).setMode(Mode.REQUIRED), - new TableFieldSchema().setName(Field.VERTEX_Y).setType(Type.FLOAT).setMode(Mode.REQUIRED) - ); + List VERTEX_FIELDS = + Arrays.asList( + new TableFieldSchema() + .setName(Field.VERTEX_X) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED), + new TableFieldSchema() + .setName(Field.VERTEX_Y) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED)); - List POSITION_FIELDS = Arrays.asList( - new TableFieldSchema().setName(Field.VERTEX_X).setType(Type.FLOAT).setMode(Mode.REQUIRED), - new TableFieldSchema().setName(Field.VERTEX_Y).setType(Type.FLOAT).setMode(Mode.REQUIRED), - new TableFieldSchema().setName(Field.VERTEX_Z).setType(Type.FLOAT).setMode(Mode.NULLABLE) - ); + List POSITION_FIELDS = + Arrays.asList( + new TableFieldSchema() + .setName(Field.VERTEX_X) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED), + new TableFieldSchema() + .setName(Field.VERTEX_Y) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED), + new TableFieldSchema() + .setName(Field.VERTEX_Z) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.NULLABLE)); - List POLYGON_FIELDS = Collections.singletonList( - new TableFieldSchema().setName(Field.VERTICES).setType(Type.RECORD).setMode(Mode.REPEATED) - .setFields(VERTEX_FIELDS) - ); + List POLYGON_FIELDS = + Collections.singletonList( + new TableFieldSchema() + .setName(Field.VERTICES) + .setType(BigQueryConstants.Type.RECORD) + .setMode(BigQueryConstants.Mode.REPEATED) + .setFields(VERTEX_FIELDS)); } diff --git a/src/main/java/com/google/solutions/annotation/ml/MLApiResponseProcessor.java b/src/main/java/com/google/solutions/annotation/ml/MLApiResponseProcessor.java new file mode 100644 index 0000000..361d750 --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/MLApiResponseProcessor.java @@ -0,0 +1,69 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.bigquery.TableDetails; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.List; + +/** + * Implementors of this interface will process zero to many TableRows to persist to a specific + * BigTable table. + */ +public interface MLApiResponseProcessor extends Serializable { + + class ProcessorResult { + public static final String IMAGE_ERROR = "image-error"; + public static final String IMAGE_PROPERTIES = "image-properties"; + public static final String IMAGE_LABEL = "image-label"; + public static final String IMAGE_LOGO = "image-logo"; + public static final String IMAGE_LANDMARK = "image-landmark"; + public static final String IMAGE_FACE = "image-face"; + public static final String IMAGE_CROP = "image-crop"; + public static final String VIDEO_OBJECT_TRACKING = "video-object-tracking"; + public static final String VIDEO_LABEL = "video-label"; + + public String type; + public BigQueryDestination destination; + public List allRows; + public List relevantRows; + + public ProcessorResult(String type, BigQueryDestination destination) { + this.type = type; + this.destination = destination; + this.allRows = new ArrayList<>(); + this.relevantRows = new ArrayList<>(); + } + } + + /** + * @param fileInfo annotation source + * @param response from Google Cloud ML Video Intelligence or Cloud ML Vision API + * @return key/value pair of a BigQuery destination and a TableRow to persist. + */ + ProcessorResult process(GCSFileInfo fileInfo, GeneratedMessageV3 response); + + /** @return details of the table to persist to. */ + TableDetails destinationTableDetails(); + + /** @return true if the processor is meant to processor this type of response object. */ + boolean shouldProcess(GeneratedMessageV3 response); +} diff --git a/src/main/java/com/google/solutions/annotation/ml/ProcessMLApiResponseDoFn.java b/src/main/java/com/google/solutions/annotation/ml/ProcessMLApiResponseDoFn.java new file mode 100644 index 0000000..b220a6e --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/ProcessMLApiResponseDoFn.java @@ -0,0 +1,106 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.auto.value.AutoValue; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import java.util.Collection; +import java.util.Objects; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.TupleTag; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * ProcessImageResponse {@link ProcessMLApiResponseDoFn} class parses the API response for specific + * annotation and using response builder output the table and table row for BigQuery + */ +@AutoValue +public abstract class ProcessMLApiResponseDoFn + extends DoFn, KV> { + + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(ProcessMLApiResponseDoFn.class); + + abstract Collection processors(); + + abstract TupleTag> allRows(); + + abstract TupleTag> relevantRows(); + + abstract Counter processedFileCounter(); + + public static ProcessMLApiResponseDoFn create( + Collection processors, + TupleTag> allRows, + TupleTag> relevantRows) { + return builder() + .processors(processors) + .allRows(allRows) + .relevantRows(relevantRows) + .processedFileCounter(Metrics.counter(ProcessMLApiResponseDoFn.class, "processedFiles")) + .build(); + } + + @ProcessElement + public void processElement( + @Element KV element, MultiOutputReceiver out) { + GCSFileInfo fileInfo = element.getKey(); + GeneratedMessageV3 annotationResponse = element.getValue(); + + LOG.debug("Processing annotations for file: {}", Objects.requireNonNull(fileInfo).getUri()); + processedFileCounter().inc(); + + processors() + .forEach( + processor -> { + if (processor.shouldProcess(annotationResponse)) { + MLApiResponseProcessor.ProcessorResult result = + processor.process(fileInfo, annotationResponse); + if (result != null) { + result.allRows.forEach( + (TableRow row) -> out.get(allRows()).output(KV.of(result.destination, row))); + result.relevantRows.forEach( + (TableRow row) -> out.get(relevantRows()).output(KV.of(result.type, row))); + } + } + }); + } + + public static Builder builder() { + return new AutoValue_ProcessMLApiResponseDoFn.Builder(); + } + + @AutoValue.Builder + public abstract static class Builder { + + public abstract Builder processors(Collection processors); + + public abstract Builder processedFileCounter(Counter processedFileCounter); + + public abstract Builder allRows(TupleTag> allRows); + + public abstract Builder relevantRows(TupleTag> relevantRows); + + public abstract ProcessMLApiResponseDoFn build(); + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/ProcessorUtils.java b/src/main/java/com/google/solutions/annotation/ml/ProcessorUtils.java new file mode 100644 index 0000000..20f672e --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/ProcessorUtils.java @@ -0,0 +1,115 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml; + +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.vision.v1.BoundingPoly; +import com.google.cloud.vision.v1.EntityAnnotation; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import org.joda.time.DateTimeZone; +import org.joda.time.Instant; +import org.joda.time.format.DateTimeFormat; +import org.joda.time.format.DateTimeFormatter; + +/** Various utility functions used by processors */ +public class ProcessorUtils { + + /** Extracts the bounding polygon if one exists and adds it to the row. */ + public static void extractBoundingPoly(EntityAnnotation annotation, TableRow row) { + if (annotation.hasBoundingPoly()) { + TableRow boundingPoly = getBoundingPolyAsRow(annotation.getBoundingPoly()); + row.put(Constants.Field.BOUNDING_POLY, boundingPoly); + } + } + + /** + * Converts {@link BoundingPoly} to a {@link TableRow}. + * + * @return table row + */ + public static TableRow getBoundingPolyAsRow(BoundingPoly boundingPoly) { + List vertices = new ArrayList<>(); + boundingPoly + .getVerticesList() + .forEach( + vertex -> { + TableRow vertexRow = new TableRow(); + vertexRow.put(Constants.Field.VERTEX_X, vertex.getX()); + vertexRow.put(Constants.Field.VERTEX_Y, vertex.getY()); + vertices.add(vertexRow); + }); + TableRow result = new TableRow(); + result.put(Constants.Field.VERTICES, vertices); + return result; + } + + /** + * Creates a TableRow and populates with two fields used in all processors: {@link + * Constants.Field#GCS_URI_FIELD} and {@link Constants.Field#TIMESTAMP_FIELD} + * + * @return new TableRow + */ + public static TableRow startRow(GCSFileInfo fileInfo) { + TableRow row = new TableRow(); + row.put(Constants.Field.GCS_URI_FIELD, fileInfo.getUri()); + row.put(Constants.Field.TIMESTAMP_FIELD, getTimeStamp()); + return row; + } + + public static void setMetadataFieldsSchema( + List fields, Set metadataKeys) { + if (!metadataKeys.isEmpty()) { + List metadataFields = new ArrayList<>(); + for (String key : metadataKeys) { + metadataFields.add(new TableFieldSchema().setName(key).setType("STRING")); + } + fields.add( + new TableFieldSchema() + .setName(Constants.Field.METADATA) + .setType("RECORD") + .setFields(metadataFields)); + } + } + + public static void addMetadataValues( + TableRow row, GCSFileInfo fileInfo, Set metadataKeys) { + // Add metadata to the row, if any + TableRow metadataRow = new TableRow(); + if (fileInfo.getMetadata() != null) { + for (String key : metadataKeys) { + String value = fileInfo.getMetadata().get(key); + if (value != null) { + metadataRow.put(key, value); + } + } + } + if (!metadataRow.isEmpty()) { + row.put(Constants.Field.METADATA, metadataRow); + } + } + + private static final DateTimeFormatter TIMESTAMP_FORMATTER = + DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSSSSS"); + + /** Formats the current timestamp in BigQuery compliant format */ + public static String getTimeStamp() { + return TIMESTAMP_FORMATTER.print(Instant.now().toDateTime(DateTimeZone.UTC)); + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/videointelligence/VideoAnnotator.java b/src/main/java/com/google/solutions/annotation/ml/videointelligence/VideoAnnotator.java new file mode 100644 index 0000000..f52257e --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/videointelligence/VideoAnnotator.java @@ -0,0 +1,120 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml.videointelligence; + +import com.google.api.client.util.ExponentialBackOff; +import com.google.api.gax.rpc.BidiStream; +import com.google.api.gax.rpc.ResourceExhaustedException; +import com.google.cloud.videointelligence.v1p3beta1.*; +import com.google.protobuf.ByteString; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.AnnotationPipeline; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.gcs.GCSUtils; +import com.google.solutions.annotation.ml.BackOffUtils; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import org.apache.beam.sdk.values.KV; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class VideoAnnotator { + + public static final Logger LOG = LoggerFactory.getLogger(VideoAnnotator.class); + private final List features; + private final StreamingVideoIntelligenceServiceClient client; + BidiStream streamCall; + + public VideoAnnotator(List features) throws IOException { + this.features = features; + this.client = StreamingVideoIntelligenceServiceClient.create(); + } + + public void teardown() { + if (client != null) { + client.shutdown(); + try { + int waitTime = 10; + if (!client.awaitTermination(waitTime, TimeUnit.SECONDS)) { + LOG.warn( + "Failed to shutdown the annotation client after {} seconds. Closing client anyway.", + waitTime); + } + } catch (InterruptedException e) { + // Do nothing + } + client.close(); + } + } + + private static StreamingVideoConfig getConfig(StreamingFeature feature) { + StreamingObjectTrackingConfig objectTrackingConfig = + StreamingObjectTrackingConfig.newBuilder().build(); + StreamingLabelDetectionConfig labelConfig = StreamingLabelDetectionConfig.newBuilder().build(); + StreamingVideoConfig.Builder builder = + StreamingVideoConfig.newBuilder() + .setObjectTrackingConfig(objectTrackingConfig) + .setLabelDetectionConfig(labelConfig) + .setFeature(feature); + return builder.build(); + } + + public List> processFiles(Iterable fileInfos) { + List> result = new ArrayList<>(); + + // Download files' contents from GCS + List gcsBytes = new ArrayList<>(); + fileInfos.forEach( + fileInfo -> { + ByteString bytes = GCSUtils.getBytes(fileInfo.getUri()); + gcsBytes.add(bytes); + }); + + ExponentialBackOff backoff = BackOffUtils.createBackOff(); + AtomicInteger counter = new AtomicInteger(); + fileInfos.forEach( + fileInfo -> { + ByteString bytes = gcsBytes.get(counter.get()); + // The Streaming API only accepts one feature at a time, so we send multiple requests. + for (StreamingFeature feature : features) { + while (true) { + try { + // Send the bytes to the Streaming Video API + streamCall = client.streamingAnnotateVideoCallable().call(); + StreamingVideoConfig config = getConfig(feature); + streamCall.send( + StreamingAnnotateVideoRequest.newBuilder().setVideoConfig(config).build()); + streamCall.send( + StreamingAnnotateVideoRequest.newBuilder().setInputContent(bytes).build()); + AnnotationPipeline.numberOfVideoApiRequests.inc(); + streamCall.closeSend(); + for (StreamingAnnotateVideoResponse response : streamCall) { + result.add(KV.of(fileInfo, response)); + } + break; + } catch (ResourceExhaustedException e) { + BackOffUtils.handleQuotaReachedException(backoff, e); + } + } + } + counter.getAndIncrement(); + }); + return result; + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/videointelligence/processors/VideoLabelAnnotationProcessor.java b/src/main/java/com/google/solutions/annotation/ml/videointelligence/processors/VideoLabelAnnotationProcessor.java new file mode 100644 index 0000000..6780017 --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/videointelligence/processors/VideoLabelAnnotationProcessor.java @@ -0,0 +1,172 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml.videointelligence.processors; + +import com.google.api.services.bigquery.model.*; +import com.google.cloud.videointelligence.v1p3beta1.LabelAnnotation; +import com.google.cloud.videointelligence.v1p3beta1.StreamingAnnotateVideoResponse; +import com.google.cloud.videointelligence.v1p3beta1.StreamingVideoAnnotationResults; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.bigquery.TableDetails; +import com.google.solutions.annotation.bigquery.TableSchemaProducer; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.ml.Constants.Field; +import com.google.solutions.annotation.ml.MLApiResponseProcessor; +import com.google.solutions.annotation.ml.ProcessorUtils; +import java.util.*; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class VideoLabelAnnotationProcessor implements MLApiResponseProcessor { + + private static final long serialVersionUID = 1L; + + private static final Logger LOG = LoggerFactory.getLogger(VideoLabelAnnotationProcessor.class); + private final BigQueryDestination destination; + private final Set metadataKeys; + private final Set relevantEntities; + private final float confidenceThreshold; + public static final Counter counter = + Metrics.counter(MLApiResponseProcessor.class, "numberOfVideoLabelAnnotations"); + + public VideoLabelAnnotationProcessor( + String tableId, + Set metadataKeys, + Set relevantEntities, + float confidenceThreshold) { + this.destination = new BigQueryDestination(tableId); + this.metadataKeys = metadataKeys; + this.relevantEntities = relevantEntities; + this.confidenceThreshold = confidenceThreshold; + } + + private static class SchemaProducer implements TableSchemaProducer { + + private static final long serialVersionUID = 1L; + private final Set metadataKeys; + + SchemaProducer(Set metadataKeys) { + this.metadataKeys = metadataKeys; + } + + /** + * Note: Only frame level label detection results are emitted by the Streaming API. This is due + * to streaming's nature. Because in live stream, one cannot predict when a shot change will + * happen. + */ + @Override + public TableSchema getTableSchema() { + ArrayList fields = new ArrayList<>(); + fields.add( + new TableFieldSchema() + .setName(Field.GCS_URI_FIELD) + .setType("STRING") + .setMode("REQUIRED")); + fields.add( + new TableFieldSchema() + .setName(Field.TIMESTAMP_FIELD) + .setType("TIMESTAMP") + .setMode("REQUIRED")); + fields.add( + new TableFieldSchema().setName(Field.ENTITY).setType("STRING").setMode("REQUIRED")); + fields.add( + new TableFieldSchema() + .setName(Field.FRAMES) + .setType("RECORD") + .setMode("REPEATED") + .setFields( + ImmutableList.of( + new TableFieldSchema() + .setName(Field.CONFIDENCE) + .setType("FLOAT") + .setMode("REQUIRED"), + new TableFieldSchema() + .setName(Field.TIME_OFFSET) + .setType("INT64") + .setMode("REQUIRED")))); + + ProcessorUtils.setMetadataFieldsSchema(fields, metadataKeys); + + return new TableSchema().setFields(fields); + } + } + + @Override + public boolean shouldProcess(GeneratedMessageV3 response) { + return (response instanceof StreamingAnnotateVideoResponse + && ((StreamingAnnotateVideoResponse) response) + .getAnnotationResults() + .getLabelAnnotationsCount() + > 0); + } + + @Override + public ProcessorResult process(GCSFileInfo fileInfo, GeneratedMessageV3 r) { + StreamingAnnotateVideoResponse response = (StreamingAnnotateVideoResponse) r; + StreamingVideoAnnotationResults annotationResults = response.getAnnotationResults(); + counter.inc(annotationResults.getLabelAnnotationsCount()); + ProcessorResult result = new ProcessorResult(ProcessorResult.VIDEO_LABEL, destination); + for (LabelAnnotation annotation : annotationResults.getLabelAnnotationsList()) { + TableRow row = ProcessorUtils.startRow(fileInfo); + row.set( + Field.ENTITY, + annotation.hasEntity() + ? annotation.getEntity().getDescription() + : "NOT_FOUND"); // FIXME: Seems like sometimes it's an empty string? + + List frames = new ArrayList<>(annotation.getFramesCount()); + AtomicReference maxConfidence = new AtomicReference<>((float) 0); + annotation + .getFramesList() + .forEach( + frame -> { + TableRow frameRow = new TableRow(); + frameRow.set(Field.CONFIDENCE, frame.getConfidence()); + frameRow.set(Field.TIME_OFFSET, frame.getTimeOffset().getSeconds()); + frames.add(frameRow); + maxConfidence.set(Math.max(maxConfidence.get(), frame.getConfidence())); + }); + row.put(Field.FRAMES, frames); + + ProcessorUtils.addMetadataValues(row, fileInfo, metadataKeys); + + LOG.debug("Processing {}", row); + result.allRows.add(row); + + if (relevantEntities != null + && relevantEntities.stream() + .anyMatch(annotation.getEntity().getDescription()::equalsIgnoreCase) + && maxConfidence.get() >= confidenceThreshold) { + result.relevantRows.add(row); + } + } + return result; + } + + @Override + public TableDetails destinationTableDetails() { + return TableDetails.create( + "Google Video Intelligence API label annotations", + new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), + new TimePartitioning().setField(Field.TIMESTAMP_FIELD), + new SchemaProducer(metadataKeys)); + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/videointelligence/processors/VideoObjectTrackingAnnotationProcessor.java b/src/main/java/com/google/solutions/annotation/ml/videointelligence/processors/VideoObjectTrackingAnnotationProcessor.java new file mode 100644 index 0000000..e83a3b2 --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/videointelligence/processors/VideoObjectTrackingAnnotationProcessor.java @@ -0,0 +1,178 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml.videointelligence.processors; + +import com.google.api.services.bigquery.model.*; +import com.google.cloud.videointelligence.v1p3beta1.NormalizedBoundingBox; +import com.google.cloud.videointelligence.v1p3beta1.ObjectTrackingAnnotation; +import com.google.cloud.videointelligence.v1p3beta1.StreamingAnnotateVideoResponse; +import com.google.cloud.videointelligence.v1p3beta1.StreamingVideoAnnotationResults; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.bigquery.TableDetails; +import com.google.solutions.annotation.bigquery.TableSchemaProducer; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.ml.Constants.Field; +import com.google.solutions.annotation.ml.MLApiResponseProcessor; +import com.google.solutions.annotation.ml.ProcessorUtils; +import java.util.*; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class VideoObjectTrackingAnnotationProcessor implements MLApiResponseProcessor { + + private static final long serialVersionUID = 1L; + + private static final Logger LOG = + LoggerFactory.getLogger(VideoObjectTrackingAnnotationProcessor.class); + private final BigQueryDestination destination; + private final Set metadataKeys; + private final Set relevantEntities; + private final float confidenceThreshold; + public static final Counter counter = + Metrics.counter(MLApiResponseProcessor.class, "numberOfVideoObjectTrackingAnnotations"); + + public VideoObjectTrackingAnnotationProcessor( + String tableId, + Set metadataKeys, + Set relevantEntities, + float confidenceThreshold) { + this.destination = new BigQueryDestination(tableId); + this.metadataKeys = metadataKeys; + this.relevantEntities = relevantEntities; + this.confidenceThreshold = confidenceThreshold; + } + + private static class SchemaProducer implements TableSchemaProducer { + + private static final long serialVersionUID = 1L; + private final Set metadataKeys; + + SchemaProducer(Set metadataKeys) { + this.metadataKeys = metadataKeys; + } + + @Override + public TableSchema getTableSchema() { + ArrayList fields = new ArrayList<>(); + fields.add( + new TableFieldSchema() + .setName(Field.GCS_URI_FIELD) + .setType("STRING") + .setMode("REQUIRED")); + fields.add( + new TableFieldSchema() + .setName(Field.TIMESTAMP_FIELD) + .setType("TIMESTAMP") + .setMode("REQUIRED")); + fields.add( + new TableFieldSchema().setName(Field.ENTITY).setType("STRING").setMode("REQUIRED")); + fields.add(new TableFieldSchema().setName(Field.CONFIDENCE).setType("FLOAT")); + fields.add( + new TableFieldSchema() + .setName(Field.FRAMES) + .setType("RECORD") + .setMode("REPEATED") + .setFields( + ImmutableList.of( + new TableFieldSchema() + .setName(Field.TIME_OFFSET) + .setType("INT64") + .setMode("REQUIRED"), + new TableFieldSchema() + .setName(Field.LEFT) + .setType("FLOAT") + .setMode("REQUIRED"), + new TableFieldSchema() + .setName(Field.TOP) + .setType("FLOAT") + .setMode("REQUIRED"), + new TableFieldSchema() + .setName(Field.RIGHT) + .setType("FLOAT") + .setMode("REQUIRED"), + new TableFieldSchema() + .setName(Field.BOTTOM) + .setType("FLOAT") + .setMode("REQUIRED")))); + ProcessorUtils.setMetadataFieldsSchema(fields, metadataKeys); + return new TableSchema().setFields(fields); + } + } + + @Override + public boolean shouldProcess(GeneratedMessageV3 response) { + return (response instanceof StreamingAnnotateVideoResponse + && ((StreamingAnnotateVideoResponse) response) + .getAnnotationResults() + .getObjectAnnotationsCount() + > 0); + } + + @Override + public ProcessorResult process(GCSFileInfo fileInfo, GeneratedMessageV3 r) { + StreamingAnnotateVideoResponse response = (StreamingAnnotateVideoResponse) r; + StreamingVideoAnnotationResults annotationResults = response.getAnnotationResults(); + counter.inc(annotationResults.getObjectAnnotationsCount()); + ProcessorResult result = + new ProcessorResult(ProcessorResult.VIDEO_OBJECT_TRACKING, destination); + for (ObjectTrackingAnnotation annotation : annotationResults.getObjectAnnotationsList()) { + TableRow row = ProcessorUtils.startRow(fileInfo); + row.set(Field.CONFIDENCE, annotation.getConfidence()); + row.set( + Field.ENTITY, + annotation.hasEntity() ? annotation.getEntity().getDescription() : "NOT_FOUND"); + List frames = new ArrayList<>(annotation.getFramesCount()); + annotation + .getFramesList() + .forEach( + frame -> { + TableRow frameRow = new TableRow(); + NormalizedBoundingBox normalizedBoundingBox = frame.getNormalizedBoundingBox(); + frameRow.set(Field.TIME_OFFSET, frame.getTimeOffset().getSeconds()); + frameRow.set(Field.LEFT, normalizedBoundingBox.getLeft()); + frameRow.set(Field.TOP, normalizedBoundingBox.getTop()); + frameRow.set(Field.RIGHT, normalizedBoundingBox.getRight()); + frameRow.set(Field.BOTTOM, normalizedBoundingBox.getBottom()); + frames.add(frameRow); + }); + row.put(Field.FRAMES, frames); + ProcessorUtils.addMetadataValues(row, fileInfo, metadataKeys); + LOG.debug("Processing {}", row); + result.allRows.add(row); + + if (relevantEntities != null + && relevantEntities.stream() + .anyMatch(annotation.getEntity().getDescription()::equalsIgnoreCase) + && annotation.getConfidence() >= confidenceThreshold) { + result.relevantRows.add(row); + } + } + return result; + } + + @Override + public TableDetails destinationTableDetails() { + return TableDetails.create( + "Google Video Intelligence API object tracking annotations", + new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), + new TimePartitioning().setField(Field.TIMESTAMP_FIELD), + new SchemaProducer(metadataKeys)); + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/vision/ImageAnnotator.java b/src/main/java/com/google/solutions/annotation/ml/vision/ImageAnnotator.java new file mode 100644 index 0000000..56012ae --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/vision/ImageAnnotator.java @@ -0,0 +1,125 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml.vision; + +import com.google.api.client.util.ExponentialBackOff; +import com.google.api.gax.rpc.ResourceExhaustedException; +import com.google.cloud.vision.v1.AnnotateImageRequest; +import com.google.cloud.vision.v1.AnnotateImageResponse; +import com.google.cloud.vision.v1.Feature; +import com.google.cloud.vision.v1.Image; +import com.google.cloud.vision.v1.ImageAnnotatorClient; +import com.google.cloud.vision.v1.ImageSource; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.AnnotationPipeline; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.ml.BackOffUtils; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; +import org.apache.beam.sdk.values.KV; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Calls Google Cloud Vision API to annotate a batch of GCS files. + * + *

The GCS file URIs are provided in the incoming PCollection and should not exceed the limit + * imposed by the API (maximum of 16 images per request). + * + *

The resulting PCollection contains key/value pair with the GCS file URI as the key and the API + * response as the value. + */ +public class ImageAnnotator { + + public static final Logger LOG = LoggerFactory.getLogger(ImageAnnotator.class); + + private final List featureList = new ArrayList<>(); + private final ImageAnnotatorClient client; + + public ImageAnnotator(List featureTypes) { + if (featureTypes != null) { + featureTypes.forEach(type -> featureList.add(Feature.newBuilder().setType(type).build())); + } + + try { + client = ImageAnnotatorClient.create(); + } catch (IOException e) { + LOG.error("Failed to create Vision API Service Client: {}", e.getMessage()); + throw new RuntimeException(e); + } + } + + public void teardown() { + if (client != null) { + client.shutdownNow(); + try { + int waitTime = 10; + if (!client.awaitTermination(waitTime, TimeUnit.SECONDS)) { + LOG.warn( + "Failed to shutdown the annotation client after {} seconds. Closing client anyway.", + waitTime); + } + } catch (InterruptedException e) { + // Do nothing + } + client.close(); + } + } + + public List> processFiles(Iterable fileInfos) { + List requests = new ArrayList<>(); + + Map uriToFileInfo = new HashMap<>(); + + fileInfos.forEach( + fileInfo -> { + uriToFileInfo.put(fileInfo.getUri(), fileInfo); + Image image = + Image.newBuilder() + .setSource(ImageSource.newBuilder().setImageUri(fileInfo.getUri()).build()) + .build(); + AnnotateImageRequest.Builder request = + AnnotateImageRequest.newBuilder().setImage(image).addAllFeatures(featureList); + requests.add(request.build()); + }); + + List responses; + + ExponentialBackOff backoff = BackOffUtils.createBackOff(); + while (true) { + try { + AnnotationPipeline.numberOfImageApiRequests.inc(); + responses = client.batchAnnotateImages(requests).getResponsesList(); + break; + } catch (ResourceExhaustedException e) { + BackOffUtils.handleQuotaReachedException(backoff, e); + } + } + + int index = 0; + List> result = new ArrayList<>(); + for (AnnotateImageResponse response : responses) { + String uri = requests.get(index++).getImage().getSource().getImageUri(); + GCSFileInfo fileInfo = uriToFileInfo.get(uri); + result.add(KV.of(fileInfo, response)); + } + return result; + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/vision/processors/CropHintAnnotationProcessor.java b/src/main/java/com/google/solutions/annotation/ml/vision/processors/CropHintAnnotationProcessor.java new file mode 100644 index 0000000..66ad28f --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/vision/processors/CropHintAnnotationProcessor.java @@ -0,0 +1,167 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml.vision.processors; + +import com.google.api.services.bigquery.model.Clustering; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.api.services.bigquery.model.TimePartitioning; +import com.google.cloud.vision.v1.AnnotateImageResponse; +import com.google.cloud.vision.v1.CropHint; +import com.google.cloud.vision.v1.CropHintsAnnotation; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryConstants; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.bigquery.TableDetails; +import com.google.solutions.annotation.bigquery.TableSchemaProducer; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.ml.Constants; +import com.google.solutions.annotation.ml.Constants.Field; +import com.google.solutions.annotation.ml.MLApiResponseProcessor; +import com.google.solutions.annotation.ml.ProcessorUtils; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Extracts crop hint annotations (https://cloud.google.com/vision/docs/detecting-crop-hints) + * + *

Note: requests for either CROP_HINT feature or IMAGE_PROPERTIES feature will produce crop + * hints + */ +public class CropHintAnnotationProcessor implements MLApiResponseProcessor { + + private static final long serialVersionUID = 1L; + + private static final Logger LOG = LoggerFactory.getLogger(CropHintAnnotationProcessor.class); + public static final Counter counter = + Metrics.counter(MLApiResponseProcessor.class, "numberOfCropHintAnnotations"); + + private final BigQueryDestination destination; + private final Set metadataKeys; + private final float confidenceThreshold; + + /** Creates a processor and specifies the table id to persist to. */ + public CropHintAnnotationProcessor( + String tableId, Set metadataKeys, float confidenceThreshold) { + this.destination = new BigQueryDestination(tableId); + this.metadataKeys = metadataKeys; + this.confidenceThreshold = confidenceThreshold; + } + + private static class SchemaProducer implements TableSchemaProducer { + + private static final long serialVersionUID = 1L; + private final Set metadataKeys; + + SchemaProducer(Set metadataKeys) { + this.metadataKeys = metadataKeys; + } + + @Override + public TableSchema getTableSchema() { + ArrayList fields = new ArrayList<>(); + fields.add( + new TableFieldSchema() + .setName(Field.GCS_URI_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Field.CROP_HINTS) + .setType(BigQueryConstants.Type.RECORD) + .setMode(BigQueryConstants.Mode.REPEATED) + .setFields( + ImmutableList.of( + new TableFieldSchema() + .setName(Field.CONFIDENCE) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED), + new TableFieldSchema() + .setName(Field.IMPORTANCE_FRACTION) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED), + new TableFieldSchema() + .setName(Field.BOUNDING_POLY) + .setType(BigQueryConstants.Type.RECORD) + .setMode(BigQueryConstants.Mode.REQUIRED) + .setFields(Constants.POLYGON_FIELDS)))); + fields.add( + new TableFieldSchema() + .setName(Field.TIMESTAMP_FIELD) + .setType(BigQueryConstants.Type.TIMESTAMP) + .setMode(BigQueryConstants.Mode.REQUIRED)); + ProcessorUtils.setMetadataFieldsSchema(fields, metadataKeys); + return new TableSchema().setFields(fields); + } + } + + @Override + public TableDetails destinationTableDetails() { + return TableDetails.create( + "Google Vision API Crop Hint Annotations", + new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), + new TimePartitioning().setField(Field.TIMESTAMP_FIELD), + new SchemaProducer(metadataKeys)); + } + + @Override + public boolean shouldProcess(GeneratedMessageV3 response) { + return (response instanceof AnnotateImageResponse + && ((AnnotateImageResponse) response).getCropHintsAnnotation().getCropHintsCount() > 0); + } + + @Override + public ProcessorResult process(GCSFileInfo fileInfo, GeneratedMessageV3 r) { + AnnotateImageResponse response = (AnnotateImageResponse) r; + CropHintsAnnotation cropHintsAnnotation = response.getCropHintsAnnotation(); + counter.inc(); + + List cropHintRows = + new ArrayList<>(response.getCropHintsAnnotation().getCropHintsCount()); + float maxConfidence = 0; + for (CropHint cropHint : cropHintsAnnotation.getCropHintsList()) { + TableRow cropHintRow = new TableRow(); + cropHintRow.put( + Field.BOUNDING_POLY, ProcessorUtils.getBoundingPolyAsRow(cropHint.getBoundingPoly())); + cropHintRow.put(Field.CONFIDENCE, cropHint.getConfidence()); + cropHintRow.put(Field.IMPORTANCE_FRACTION, cropHint.getImportanceFraction()); + cropHintRows.add(cropHintRow); + maxConfidence = Math.max(maxConfidence, cropHint.getConfidence()); + } + + TableRow row = ProcessorUtils.startRow(fileInfo); + row.put(Field.CROP_HINTS, cropHintRows); + ProcessorUtils.addMetadataValues(row, fileInfo, metadataKeys); + LOG.debug("Processing {}", row); + + ProcessorResult result = new ProcessorResult(ProcessorResult.IMAGE_CROP, destination); + result.allRows.add(row); + + if (maxConfidence >= confidenceThreshold) { + result.relevantRows.add(row); + } + + return result; + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/vision/processors/ErrorProcessor.java b/src/main/java/com/google/solutions/annotation/ml/vision/processors/ErrorProcessor.java new file mode 100644 index 0000000..cddf5f0 --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/vision/processors/ErrorProcessor.java @@ -0,0 +1,129 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml.vision.processors; + +import com.google.api.services.bigquery.model.Clustering; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.api.services.bigquery.model.TimePartitioning; +import com.google.cloud.vision.v1.AnnotateImageResponse; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryConstants; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.bigquery.TableDetails; +import com.google.solutions.annotation.bigquery.TableSchemaProducer; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.ml.Constants.Field; +import com.google.solutions.annotation.ml.MLApiResponseProcessor; +import com.google.solutions.annotation.ml.ProcessorUtils; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Set; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Captures the error occurred during processing. Note, that there could be some valid annotations + * returned in the response even though the response contains an error. + */ +public class ErrorProcessor implements MLApiResponseProcessor { + + private static final long serialVersionUID = 1L; + + public static final Counter counter = + Metrics.counter(MLApiResponseProcessor.class, "numberOfErrors"); + public static final Logger LOG = LoggerFactory.getLogger(ErrorProcessor.class); + + private final BigQueryDestination destination; + private final Set metadataKeys; + + /** Creates a processor and specifies the table id to persist to. */ + public ErrorProcessor(String tableId, Set metadataKeys) { + this.destination = new BigQueryDestination(tableId); + this.metadataKeys = metadataKeys; + } + + private static class SchemaProducer implements TableSchemaProducer { + + private static final long serialVersionUID = 1L; + private final Set metadataKeys; + + SchemaProducer(Set metadataKeys) { + this.metadataKeys = metadataKeys; + } + + @Override + public TableSchema getTableSchema() { + ArrayList fields = new ArrayList<>(); + fields.add( + new TableFieldSchema() + .setName(Field.GCS_URI_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Field.DESCRIPTION_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Field.STACK_TRACE) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.NULLABLE)); + fields.add( + new TableFieldSchema() + .setName(Field.TIMESTAMP_FIELD) + .setType(BigQueryConstants.Type.TIMESTAMP) + .setMode(BigQueryConstants.Mode.REQUIRED)); + ProcessorUtils.setMetadataFieldsSchema(fields, metadataKeys); + return new TableSchema().setFields(fields); + } + } + + @Override + public TableDetails destinationTableDetails() { + return TableDetails.create( + "Google Vision API Processing Errors", + new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), + new TimePartitioning().setField(Field.TIMESTAMP_FIELD), + new SchemaProducer(metadataKeys)); + } + + @Override + public boolean shouldProcess(GeneratedMessageV3 response) { + return (response instanceof AnnotateImageResponse + && ((AnnotateImageResponse) response).hasError()); + } + + @Override + public ProcessorResult process(GCSFileInfo fileInfo, GeneratedMessageV3 r) { + AnnotateImageResponse response = (AnnotateImageResponse) r; + counter.inc(); + + TableRow row = ProcessorUtils.startRow(fileInfo); + row.put(Field.DESCRIPTION_FIELD, response.getError().toString()); + ProcessorUtils.addMetadataValues(row, fileInfo, metadataKeys); + + LOG.debug("Processing {}", row); + + ProcessorResult result = new ProcessorResult(ProcessorResult.IMAGE_ERROR, destination); + result.allRows.add(row); + return result; + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/vision/processors/FaceAnnotationProcessor.java b/src/main/java/com/google/solutions/annotation/ml/vision/processors/FaceAnnotationProcessor.java new file mode 100644 index 0000000..4bb7ed4 --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/vision/processors/FaceAnnotationProcessor.java @@ -0,0 +1,216 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml.vision.processors; + +import com.google.api.services.bigquery.model.Clustering; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.api.services.bigquery.model.TimePartitioning; +import com.google.cloud.vision.v1.AnnotateImageResponse; +import com.google.cloud.vision.v1.FaceAnnotation; +import com.google.cloud.vision.v1.Position; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryConstants; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.bigquery.TableDetails; +import com.google.solutions.annotation.bigquery.TableSchemaProducer; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.ml.Constants; +import com.google.solutions.annotation.ml.MLApiResponseProcessor; +import com.google.solutions.annotation.ml.ProcessorUtils; +import java.util.*; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Extracts face annotations (https://cloud.google.com/vision/docs/detecting-faces) */ +public class FaceAnnotationProcessor implements MLApiResponseProcessor { + + private static final long serialVersionUID = 1L; + + private static final Logger LOG = LoggerFactory.getLogger(FaceAnnotationProcessor.class); + public static final Counter counter = + Metrics.counter(MLApiResponseProcessor.class, "numberOfFaceAnnotations"); + private final BigQueryDestination destination; + private final Set metadataKeys; + private final float detectionConfidenceThreshold; + + /** Creates a processor and specifies the table id to persist to. */ + public FaceAnnotationProcessor( + String tableId, Set metadataKeys, float detectionConfidenceThreshold) { + this.destination = new BigQueryDestination(tableId); + this.metadataKeys = metadataKeys; + this.detectionConfidenceThreshold = detectionConfidenceThreshold; + } + + /** + * The schema doesn't represent the complete list of all attributes returned by the APIs. For more + * details see + * https://cloud.google.com/vision/docs/reference/rest/v1/AnnotateImageResponse?hl=pl#FaceAnnotation + */ + private static class SchemaProducer implements TableSchemaProducer { + + private static final long serialVersionUID = 1L; + private final Set metadataKeys; + + SchemaProducer(Set metadataKeys) { + this.metadataKeys = metadataKeys; + } + + @Override + public TableSchema getTableSchema() { + ArrayList fields = new ArrayList<>(); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.GCS_URI_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.BOUNDING_POLY) + .setType(BigQueryConstants.Type.RECORD) + .setMode(BigQueryConstants.Mode.REQUIRED) + .setFields(Constants.POLYGON_FIELDS)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.FD_BOUNDING_POLY) + .setType(BigQueryConstants.Type.RECORD) + .setMode(BigQueryConstants.Mode.REQUIRED) + .setFields(Constants.POLYGON_FIELDS)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.LANDMARKS) + .setType(BigQueryConstants.Type.RECORD) + .setMode(BigQueryConstants.Mode.REPEATED) + .setFields( + Arrays.asList( + new TableFieldSchema() + .setName(Constants.Field.FACE_LANDMARK_TYPE) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED), + new TableFieldSchema() + .setName(Constants.Field.FACE_LANDMARK_POSITION) + .setType(BigQueryConstants.Type.RECORD) + .setMode(BigQueryConstants.Mode.REQUIRED) + .setFields(Constants.POSITION_FIELDS)))); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.DETECTION_CONFIDENCE) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.LANDMARKING_CONFIDENCE) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.JOY_LIKELIHOOD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.SORROW_LIKELIHOOD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.ANGER_LIKELIHOOD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.SURPRISE_LIKELIHOOD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.TIMESTAMP_FIELD) + .setType(BigQueryConstants.Type.TIMESTAMP) + .setMode(BigQueryConstants.Mode.REQUIRED)); + ProcessorUtils.setMetadataFieldsSchema(fields, metadataKeys); + return new TableSchema().setFields(fields); + } + } + + @Override + public TableDetails destinationTableDetails() { + return TableDetails.create( + "Google Vision API Face Annotations", + new Clustering().setFields(Collections.singletonList(Constants.Field.GCS_URI_FIELD)), + new TimePartitioning().setField(Constants.Field.TIMESTAMP_FIELD), + new SchemaProducer(metadataKeys)); + } + + @Override + public boolean shouldProcess(GeneratedMessageV3 response) { + return (response instanceof AnnotateImageResponse + && ((AnnotateImageResponse) response).getFaceAnnotationsCount() > 0); + } + + @Override + public ProcessorResult process(GCSFileInfo fileInfo, GeneratedMessageV3 r) { + AnnotateImageResponse response = (AnnotateImageResponse) r; + counter.inc(response.getFaceAnnotationsCount()); + ProcessorResult result = new ProcessorResult(ProcessorResult.IMAGE_FACE, destination); + for (FaceAnnotation annotation : response.getFaceAnnotationsList()) { + TableRow row = ProcessorUtils.startRow(fileInfo); + + row.put( + Constants.Field.BOUNDING_POLY, + ProcessorUtils.getBoundingPolyAsRow(annotation.getBoundingPoly())); + row.put( + Constants.Field.FD_BOUNDING_POLY, + ProcessorUtils.getBoundingPolyAsRow(annotation.getFdBoundingPoly())); + List landmarks = new ArrayList<>(annotation.getLandmarksCount()); + annotation + .getLandmarksList() + .forEach( + landmark -> { + TableRow landmarkRow = new TableRow(); + landmarkRow.put(Constants.Field.FACE_LANDMARK_TYPE, landmark.getType().toString()); + + Position position = landmark.getPosition(); + TableRow positionRow = new TableRow(); + positionRow.put(Constants.Field.VERTEX_X, position.getX()); + positionRow.put(Constants.Field.VERTEX_Y, position.getY()); + positionRow.put(Constants.Field.VERTEX_Z, position.getZ()); + landmarkRow.put(Constants.Field.FACE_LANDMARK_POSITION, positionRow); + + landmarks.add(landmarkRow); + }); + row.put(Constants.Field.LANDMARKS, landmarks); + row.put(Constants.Field.DETECTION_CONFIDENCE, annotation.getDetectionConfidence()); + row.put(Constants.Field.LANDMARKING_CONFIDENCE, annotation.getLandmarkingConfidence()); + row.put(Constants.Field.JOY_LIKELIHOOD, annotation.getJoyLikelihood().toString()); + row.put(Constants.Field.SORROW_LIKELIHOOD, annotation.getSorrowLikelihood().toString()); + row.put(Constants.Field.ANGER_LIKELIHOOD, annotation.getAngerLikelihood().toString()); + row.put(Constants.Field.SURPRISE_LIKELIHOOD, annotation.getSurpriseLikelihood().toString()); + ProcessorUtils.addMetadataValues(row, fileInfo, metadataKeys); + + LOG.debug("Processing {}", row); + result.allRows.add(row); + + if (annotation.getDetectionConfidence() >= detectionConfidenceThreshold) { + result.relevantRows.add(row); + } + } + + return result; + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/vision/processors/ImagePropertiesProcessor.java b/src/main/java/com/google/solutions/annotation/ml/vision/processors/ImagePropertiesProcessor.java new file mode 100644 index 0000000..7a263bf --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/vision/processors/ImagePropertiesProcessor.java @@ -0,0 +1,186 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml.vision.processors; + +import com.google.api.services.bigquery.model.Clustering; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.api.services.bigquery.model.TimePartitioning; +import com.google.cloud.vision.v1.AnnotateImageResponse; +import com.google.cloud.vision.v1.DominantColorsAnnotation; +import com.google.common.collect.ImmutableList; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryConstants; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.bigquery.TableDetails; +import com.google.solutions.annotation.bigquery.TableSchemaProducer; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.ml.Constants.Field; +import com.google.solutions.annotation.ml.MLApiResponseProcessor; +import com.google.solutions.annotation.ml.ProcessorUtils; +import com.google.type.Color; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Extracts image properties (https://cloud.google.com/vision/docs/detecting-properties) */ +public class ImagePropertiesProcessor implements MLApiResponseProcessor { + + private static final long serialVersionUID = 1L; + private static final Logger LOG = LoggerFactory.getLogger(ImagePropertiesProcessor.class); + + public static final Counter counter = + Metrics.counter(MLApiResponseProcessor.class, "numberOfImagePropertiesAnnotations"); + + private final BigQueryDestination destination; + private final Set metadataKeys; + + /** Creates a processor and specifies the table id to persist to. */ + public ImagePropertiesProcessor(String tableId, Set metadataKeys) { + this.destination = new BigQueryDestination(tableId); + this.metadataKeys = metadataKeys; + } + + private static class SchemaProducer implements TableSchemaProducer { + + private static final long serialVersionUID = 1L; + private final Set metadataKeys; + + SchemaProducer(Set metadataKeys) { + this.metadataKeys = metadataKeys; + } + + @Override + public TableSchema getTableSchema() { + ArrayList fields = new ArrayList<>(); + fields.add( + new TableFieldSchema() + .setName(Field.GCS_URI_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Field.DOMINANT_COLORS) + .setType(BigQueryConstants.Type.RECORD) + .setMode(BigQueryConstants.Mode.REQUIRED) + .setFields( + ImmutableList.of( + new TableFieldSchema() + .setName(Field.COLORS) + .setType(BigQueryConstants.Type.RECORD) + .setMode(BigQueryConstants.Mode.REPEATED) + .setFields( + ImmutableList.of( + new TableFieldSchema() + .setName(Field.SCORE_FIELD) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED), + new TableFieldSchema() + .setName(Field.PIXEL_FRACTION) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED), + new TableFieldSchema() + .setName(Field.COLOR) + .setType(BigQueryConstants.Type.RECORD) + .setMode(BigQueryConstants.Mode.REQUIRED) + .setFields( + ImmutableList.of( + new TableFieldSchema() + .setName(Field.COLOR_RED) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED), + new TableFieldSchema() + .setName(Field.COLOR_BLUE) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED), + new TableFieldSchema() + .setName(Field.COLOR_GREEN) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED), + new TableFieldSchema() + .setName(Field.COLOR_ALPHA) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.NULLABLE)))))))); + fields.add( + new TableFieldSchema() + .setName(Field.TIMESTAMP_FIELD) + .setType(BigQueryConstants.Type.TIMESTAMP) + .setMode(BigQueryConstants.Mode.REQUIRED)); + ProcessorUtils.setMetadataFieldsSchema(fields, metadataKeys); + return new TableSchema().setFields(fields); + } + } + + @Override + public TableDetails destinationTableDetails() { + return TableDetails.create( + "Google Vision API Image Properties", + new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), + new TimePartitioning().setField(Field.TIMESTAMP_FIELD), + new SchemaProducer(metadataKeys)); + } + + @Override + public boolean shouldProcess(GeneratedMessageV3 response) { + return (response instanceof AnnotateImageResponse + && ((AnnotateImageResponse) response).getImagePropertiesAnnotation().hasDominantColors()); + } + + @Override + public ProcessorResult process(GCSFileInfo fileInfo, GeneratedMessageV3 r) { + AnnotateImageResponse response = (AnnotateImageResponse) r; + counter.inc(); + TableRow row = ProcessorUtils.startRow(fileInfo); + DominantColorsAnnotation dominantColors = + response.getImagePropertiesAnnotation().getDominantColors(); + List colors = new ArrayList<>(dominantColors.getColorsCount()); + dominantColors + .getColorsList() + .forEach( + colorInfo -> { + TableRow colorInfoRow = new TableRow(); + colorInfoRow.put(Field.SCORE_FIELD, colorInfo.getScore()); + colorInfoRow.put(Field.PIXEL_FRACTION, colorInfo.getPixelFraction()); + Color color = colorInfo.getColor(); + TableRow colorRow = new TableRow(); + colorRow.put(Field.COLOR_RED, color.getRed()); + colorRow.put(Field.COLOR_GREEN, color.getGreen()); + colorRow.put(Field.COLOR_BLUE, color.getBlue()); + if (color.hasAlpha()) { + colorRow.put(Field.COLOR_ALPHA, color.getAlpha()); + } + colorInfoRow.put(Field.COLOR, colorRow); + colors.add(colorInfoRow); + }); + TableRow colorsRow = new TableRow(); + colorsRow.put(Field.COLORS, colors); + row.put(Field.DOMINANT_COLORS, colorsRow); + + ProcessorUtils.addMetadataValues(row, fileInfo, metadataKeys); + LOG.debug("Processing {}", row); + + ProcessorResult result = new ProcessorResult(ProcessorResult.IMAGE_PROPERTIES, destination); + result.allRows.add(row); + + return result; + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/vision/processors/LabelAnnotationProcessor.java b/src/main/java/com/google/solutions/annotation/ml/vision/processors/LabelAnnotationProcessor.java new file mode 100644 index 0000000..c9082f6 --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/vision/processors/LabelAnnotationProcessor.java @@ -0,0 +1,148 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml.vision.processors; + +import com.google.api.services.bigquery.model.Clustering; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.api.services.bigquery.model.TimePartitioning; +import com.google.cloud.vision.v1.AnnotateImageResponse; +import com.google.cloud.vision.v1.EntityAnnotation; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryConstants; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.bigquery.TableDetails; +import com.google.solutions.annotation.bigquery.TableSchemaProducer; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.ml.Constants.Field; +import com.google.solutions.annotation.ml.MLApiResponseProcessor; +import com.google.solutions.annotation.ml.ProcessorUtils; +import java.util.*; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Extracts label annotations (https://cloud.google.com/vision/docs/labels) */ +public class LabelAnnotationProcessor implements MLApiResponseProcessor { + + private static final long serialVersionUID = 1L; + private final BigQueryDestination destination; + private final Set metadataKeys; + private final Set relevantLabels; + private final float scoreThreshold; + private static final Logger LOG = LoggerFactory.getLogger(LabelAnnotationProcessor.class); + + public static final Counter counter = + Metrics.counter(MLApiResponseProcessor.class, "numberOfLabelAnnotations"); + + /** Creates a processor and specifies the table id to persist to. */ + public LabelAnnotationProcessor( + String tableId, Set metadataKeys, Set relevantLabels, float scoreThreshold) { + this.destination = new BigQueryDestination(tableId); + this.metadataKeys = metadataKeys; + this.relevantLabels = relevantLabels; + this.scoreThreshold = scoreThreshold; + } + + private static class SchemaProducer implements TableSchemaProducer { + + private static final long serialVersionUID = 1L; + private final Set metadataKeys; + + SchemaProducer(Set metadataKeys) { + this.metadataKeys = metadataKeys; + } + + @Override + public TableSchema getTableSchema() { + ArrayList fields = new ArrayList<>(); + fields.add( + new TableFieldSchema() + .setName(Field.GCS_URI_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Field.MID_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.NULLABLE)); + fields.add( + new TableFieldSchema() + .setName(Field.DESCRIPTION_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Field.SCORE_FIELD) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Field.TOPICALITY_FIELD) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Field.TIMESTAMP_FIELD) + .setType(BigQueryConstants.Type.TIMESTAMP) + .setMode(BigQueryConstants.Mode.REQUIRED)); + ProcessorUtils.setMetadataFieldsSchema(fields, metadataKeys); + return new TableSchema().setFields(fields); + } + } + + @Override + public TableDetails destinationTableDetails() { + return TableDetails.create( + "Google Vision API Label Annotations", + new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), + new TimePartitioning().setField(Field.TIMESTAMP_FIELD), + new SchemaProducer(metadataKeys)); + } + + @Override + public boolean shouldProcess(GeneratedMessageV3 response) { + return (response instanceof AnnotateImageResponse + && ((AnnotateImageResponse) response).getLabelAnnotationsCount() > 0); + } + + @Override + public ProcessorResult process(GCSFileInfo fileInfo, GeneratedMessageV3 r) { + AnnotateImageResponse response = (AnnotateImageResponse) r; + counter.inc(response.getLabelAnnotationsCount()); + ProcessorResult result = new ProcessorResult(ProcessorResult.IMAGE_LABEL, destination); + for (EntityAnnotation annotation : response.getLabelAnnotationsList()) { + TableRow row = ProcessorUtils.startRow(fileInfo); + row.put(Field.MID_FIELD, annotation.getMid()); + row.put(Field.DESCRIPTION_FIELD, annotation.getDescription()); + row.put(Field.SCORE_FIELD, annotation.getScore()); + row.put(Field.TOPICALITY_FIELD, annotation.getTopicality()); + ProcessorUtils.addMetadataValues(row, fileInfo, metadataKeys); + + LOG.debug("Processing {}", row); + result.allRows.add(row); + + if (relevantLabels != null + && relevantLabels.stream().anyMatch(annotation.getDescription()::equalsIgnoreCase) + && annotation.getScore() >= scoreThreshold) { + result.relevantRows.add(row); + } + } + return result; + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/vision/processors/LandmarkAnnotationProcessor.java b/src/main/java/com/google/solutions/annotation/ml/vision/processors/LandmarkAnnotationProcessor.java new file mode 100644 index 0000000..7107b87 --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/vision/processors/LandmarkAnnotationProcessor.java @@ -0,0 +1,176 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml.vision.processors; + +import com.google.api.services.bigquery.model.Clustering; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.api.services.bigquery.model.TimePartitioning; +import com.google.cloud.vision.v1.AnnotateImageResponse; +import com.google.cloud.vision.v1.EntityAnnotation; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryConstants; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.bigquery.TableDetails; +import com.google.solutions.annotation.bigquery.TableSchemaProducer; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.ml.Constants; +import com.google.solutions.annotation.ml.MLApiResponseProcessor; +import com.google.solutions.annotation.ml.ProcessorUtils; +import java.util.*; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Extracts landmark annotations (https://cloud.google.com/vision/docs/detecting-landmarks) */ +public class LandmarkAnnotationProcessor implements MLApiResponseProcessor { + + private static final long serialVersionUID = 1L; + + private static final Logger LOG = LoggerFactory.getLogger(LandmarkAnnotationProcessor.class); + public static final Counter counter = + Metrics.counter(MLApiResponseProcessor.class, "numberOfLandmarkAnnotations"); + + private final BigQueryDestination destination; + private final Set metadataKeys; + private final Set relevantLandmarks; + private final float scoreThreshold; + + /** Creates a processor and specifies the table id to persist to. */ + public LandmarkAnnotationProcessor( + String tableId, + Set metadataKeys, + Set relevantLandmarks, + float scoreThreshold) { + this.destination = new BigQueryDestination(tableId); + this.metadataKeys = metadataKeys; + this.relevantLandmarks = relevantLandmarks; + this.scoreThreshold = scoreThreshold; + } + + private static class SchemaProducer implements TableSchemaProducer { + + private static final long serialVersionUID = 1L; + private final Set metadataKeys; + + SchemaProducer(Set metadataKeys) { + this.metadataKeys = metadataKeys; + } + + @Override + public TableSchema getTableSchema() { + ArrayList fields = new ArrayList<>(); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.GCS_URI_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.MID_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.NULLABLE)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.DESCRIPTION_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.SCORE_FIELD) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.BOUNDING_POLY) + .setType(BigQueryConstants.Type.RECORD) + .setMode(BigQueryConstants.Mode.NULLABLE) + .setFields(Constants.POLYGON_FIELDS)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.LOCATIONS) + .setType(BigQueryConstants.Type.GEOGRAPHY) + .setMode(BigQueryConstants.Mode.REPEATED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.TIMESTAMP_FIELD) + .setType(BigQueryConstants.Type.TIMESTAMP) + .setMode(BigQueryConstants.Mode.REQUIRED)); + ProcessorUtils.setMetadataFieldsSchema(fields, metadataKeys); + return new TableSchema().setFields(fields); + } + } + + @Override + public TableDetails destinationTableDetails() { + return TableDetails.create( + "Google Vision API Landmark Annotations", + new Clustering().setFields(Collections.singletonList(Constants.Field.GCS_URI_FIELD)), + new TimePartitioning().setField(Constants.Field.TIMESTAMP_FIELD), + new SchemaProducer(metadataKeys)); + } + + @Override + public boolean shouldProcess(GeneratedMessageV3 response) { + return (response instanceof AnnotateImageResponse + && ((AnnotateImageResponse) response).getLandmarkAnnotationsCount() > 0); + } + + @Override + public ProcessorResult process(GCSFileInfo fileInfo, GeneratedMessageV3 r) { + AnnotateImageResponse response = (AnnotateImageResponse) r; + counter.inc(response.getLandmarkAnnotationsCount()); + ProcessorResult result = new ProcessorResult(ProcessorResult.IMAGE_LANDMARK, destination); + for (EntityAnnotation annotation : response.getLandmarkAnnotationsList()) { + TableRow row = ProcessorUtils.startRow(fileInfo); + row.put(Constants.Field.MID_FIELD, annotation.getMid()); + row.put(Constants.Field.DESCRIPTION_FIELD, annotation.getDescription()); + row.put(Constants.Field.SCORE_FIELD, annotation.getScore()); + + ProcessorUtils.extractBoundingPoly(annotation, row); + + if (annotation.getLocationsCount() > 0) { + List locations = new ArrayList<>(annotation.getLocationsCount()); + annotation + .getLocationsList() + .forEach( + location -> + locations.add( + "POINT(" + + location.getLatLng().getLongitude() + + " " + + location.getLatLng().getLatitude() + + ")")); + row.put(Constants.Field.LOCATIONS, locations); + } + + ProcessorUtils.addMetadataValues(row, fileInfo, metadataKeys); + + LOG.debug("Processing {}", row); + result.allRows.add(row); + + if (relevantLandmarks != null + && relevantLandmarks.stream().anyMatch(annotation.getDescription()::equalsIgnoreCase) + && annotation.getScore() >= scoreThreshold) { + result.relevantRows.add(row); + } + } + + return result; + } +} diff --git a/src/main/java/com/google/solutions/annotation/ml/vision/processors/LogoAnnotationProcessor.java b/src/main/java/com/google/solutions/annotation/ml/vision/processors/LogoAnnotationProcessor.java new file mode 100644 index 0000000..d5fe0dd --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/ml/vision/processors/LogoAnnotationProcessor.java @@ -0,0 +1,152 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml.vision.processors; + +import com.google.api.services.bigquery.model.Clustering; +import com.google.api.services.bigquery.model.TableFieldSchema; +import com.google.api.services.bigquery.model.TableRow; +import com.google.api.services.bigquery.model.TableSchema; +import com.google.api.services.bigquery.model.TimePartitioning; +import com.google.cloud.vision.v1.AnnotateImageResponse; +import com.google.cloud.vision.v1.EntityAnnotation; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryConstants; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.bigquery.TableDetails; +import com.google.solutions.annotation.bigquery.TableSchemaProducer; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import com.google.solutions.annotation.ml.Constants; +import com.google.solutions.annotation.ml.MLApiResponseProcessor; +import com.google.solutions.annotation.ml.ProcessorUtils; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Set; +import org.apache.beam.sdk.metrics.Counter; +import org.apache.beam.sdk.metrics.Metrics; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Extracts logo annotations (https://cloud.google.com/vision/docs/detecting-logos) */ +public class LogoAnnotationProcessor implements MLApiResponseProcessor { + + private static final long serialVersionUID = 1L; + + public static final Counter counter = + Metrics.counter(MLApiResponseProcessor.class, "numberOfLogoAnnotations"); + public static final Logger LOG = LoggerFactory.getLogger(LogoAnnotationProcessor.class); + + private final BigQueryDestination destination; + private final Set metadataKeys; + private final Set relevantLogos; + private final float scoreThreshold; + + /** Creates a processor and specifies the table id to persist to. */ + public LogoAnnotationProcessor( + String tableId, Set metadataKeys, Set relevantLogos, float scoreThreshold) { + this.destination = new BigQueryDestination(tableId); + this.metadataKeys = metadataKeys; + this.relevantLogos = relevantLogos; + this.scoreThreshold = scoreThreshold; + } + + private static class SchemaProducer implements TableSchemaProducer { + + private static final long serialVersionUID = 1L; + private final Set metadataKeys; + + SchemaProducer(Set metadataKeys) { + this.metadataKeys = metadataKeys; + } + + @Override + public TableSchema getTableSchema() { + ArrayList fields = new ArrayList<>(); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.GCS_URI_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.MID_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.NULLABLE)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.DESCRIPTION_FIELD) + .setType(BigQueryConstants.Type.STRING) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.SCORE_FIELD) + .setType(BigQueryConstants.Type.FLOAT) + .setMode(BigQueryConstants.Mode.REQUIRED)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.BOUNDING_POLY) + .setType(BigQueryConstants.Type.RECORD) + .setMode(BigQueryConstants.Mode.NULLABLE) + .setFields(Constants.POLYGON_FIELDS)); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.TIMESTAMP_FIELD) + .setType(BigQueryConstants.Type.TIMESTAMP) + .setMode(BigQueryConstants.Mode.REQUIRED)); + ProcessorUtils.setMetadataFieldsSchema(fields, metadataKeys); + return new TableSchema().setFields(fields); + } + } + + @Override + public TableDetails destinationTableDetails() { + return TableDetails.create( + "Google Vision API Logo Annotations", + new Clustering().setFields(Collections.singletonList(Constants.Field.GCS_URI_FIELD)), + new TimePartitioning().setField(Constants.Field.TIMESTAMP_FIELD), + new SchemaProducer(metadataKeys)); + } + + @Override + public boolean shouldProcess(GeneratedMessageV3 response) { + return (response instanceof AnnotateImageResponse + && ((AnnotateImageResponse) response).getLogoAnnotationsCount() > 0); + } + + @Override + public ProcessorResult process(GCSFileInfo fileInfo, GeneratedMessageV3 r) { + AnnotateImageResponse response = (AnnotateImageResponse) r; + counter.inc(response.getLogoAnnotationsCount()); + ProcessorResult result = new ProcessorResult(ProcessorResult.IMAGE_LOGO, destination); + for (EntityAnnotation annotation : response.getLabelAnnotationsList()) { + TableRow row = ProcessorUtils.startRow(fileInfo); + row.put(Constants.Field.MID_FIELD, annotation.getMid()); + row.put(Constants.Field.DESCRIPTION_FIELD, annotation.getDescription()); + row.put(Constants.Field.SCORE_FIELD, annotation.getScore()); + ProcessorUtils.extractBoundingPoly(annotation, row); + ProcessorUtils.addMetadataValues(row, fileInfo, metadataKeys); + + LOG.debug("Processing {}", row); + result.allRows.add(row); + + if (relevantLogos != null + && relevantLogos.stream().anyMatch(annotation.getDescription()::equalsIgnoreCase) + && annotation.getScore() >= scoreThreshold) { + result.relevantRows.add(row); + } + } + return result; + } +} diff --git a/src/main/java/com/google/solutions/ml/api/vision/PubSubNotificationToGCSUriDoFn.java b/src/main/java/com/google/solutions/annotation/pubsub/PubSubNotificationToGCSInfoDoFn.java similarity index 58% rename from src/main/java/com/google/solutions/ml/api/vision/PubSubNotificationToGCSUriDoFn.java rename to src/main/java/com/google/solutions/annotation/pubsub/PubSubNotificationToGCSInfoDoFn.java index 68ad489..5f35273 100644 --- a/src/main/java/com/google/solutions/ml/api/vision/PubSubNotificationToGCSUriDoFn.java +++ b/src/main/java/com/google/solutions/annotation/pubsub/PubSubNotificationToGCSInfoDoFn.java @@ -1,5 +1,5 @@ /* - * Copyright 2020 Google LLC + * Copyright 2022 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,13 +13,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +package com.google.solutions.annotation.pubsub; -package com.google.solutions.ml.api.vision; - +import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.google.auto.value.AutoValue; +import com.google.solutions.annotation.AnnotationPipeline; +import com.google.solutions.annotation.gcs.GCSFileInfo; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.Objects; import java.util.Set; import org.apache.beam.sdk.extensions.gcp.util.gcsfs.GcsPath; @@ -28,21 +32,17 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -/** - * Converts PubSub notifications into GCS File URIs. - */ +/** Converts PubSub notifications into GCS File URIs. */ @AutoValue -public abstract class PubSubNotificationToGCSUriDoFn extends DoFn { +public abstract class PubSubNotificationToGCSInfoDoFn extends DoFn { private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(PubSubNotificationToGCSUriDoFn.class); + private static final Logger LOG = LoggerFactory.getLogger(PubSubNotificationToGCSInfoDoFn.class); - abstract public Set supportedContentTypes(); + public abstract Set supportedContentTypes(); - public static PubSubNotificationToGCSUriDoFn create(Set supportedContentTypes) { - return builder() - .supportedContentTypes(supportedContentTypes) - .build(); + public static PubSubNotificationToGCSInfoDoFn create(Set supportedContentTypes) { + return builder().supportedContentTypes(supportedContentTypes).build(); } @ProcessElement @@ -50,10 +50,11 @@ public void processElement(ProcessContext c) { PubsubMessage message = c.element(); String eventType = message.getAttribute("eventType"); if (!Objects.equals(eventType, "OBJECT_FINALIZE")) { + // TODO: Output rejected messages to a queue (e.g. in a BigQuery table) LOG.warn("PubSub event type '{}' will not be processed", eventType); return; } - VisionAnalyticsPipeline.totalFiles.inc(); + AnnotationPipeline.totalFiles.inc(); String bucket = message.getAttribute("bucketId"); String object = message.getAttribute("objectId"); @@ -62,15 +63,15 @@ public void processElement(ProcessContext c) { String contentType = getContentType(message); - if (contentType != null && !supportedContentTypes().contains(contentType)) { - VisionAnalyticsPipeline.rejectedFiles.inc(); - LOG.warn("File {} is rejected - content type '{}' is not supported. " - + "Refer to https://cloud.google.com/vision/docs/supported-files for details.", - fileName, contentType); + if (contentType != null + && supportedContentTypes().stream().noneMatch(contentType::equalsIgnoreCase)) { + AnnotationPipeline.rejectedFiles.inc(); + // TODO: Output rejected files to a queue (e.g. in a BigQuery table) + LOG.warn("File {} is rejected - content type '{}' is not supported.", fileName, contentType); return; } - c.output(fileName); + c.output(new GCSFileInfo(fileName, contentType, getMetadata(message))); LOG.debug("GCS URI: {}", fileName); } @@ -91,8 +92,27 @@ private String getContentType(PubsubMessage message) { } } + /** + * Extract GCS object's metadata from PubSub payload + * + * @return metadata or null if none found. + */ + private Map getMetadata(PubsubMessage message) { + try { + ObjectMapper mapper = new ObjectMapper(); + JsonNode payloadJson = mapper.readTree(message.getPayload()); + JsonNode metadata = payloadJson.get("metadata"); + if (metadata != null) { + return mapper.convertValue(metadata, new TypeReference>() {}); + } + } catch (IOException e) { + LOG.warn("Failed to parse pubsub payload: ", e); + } + return new HashMap<>(); + } + public static Builder builder() { - return new AutoValue_PubSubNotificationToGCSUriDoFn.Builder(); + return new AutoValue_PubSubNotificationToGCSInfoDoFn.Builder(); } @AutoValue.Builder @@ -100,6 +120,6 @@ public abstract static class Builder { public abstract Builder supportedContentTypes(Set supportedContentTypes); - public abstract PubSubNotificationToGCSUriDoFn build(); + public abstract PubSubNotificationToGCSInfoDoFn build(); } } diff --git a/src/main/java/com/google/solutions/annotation/pubsub/WriteRelevantAnnotationsToPubSubTransform.java b/src/main/java/com/google/solutions/annotation/pubsub/WriteRelevantAnnotationsToPubSubTransform.java new file mode 100644 index 0000000..0c8392b --- /dev/null +++ b/src/main/java/com/google/solutions/annotation/pubsub/WriteRelevantAnnotationsToPubSubTransform.java @@ -0,0 +1,94 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.pubsub; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.auto.value.AutoValue; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Objects; +import org.apache.beam.sdk.io.gcp.bigquery.TableRowJsonCoder; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubIO; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessage; +import org.apache.beam.sdk.transforms.DoFn; +import org.apache.beam.sdk.transforms.PTransform; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.PCollection; +import org.apache.beam.sdk.values.PDone; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +@AutoValue +public abstract class WriteRelevantAnnotationsToPubSubTransform + extends PTransform>, PDone> { + + private static final long serialVersionUID = 1L; + + private static final Logger LOG = + LoggerFactory.getLogger(WriteRelevantAnnotationsToPubSubTransform.class); + + public abstract String topicId(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder setTopicId(String topic); + + public abstract WriteRelevantAnnotationsToPubSubTransform build(); + } + + public static Builder newBuilder() { + return new AutoValue_WriteRelevantAnnotationsToPubSubTransform.Builder(); + } + + @Override + public PDone expand(PCollection> input) { + + return input + .apply( + "ConvertToJSON", + ParDo.of( + new DoFn, String>() { + @ProcessElement + public void processContext(ProcessContext c) throws IOException { + String type = Objects.requireNonNull(c.element()).getKey(); + TableRow row = Objects.requireNonNull(c.element()).getValue(); + ByteArrayOutputStream jsonStream = new ByteArrayOutputStream(); + TableRowJsonCoder.of().encode(row, jsonStream); + String json = jsonStream.toString(StandardCharsets.UTF_8.name()); + + // TODO: Figure out why the json string has a leading "?" character + // Remove leading "?" character + json = json.substring(1); + + c.output(String.format("{\"type\": \"%s\", \"annotation\": %s}", type, json)); + } + })) + .apply( + "ConvertToPubSubMessage", + ParDo.of( + new DoFn() { + @ProcessElement + public void processContext(ProcessContext c) { + LOG.info("Json {}", c.element()); + c.output( + new PubsubMessage(Objects.requireNonNull(c.element()).getBytes(), null)); + } + })) + .apply("PublishToPubSub", PubsubIO.writeMessages().to(topicId())); + } +} diff --git a/src/main/java/com/google/solutions/ml/api/vision/AnnotateImagesDoFn.java b/src/main/java/com/google/solutions/ml/api/vision/AnnotateImagesDoFn.java deleted file mode 100644 index e7851ea..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/AnnotateImagesDoFn.java +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.solutions.ml.api.vision; - -import com.google.api.client.util.ExponentialBackOff; -import com.google.api.gax.rpc.ResourceExhaustedException; -import com.google.cloud.vision.v1.AnnotateImageRequest; -import com.google.cloud.vision.v1.AnnotateImageResponse; -import com.google.cloud.vision.v1.Feature; -import com.google.cloud.vision.v1.Image; -import com.google.cloud.vision.v1.ImageAnnotatorClient; -import com.google.cloud.vision.v1.ImageSource; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.TimeUnit; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.values.KV; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Calls Google Cloud Vision API to annotate a batch of GCS files. - * - * The GCS file URIs are provided in the incoming PCollection and should not exceed the limit - * imposed by the API (maximum of 16 images per request). - * - * The resulting PCollection contains key/value pair with the GCS file URI as the key and the API - * response as the value. - */ -public class AnnotateImagesDoFn extends DoFn, KV> { - - private static final long serialVersionUID = 1L; - - public static final Logger LOG = LoggerFactory.getLogger(AnnotateImagesDoFn.class); - - private final List featureList = new ArrayList<>(); - private ImageAnnotatorClient visionApiClient; - - public AnnotateImagesDoFn(List featureTypes) { - featureTypes.forEach( - type -> featureList.add(Feature.newBuilder().setType(type).build())); - } - - @Setup - public void setupAPIClient() { - try { - visionApiClient = ImageAnnotatorClient.create(); - } catch (IOException e) { - LOG.error("Failed to create Vision API Service Client: {}", e.getMessage()); - throw new RuntimeException(e); - } - } - - @Teardown - public void tearDownAPIClient() { - if (visionApiClient != null) { - visionApiClient.shutdownNow(); - try { - int waitTime = 10; - if (!visionApiClient.awaitTermination(waitTime, TimeUnit.SECONDS)) { - LOG.warn( - "Failed to shutdown the annotation client after {} seconds. Closing client anyway.", - waitTime); - } - } catch (InterruptedException e) { - // Do nothing - } - visionApiClient.close(); - } - } - - @ProcessElement - public void processElement(@Element Iterable imageFileURIs, - OutputReceiver> out) { - List requests = new ArrayList<>(); - - imageFileURIs.forEach( - imageUri -> { - Image image = - Image.newBuilder() - .setSource(ImageSource.newBuilder().setImageUri(imageUri).build()) - .build(); - AnnotateImageRequest.Builder request = - AnnotateImageRequest.newBuilder().setImage(image).addAllFeatures(featureList); - requests.add(request.build()); - }); - - List responses; - - ExponentialBackOff backoff = new ExponentialBackOff.Builder() - .setInitialIntervalMillis(10 * 1000 /* 10 seconds */) - .setMaxElapsedTimeMillis(10 * 60 * 1000 /* 10 minutes */) - .setMaxIntervalMillis(90 * 1000 /* 90 seconds */) - .setMultiplier(1.5) - .setRandomizationFactor(0.5) - .build(); - while (true) { - try { - VisionAnalyticsPipeline.numberOfRequests.inc(); - responses = visionApiClient.batchAnnotateImages(requests).getResponsesList(); - break; - } catch (ResourceExhaustedException e) { - handleQuotaReachedException(backoff, e); - } - } - - int index = 0; - for (AnnotateImageResponse response : responses) { - String imageUri = requests.get(index++).getImage().getSource().getImageUri(); - out.output(KV.of(imageUri, response)); - } - } - - - /** - * Attempts to backoff unless reaches the max elapsed time. - * - * @param backoff - * @param e - */ - void handleQuotaReachedException(ExponentialBackOff backoff, ResourceExhaustedException e) { - VisionAnalyticsPipeline.numberOfQuotaExceededRequests.inc(); - long waitInMillis = 0; - try { - waitInMillis = backoff.nextBackOffMillis(); - } catch (IOException ioException) { - // Will not occur with this implementation of Backoff. - } - if (waitInMillis == ExponentialBackOff.STOP) { - LOG.warn("Reached the limit of backoff retries. Throwing the exception to the pipeline"); - throw e; - } - LOG.info("Received {}. Will retry in {} seconds.", e.getClass().getName(), - waitInMillis / 1000); - try { - TimeUnit.MILLISECONDS.sleep(waitInMillis); - } catch (InterruptedException interruptedException) { - // Do nothing - } - } -} diff --git a/src/main/java/com/google/solutions/ml/api/vision/ProcessImageResponseDoFn.java b/src/main/java/com/google/solutions/ml/api/vision/ProcessImageResponseDoFn.java deleted file mode 100644 index 01edd09..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/ProcessImageResponseDoFn.java +++ /dev/null @@ -1,86 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.solutions.ml.api.vision; - -import com.google.api.services.bigquery.model.TableRow; -import com.google.auto.value.AutoValue; -import com.google.cloud.vision.v1.AnnotateImageResponse; -import com.google.solutions.ml.api.vision.processor.AnnotateImageResponseProcessor; -import java.util.Collection; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.values.KV; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * ProcessImageResponse {@link ProcessImageResponseDoFn} class parses the image response for - * specific annotation and using image response builder output the table and table row for BigQuery - */ -@AutoValue -abstract public class ProcessImageResponseDoFn - extends DoFn, KV> { - - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(ProcessImageResponseDoFn.class); - - abstract Collection processors(); - - abstract Counter processedFileCounter(); - - public static ProcessImageResponseDoFn create( - Collection processors) { - return builder() - .processors(processors) - .processedFileCounter(Metrics - .counter(ProcessImageResponseDoFn.class, "processedFiles")) - .build(); - } - - @ProcessElement - public void processElement(@Element KV element, - OutputReceiver> out) { - String imageFileURI = element.getKey(); - AnnotateImageResponse annotationResponse = element.getValue(); - - LOG.debug("Processing annotations for file: {}", imageFileURI); - processedFileCounter().inc(); - - processors().forEach(processor -> { - Iterable> processingResult = processor - .process(imageFileURI, annotationResponse); - if (processingResult != null) { - processingResult.forEach(out::output); - } - }); - } - - public static Builder builder() { - return new AutoValue_ProcessImageResponseDoFn.Builder(); - } - - - @AutoValue.Builder - public abstract static class Builder { - - public abstract Builder processors(Collection processors); - - public abstract Builder processedFileCounter(Counter processedFileCounter); - - public abstract ProcessImageResponseDoFn build(); - } -} diff --git a/src/main/java/com/google/solutions/ml/api/vision/VisionAnalyticsPipeline.java b/src/main/java/com/google/solutions/ml/api/vision/VisionAnalyticsPipeline.java deleted file mode 100644 index e14c93f..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/VisionAnalyticsPipeline.java +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.solutions.ml.api.vision; - - -import com.google.api.services.bigquery.model.TableFieldSchema; -import com.google.api.services.bigquery.model.TableReference; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import com.google.cloud.vision.v1.AnnotateImageResponse; -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Iterables; -import com.google.solutions.ml.api.vision.BigQueryConstants.Mode; -import com.google.solutions.ml.api.vision.BigQueryConstants.Type; -import com.google.solutions.ml.api.vision.processor.AnnotateImageResponseProcessor; -import com.google.solutions.ml.api.vision.processor.CropHintAnnotationProcessor; -import com.google.solutions.ml.api.vision.processor.ErrorProcessor; -import com.google.solutions.ml.api.vision.processor.FaceAnnotationProcessor; -import com.google.solutions.ml.api.vision.processor.ImagePropertiesProcessor; -import com.google.solutions.ml.api.vision.processor.LabelAnnotationProcessor; -import com.google.solutions.ml.api.vision.processor.LandmarkAnnotationProcessor; -import com.google.solutions.ml.api.vision.processor.LogoAnnotationProcessor; -import com.google.solutions.ml.api.vision.processor.ProcessorUtils; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; -import org.apache.beam.sdk.Pipeline; -import org.apache.beam.sdk.PipelineResult; -import org.apache.beam.sdk.io.FileIO; -import org.apache.beam.sdk.io.fs.MatchResult.Metadata; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO; -import org.apache.beam.sdk.io.gcp.bigquery.BigQueryIO.Write.WriteDisposition; -import org.apache.beam.sdk.io.gcp.pubsub.PubsubIO; -import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessage; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Distribution; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.options.PipelineOptionsFactory; -import org.apache.beam.sdk.transforms.Create; -import org.apache.beam.sdk.transforms.DoFn; -import org.apache.beam.sdk.transforms.Filter; -import org.apache.beam.sdk.transforms.ParDo; -import org.apache.beam.sdk.transforms.SerializableFunction; -import org.apache.beam.sdk.transforms.windowing.AfterWatermark; -import org.apache.beam.sdk.transforms.windowing.BoundedWindow; -import org.apache.beam.sdk.transforms.windowing.FixedWindows; -import org.apache.beam.sdk.transforms.windowing.Window; -import org.apache.beam.sdk.values.KV; -import org.apache.beam.sdk.values.PCollection; -import org.joda.time.Duration; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Main class for the vision analytics processing. - */ -public class VisionAnalyticsPipeline { - - public static final Logger LOG = LoggerFactory.getLogger(VisionAnalyticsPipeline.class); - - public static final Counter totalFiles = Metrics - .counter(VisionAnalyticsPipeline.class, "totalFiles"); - public static final Counter rejectedFiles = Metrics - .counter(VisionAnalyticsPipeline.class, "rejectedFiles"); - public static final Counter numberOfRequests = Metrics - .counter(VisionAnalyticsPipeline.class, "numberOfRequests"); - public static final Counter numberOfQuotaExceededRequests = Metrics - .counter(VisionAnalyticsPipeline.class, "numberOfQuotaExceededRequests"); - - public static final Distribution batchSizeDistribution = Metrics - .distribution(VisionAnalyticsPipeline.class, "batchSizeDistribution"); - - - private static final Set SUPPORTED_CONTENT_TYPES = ImmutableSet.of( - "image/jpeg", "image/png", "image/tiff", "image/tif", "image/gif" - ); - - private static final String ACCEPTED_FILE_PATTERN = "(^.*\\.(JPEG|jpeg|JPG|jpg|PNG|png|GIF|gif|TIFF|tiff|TIF|tif)$)"; - - /** - * Main entry point for executing the pipeline. This will run the pipeline asynchronously. If - * blocking execution is required, use the {@link VisionAnalyticsPipeline#run(VisionAnalyticsPipelineOptions)} - * method to start the pipeline and invoke {@code result.waitUntilFinish()} on the {@link - * PipelineResult} - * - * @param args The command-line arguments to the pipeline. - */ - public static void main(String[] args) { - - VisionAnalyticsPipelineOptions options = - PipelineOptionsFactory.fromArgs(args) - .withValidation() - .as(VisionAnalyticsPipelineOptions.class); - - run(options); - } - - /** - * Runs the pipeline - * - * @return result - */ - public static PipelineResult run(VisionAnalyticsPipelineOptions options) { - Pipeline p = Pipeline.create(options); - - PCollection imageFileUris; - if (options.getSubscriberId() != null) { - imageFileUris = convertPubSubNotificationsToGCSURIs(p, options); - } else if (options.getFileList() != null) { - imageFileUris = listGCSFiles(p, options); - } else { - throw new RuntimeException("Either the subscriber id or the file list should be provided."); - } - - PCollection> batchedImageURIs = imageFileUris - .apply("Batch images", - BatchRequestsTransform.create(options.getBatchSize(), options.getKeyRange())); - - PCollection> annotatedImages = - options.isSimulate() ? - batchedImageURIs.apply("Simulate Annotation", - ParDo.of(new AnnotateImagesSimulatorDoFn(options.getFeatures()))) : - batchedImageURIs.apply( - "Annotate Images", - ParDo.of(new AnnotateImagesDoFn(options.getFeatures()))); - - Map processors = configureProcessors(options); - - PCollection> annotationOutcome = - annotatedImages.apply( - "Process Annotations", - ParDo.of(ProcessImageResponseDoFn.create(ImmutableSet.copyOf(processors.values())))); - - annotationOutcome.apply("Write To BigQuery", new BigQueryDynamicWriteTransform( - BQDynamicDestinations.builder() - .projectId(options.getVisionApiProjectId()) - .datasetId(options.getDatasetName()) - .tableNameToTableDetailsMap( - tableNameToTableDetailsMap(processors)).build()) - ); - - collectBatchStatistics(batchedImageURIs, options); - - return p.run(); - } - - /** - * Collect the statistics on batching the requests. The results are published to a metric. If - * {@link VisionAnalyticsPipelineOptions#isCollectBatchData()} is true the batch data is saved to - * BigQuery table "batch_info". - */ - static void collectBatchStatistics(PCollection> batchedImageURIs, - VisionAnalyticsPipelineOptions options) { - - PCollection batchInfo = batchedImageURIs - .apply("Collect Batch Stats", ParDo.of(new DoFn, TableRow>() { - private static final long serialVersionUID = 1L; - - @ProcessElement - public void processElement(@Element Iterable element, BoundedWindow window, - OutputReceiver out, ProcessContext context) { - int size = Iterables.size(element); - batchSizeDistribution.update(size); - if (context.getPipelineOptions().as(VisionAnalyticsPipelineOptions.class) - .isCollectBatchData()) { - TableRow row = new TableRow(); - row.put("window", window.toString()); - row.put("timestamp", ProcessorUtils.getTimeStamp()); - row.put("size", size); - List items = new ArrayList<>(); - element.forEach(items::add); - row.put("items", items); - - out.output(row); - } - } - })); - if (!options.isCollectBatchData()) { - return; - } - batchInfo.apply( - BigQueryIO.writeTableRows() - .to(new TableReference().setProjectId(options.getVisionApiProjectId()) - .setDatasetId(options.getDatasetName()).setTableId("batch_info")) - .withWriteDisposition(WriteDisposition.WRITE_APPEND) - .withoutValidation() - .withClustering() - .ignoreInsertIds() - .withSchema(new TableSchema().setFields(ImmutableList.of( - new TableFieldSchema().setName("window").setType(Type.STRING), - new TableFieldSchema().setName("timestamp").setType(Type.TIMESTAMP), - new TableFieldSchema().setName("size").setType(Type.NUMERIC), - new TableFieldSchema().setName("items").setType(Type.STRING).setMode(Mode.REPEATED) - ))) - .withCreateDisposition(BigQueryIO.Write.CreateDisposition.CREATE_IF_NEEDED)); - } - - /** - * Create a map of the table details. Each processor will produce TableRows destined - * to a different table. Each processor will provide the details about that table. - * - * @return map of table details keyed by table name - */ - static Map tableNameToTableDetailsMap( - Map processors) { - Map tableNameToTableDetailsMap = new HashMap<>(); - processors.forEach( - (tableName, processor) -> tableNameToTableDetailsMap - .put(tableName, processor.destinationTableDetails())); - return tableNameToTableDetailsMap; - } - - /** - * Reads PubSub messages from the subscription provided by {@link VisionAnalyticsPipelineOptions#getSubscriberId()}. - * - * The messages are expected to confirm to the GCS notification message format defined in - * https://cloud.google.com/storage/docs/pubsub-notifications - * - * Notifications are filtered to have one of the supported content types: {@link - * VisionAnalyticsPipeline#SUPPORTED_CONTENT_TYPES}. - * - * @return PCollection of GCS URIs - */ - static PCollection convertPubSubNotificationsToGCSURIs( - Pipeline p, VisionAnalyticsPipelineOptions options) { - PCollection imageFileUris; - PCollection pubSubNotifications = p.begin().apply("Read PubSub", - PubsubIO.readMessagesWithAttributes().fromSubscription(options.getSubscriberId())); - imageFileUris = pubSubNotifications - .apply("PubSub to GCS URIs", - ParDo.of(PubSubNotificationToGCSUriDoFn.create(SUPPORTED_CONTENT_TYPES))) - .apply( - "Fixed Window", - Window.into( - FixedWindows.of(Duration.standardSeconds(options.getWindowInterval()))) - .triggering(AfterWatermark.pastEndOfWindow()) - .discardingFiredPanes() - .withAllowedLateness(Duration.standardMinutes(15))); - return imageFileUris; - } - - /** - * Reads the GCS buckets provided by {@link VisionAnalyticsPipelineOptions#getFileList()}. - * - * The file list can contain multiple entries. Each entry can contain wildcards supported by - * {@link FileIO#matchAll()}. - * - * Files are filtered based on their suffixes as defined in {@link VisionAnalyticsPipeline#ACCEPTED_FILE_PATTERN}. - * - * @return PCollection of GCS URIs - */ - static PCollection listGCSFiles(Pipeline p, VisionAnalyticsPipelineOptions options) { - PCollection imageFileUris; - PCollection allFiles = p.begin() - .apply("Get File List", Create.of(options.getFileList())) - .apply("Match GCS Files", FileIO.matchAll()); - imageFileUris = allFiles.apply(ParDo.of(new DoFn() { - private static final long serialVersionUID = 1L; - - @ProcessElement - public void processElement(@Element Metadata metadata, OutputReceiver out) { - out.output(metadata.resourceId().toString()); - } - })) - .apply("Filter out non-image files", - Filter.by((SerializableFunction) fileName -> { - totalFiles.inc(); - if (fileName.matches(ACCEPTED_FILE_PATTERN)) { - return true; - } - LOG.warn("File {} does not contain a valid extension", fileName); - rejectedFiles.inc(); - return false; - })); - return imageFileUris; - } - - /** - * Creates a map of well-known {@link AnnotateImageResponseProcessor}s. - * - * If additional processors are needed they should be configured in this method. - */ - private static Map configureProcessors( - VisionAnalyticsPipelineOptions options) { - Map result = new HashMap<>(); - - String tableName = options.getLabelAnnotationTable(); - result.put(tableName, new LabelAnnotationProcessor(tableName)); - - tableName = options.getLandmarkAnnotationTable(); - result.put(tableName, new LandmarkAnnotationProcessor(tableName)); - - tableName = options.getLogoAnnotationTable(); - result.put(tableName, new LogoAnnotationProcessor(tableName)); - - tableName = options.getFaceAnnotationTable(); - result.put(tableName, new FaceAnnotationProcessor(tableName)); - - tableName = options.getImagePropertiesTable(); - result.put(tableName, new ImagePropertiesProcessor(tableName)); - - tableName = options.getCropHintAnnotationTable(); - result.put(tableName, new CropHintAnnotationProcessor(tableName)); - - tableName = options.getErrorLogTable(); - result.put(tableName, new ErrorProcessor(tableName)); - - return result; - } -} diff --git a/src/main/java/com/google/solutions/ml/api/vision/VisionAnalyticsPipelineOptions.java b/src/main/java/com/google/solutions/ml/api/vision/VisionAnalyticsPipelineOptions.java deleted file mode 100644 index dcfee64..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/VisionAnalyticsPipelineOptions.java +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package com.google.solutions.ml.api.vision; - -import com.google.cloud.vision.v1.Feature; -import java.util.List; -import org.apache.beam.runners.dataflow.options.DataflowPipelineOptions; -import org.apache.beam.sdk.options.Default; -import org.apache.beam.sdk.options.Description; -import org.apache.beam.sdk.options.Validation; - -/** - * Interface to store pipeline options provided by the user - */ -public interface VisionAnalyticsPipelineOptions extends DataflowPipelineOptions { - - @Description("Pub/Sub subscription to receive messages from") - String getSubscriberId(); - - void setSubscriberId(String value); - - @Description("Google Cloud Storage files to process") - List getFileList(); - - void setFileList(List value); - - @Description("Key range") - @Default.Integer(1) - Integer getKeyRange(); - - void setKeyRange(Integer value); - - @Description("Image annotation request batch size") - @Default.Integer(1) - Integer getBatchSize(); - - void setBatchSize(Integer value); - - @Description("Window interval in seconds (default is 5)") - @Default.Integer(5) - Integer getWindowInterval(); - - void setWindowInterval(Integer value); - - @Description("BigQuery dataset") - @Validation.Required - String getDatasetName(); - - void setDatasetName(String value); - - @Description("Project id to be used for Vision API requests and BigQuery dataset") - @Validation.Required - String getVisionApiProjectId(); - - void setVisionApiProjectId(String value); - - @Description("Vision API features to use") - @Validation.Required - List getFeatures(); - - void setFeatures(List value); - - @Description("Simulate annotations") - @Default.Boolean(false) - boolean isSimulate(); - - void setSimulate(boolean value); - - @Description("Collect batch data") - @Default.Boolean(false) - boolean isCollectBatchData(); - - void setCollectBatchData(boolean value); - - @Description("Table name for label annotations") - @Default.String("label_annotation") - String getLabelAnnotationTable(); - - void setLabelAnnotationTable(String value); - - @Description("Table name for landmark annotations") - @Default.String("landmark_annotation") - String getLandmarkAnnotationTable(); - - void setLandmarkAnnotationTable(String value); - - @Description("Table name for logo annotations") - @Default.String("logo_annotation") - String getLogoAnnotationTable(); - - void setLogoAnnotationTable(String value); - - @Description("Table name for face annotations") - @Default.String("face_annotation") - String getFaceAnnotationTable(); - - void setFaceAnnotationTable(String value); - - @Description("Table name for image properties") - @Default.String("image_properties") - String getImagePropertiesTable(); - - void setImagePropertiesTable(String value); - - @Description("Table name for crop hint annotations") - @Default.String("crop_hint_annotation") - String getCropHintAnnotationTable(); - - void setCropHintAnnotationTable(String value); - - @Description("Table name for error logs") - @Default.String("error_log") - String getErrorLogTable(); - - void setErrorLogTable(String value); -} diff --git a/src/main/java/com/google/solutions/ml/api/vision/processor/AnnotateImageResponseProcessor.java b/src/main/java/com/google/solutions/ml/api/vision/processor/AnnotateImageResponseProcessor.java deleted file mode 100644 index 493baeb..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/processor/AnnotateImageResponseProcessor.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.solutions.ml.api.vision.processor; - -import com.google.api.services.bigquery.model.TableRow; -import com.google.cloud.vision.v1.AnnotateImageResponse; -import com.google.solutions.ml.api.vision.BQDestination; -import com.google.solutions.ml.api.vision.TableDetails; -import java.io.Serializable; -import org.apache.beam.sdk.values.KV; - -/** - * Implementors of this interface will process zero to many TableRows to persist to a specific - * BigTable table. - */ -public interface AnnotateImageResponseProcessor extends Serializable { - - /** - * @param gcsURI annotation source - * @param response from Google Cloud Vision API - * @return key/value pair of a BigQuery destination and a TableRow to persist. - */ - Iterable> process(String gcsURI, AnnotateImageResponse response); - - /** - * @return details of the table to persist to. - */ - TableDetails destinationTableDetails(); -} \ No newline at end of file diff --git a/src/main/java/com/google/solutions/ml/api/vision/processor/CropHintAnnotationProcessor.java b/src/main/java/com/google/solutions/ml/api/vision/processor/CropHintAnnotationProcessor.java deleted file mode 100644 index e873ebf..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/processor/CropHintAnnotationProcessor.java +++ /dev/null @@ -1,135 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.solutions.ml.api.vision.processor; - -import com.google.api.services.bigquery.model.Clustering; -import com.google.api.services.bigquery.model.TableFieldSchema; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import com.google.api.services.bigquery.model.TimePartitioning; -import com.google.cloud.vision.v1.AnnotateImageResponse; -import com.google.cloud.vision.v1.CropHint; -import com.google.cloud.vision.v1.CropHintsAnnotation; -import com.google.common.collect.ImmutableList; -import com.google.solutions.ml.api.vision.BQDestination; -import com.google.solutions.ml.api.vision.BigQueryConstants.Mode; -import com.google.solutions.ml.api.vision.BigQueryConstants.Type; -import com.google.solutions.ml.api.vision.TableDetails; -import com.google.solutions.ml.api.vision.TableSchemaProducer; -import com.google.solutions.ml.api.vision.processor.Constants.Field; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.values.KV; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Extracts crop hint annotations (https://cloud.google.com/vision/docs/detecting-crop-hints) - * - * Note: requests for either CROP_HINT feature or IMAGE_PROPERTIES feature will produce crop hints - */ -public class CropHintAnnotationProcessor implements AnnotateImageResponseProcessor { - - private static final long serialVersionUID = 1L; - - private static final Logger LOG = LoggerFactory.getLogger(CropHintAnnotationProcessor.class); - public final static Counter counter = - Metrics.counter(AnnotateImageResponseProcessor.class, "numberOfCropHintAnnotations"); - - private static class SchemaProducer implements TableSchemaProducer { - - private static final long serialVersionUID = 1L; - - @Override - public TableSchema getTableSchema() { - return new TableSchema().setFields( - ImmutableList.of( - new TableFieldSchema() - .setName(Field.GCS_URI_FIELD) - .setType(Type.STRING) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.CROP_HINTS).setType(Type.RECORD) - .setMode(Mode.REPEATED) - .setFields(ImmutableList.of( - new TableFieldSchema() - .setName(Field.CONFIDENCE).setType(Type.FLOAT) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.IMPORTANCE_FRACTION).setType(Type.FLOAT) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.BOUNDING_POLY).setType(Type.RECORD) - .setMode(Mode.REQUIRED).setFields(Constants.POLYGON_FIELDS) - )), - new TableFieldSchema() - .setName(Field.TIMESTAMP_FIELD).setType(Type.TIMESTAMP) - .setMode(Mode.REQUIRED)) - ); - } - } - - @Override - public TableDetails destinationTableDetails() { - return TableDetails.create("Google Vision API Crop Hint Annotations", - new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), - new TimePartitioning().setField(Field.TIMESTAMP_FIELD), new SchemaProducer()); - } - - private final BQDestination destination; - - /** - * Creates a processor and specifies the table id to persist to. - */ - public CropHintAnnotationProcessor(String tableId) { - destination = new BQDestination(tableId); - } - - @Override - public Iterable> process( - String gcsURI, AnnotateImageResponse response) { - CropHintsAnnotation cropHintsAnnotation = response.getCropHintsAnnotation(); - if (cropHintsAnnotation == null) { - return null; - } - int cropHintsCount = cropHintsAnnotation.getCropHintsCount(); - if (cropHintsCount == 0) { - return null; - } - - counter.inc(); - - List cropHintRows = new ArrayList<>(cropHintsCount); - for (CropHint cropHint : cropHintsAnnotation.getCropHintsList()) { - TableRow cropHintRow = new TableRow(); - cropHintRow.put(Field.BOUNDING_POLY, - ProcessorUtils.getBoundingPolyAsRow(cropHint.getBoundingPoly())); - cropHintRow.put(Field.CONFIDENCE, cropHint.getConfidence()); - cropHintRow.put(Field.IMPORTANCE_FRACTION, cropHint.getImportanceFraction()); - - cropHintRows.add(cropHintRow); - } - - TableRow result = ProcessorUtils.startRow(gcsURI); - result.put(Field.CROP_HINTS, cropHintRows); - LOG.debug("Processing {}", result); - return Collections.singletonList((KV.of(destination, result))); - } -} diff --git a/src/main/java/com/google/solutions/ml/api/vision/processor/ErrorProcessor.java b/src/main/java/com/google/solutions/ml/api/vision/processor/ErrorProcessor.java deleted file mode 100644 index c3ba474..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/processor/ErrorProcessor.java +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.solutions.ml.api.vision.processor; - -import com.google.api.services.bigquery.model.Clustering; -import com.google.api.services.bigquery.model.TableFieldSchema; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import com.google.api.services.bigquery.model.TimePartitioning; -import com.google.cloud.vision.v1.AnnotateImageResponse; -import com.google.common.collect.ImmutableList; -import com.google.solutions.ml.api.vision.BQDestination; -import com.google.solutions.ml.api.vision.BigQueryConstants.Mode; -import com.google.solutions.ml.api.vision.BigQueryConstants.Type; -import com.google.solutions.ml.api.vision.TableDetails; -import com.google.solutions.ml.api.vision.TableSchemaProducer; -import com.google.solutions.ml.api.vision.processor.Constants.Field; -import java.util.Collections; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.values.KV; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Captures the error occurred during processing. Note, that there could be some valid annotations - * returned in the response even though the response contains an error. - */ -public class ErrorProcessor implements AnnotateImageResponseProcessor { - - private static final long serialVersionUID = 1L; - - public final static Counter counter = - Metrics.counter(AnnotateImageResponseProcessor.class, "numberOfErrors"); - public static final Logger LOG = LoggerFactory.getLogger(ErrorProcessor.class); - - private static class SchemaProducer implements TableSchemaProducer { - - private static final long serialVersionUID = 1L; - - @Override - public TableSchema getTableSchema() { - return new TableSchema().setFields( - ImmutableList.of( - new TableFieldSchema() - .setName(Field.GCS_URI_FIELD) - .setType(Type.STRING) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.DESCRIPTION_FIELD).setType(Type.STRING) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.STACK_TRACE).setType(Type.STRING) - .setMode(Mode.NULLABLE), - new TableFieldSchema() - .setName(Field.TIMESTAMP_FIELD).setType(Type.TIMESTAMP) - .setMode(Mode.REQUIRED)) - ); - } - } - - @Override - public TableDetails destinationTableDetails() { - return TableDetails.create("Google Vision API Processing Errors", - new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), - new TimePartitioning().setField(Field.TIMESTAMP_FIELD), new SchemaProducer()); - } - - private final BQDestination destination; - - /** - * Creates a processor and specifies the table id to persist to. - */ - public ErrorProcessor(String tableId) { - destination = new BQDestination(tableId); - } - - @Override - public Iterable> process( - String gcsURI, AnnotateImageResponse response) { - if (!response.hasError()) { - return null; - } - - counter.inc(); - - TableRow result = ProcessorUtils.startRow(gcsURI); - result.put(Field.DESCRIPTION_FIELD, response.getError().toString()); - - LOG.debug("Processing {}", result); - - return Collections.singletonList(KV.of(destination, result)); - } -} diff --git a/src/main/java/com/google/solutions/ml/api/vision/processor/FaceAnnotationProcessor.java b/src/main/java/com/google/solutions/ml/api/vision/processor/FaceAnnotationProcessor.java deleted file mode 100644 index bf41743..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/processor/FaceAnnotationProcessor.java +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.solutions.ml.api.vision.processor; - -import com.google.api.services.bigquery.model.Clustering; -import com.google.api.services.bigquery.model.TableFieldSchema; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import com.google.api.services.bigquery.model.TimePartitioning; -import com.google.cloud.vision.v1.AnnotateImageResponse; -import com.google.cloud.vision.v1.FaceAnnotation; -import com.google.cloud.vision.v1.Position; -import com.google.common.collect.ImmutableList; -import com.google.solutions.ml.api.vision.BQDestination; -import com.google.solutions.ml.api.vision.BigQueryConstants.Mode; -import com.google.solutions.ml.api.vision.BigQueryConstants.Type; -import com.google.solutions.ml.api.vision.TableDetails; -import com.google.solutions.ml.api.vision.TableSchemaProducer; -import com.google.solutions.ml.api.vision.processor.Constants.Field; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.values.KV; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Extracts face annotations (https://cloud.google.com/vision/docs/detecting-faces) - */ -public class FaceAnnotationProcessor implements AnnotateImageResponseProcessor { - - private static final long serialVersionUID = 1L; - - private static final Logger LOG = LoggerFactory.getLogger(FaceAnnotationProcessor.class); - public final static Counter counter = - Metrics.counter(AnnotateImageResponseProcessor.class, "numberOfFaceAnnotations"); - - /** - * The schema doesn't represent the complete list of all attributes returned by the APIs. For more - * details see https://cloud.google.com/vision/docs/reference/rest/v1/AnnotateImageResponse?hl=pl#FaceAnnotation - */ - private static class SchemaProducer implements TableSchemaProducer { - - private static final long serialVersionUID = 1L; - - @Override - public TableSchema getTableSchema() { - return new TableSchema().setFields( - ImmutableList.of( - new TableFieldSchema() - .setName(Field.GCS_URI_FIELD).setType(Type.STRING).setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.BOUNDING_POLY).setType(Type.RECORD) - .setMode(Mode.REQUIRED).setFields(Constants.POLYGON_FIELDS), - new TableFieldSchema() - .setName(Field.FD_BOUNDING_POLY).setType(Type.RECORD) - .setMode(Mode.REQUIRED).setFields(Constants.POLYGON_FIELDS), - new TableFieldSchema() - .setName(Field.LANDMARKS).setType(Type.RECORD).setMode(Mode.REPEATED).setFields( - Arrays.asList( - new TableFieldSchema().setName(Field.FACE_LANDMARK_TYPE).setType(Type.STRING) - .setMode(Mode.REQUIRED), - new TableFieldSchema().setName(Field.FACE_LANDMARK_POSITION).setType(Type.RECORD) - .setMode(Mode.REQUIRED).setFields(Constants.POSITION_FIELDS) - ) - ), - new TableFieldSchema() - .setName(Field.DETECTION_CONFIDENCE).setType(Type.FLOAT).setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.LANDMARKING_CONFIDENCE).setType(Type.FLOAT).setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.JOY_LIKELIHOOD).setType(Type.STRING).setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.SORROW_LIKELIHOOD).setType(Type.STRING).setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.ANGER_LIKELIHOOD).setType(Type.STRING).setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.SURPISE_LIKELIHOOD).setType(Type.STRING).setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.TIMESTAMP_FIELD).setType(Type.TIMESTAMP) - .setMode(Mode.REQUIRED)) - ); - } - } - - @Override - public TableDetails destinationTableDetails() { - return TableDetails.create("Google Vision API Face Annotations", - new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), - new TimePartitioning().setField(Field.TIMESTAMP_FIELD), new SchemaProducer()); - } - - private final BQDestination destination; - - /** - * Creates a processor and specifies the table id to persist to. - */ - public FaceAnnotationProcessor(String tableId) { - destination = new BQDestination(tableId); - } - - @Override - public Iterable> process( - String gcsURI, AnnotateImageResponse response) { - int numberOfAnnotations = response.getFaceAnnotationsCount(); - if (numberOfAnnotations == 0) { - return null; - } - - counter.inc(numberOfAnnotations); - - Collection> result = new ArrayList<>(numberOfAnnotations); - for (FaceAnnotation annotation : response.getFaceAnnotationsList()) { - TableRow row = ProcessorUtils.startRow(gcsURI); - - row.put(Field.BOUNDING_POLY, - ProcessorUtils.getBoundingPolyAsRow(annotation.getBoundingPoly())); - row.put(Field.FD_BOUNDING_POLY, - ProcessorUtils.getBoundingPolyAsRow(annotation.getFdBoundingPoly())); - List landmarks = new ArrayList<>(annotation.getLandmarksCount()); - annotation.getLandmarksList().forEach( - landmark -> { - TableRow landmarkRow = new TableRow(); - landmarkRow.put(Field.FACE_LANDMARK_TYPE, landmark.getType().toString()); - - Position position = landmark.getPosition(); - TableRow positionRow = new TableRow(); - positionRow.put(Field.VERTEX_X, position.getX()); - positionRow.put(Field.VERTEX_Y, position.getY()); - positionRow.put(Field.VERTEX_Z, position.getZ()); - landmarkRow.put(Field.FACE_LANDMARK_POSITION, positionRow); - - landmarks.add(landmarkRow); - } - ); - row.put(Field.LANDMARKS, landmarks); - row.put(Field.DETECTION_CONFIDENCE, annotation.getDetectionConfidence()); - row.put(Field.LANDMARKING_CONFIDENCE, annotation.getLandmarkingConfidence()); - row.put(Field.JOY_LIKELIHOOD, annotation.getJoyLikelihood().toString()); - row.put(Field.SORROW_LIKELIHOOD, annotation.getSorrowLikelihood().toString()); - row.put(Field.ANGER_LIKELIHOOD, annotation.getAngerLikelihood().toString()); - row.put(Field.SURPISE_LIKELIHOOD, annotation.getSurpriseLikelihood().toString()); - - LOG.debug("Processing {}", row); - result.add(KV.of(destination, row)); - } - - return result; - } - -} diff --git a/src/main/java/com/google/solutions/ml/api/vision/processor/ImagePropertiesProcessor.java b/src/main/java/com/google/solutions/ml/api/vision/processor/ImagePropertiesProcessor.java deleted file mode 100644 index 88aa7f4..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/processor/ImagePropertiesProcessor.java +++ /dev/null @@ -1,170 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.solutions.ml.api.vision.processor; - -import com.google.api.services.bigquery.model.Clustering; -import com.google.api.services.bigquery.model.TableFieldSchema; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import com.google.api.services.bigquery.model.TimePartitioning; -import com.google.cloud.vision.v1.AnnotateImageResponse; -import com.google.cloud.vision.v1.DominantColorsAnnotation; -import com.google.cloud.vision.v1.ImageProperties; -import com.google.common.collect.ImmutableList; -import com.google.solutions.ml.api.vision.BQDestination; -import com.google.solutions.ml.api.vision.BigQueryConstants.Mode; -import com.google.solutions.ml.api.vision.BigQueryConstants.Type; -import com.google.solutions.ml.api.vision.TableDetails; -import com.google.solutions.ml.api.vision.TableSchemaProducer; -import com.google.solutions.ml.api.vision.processor.Constants.Field; -import com.google.type.Color; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.values.KV; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Extracts image properties (https://cloud.google.com/vision/docs/detecting-properties) - */ -public class ImagePropertiesProcessor implements AnnotateImageResponseProcessor { - - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory - .getLogger(ImagePropertiesProcessor.class); - - public final static Counter counter = - Metrics.counter(AnnotateImageResponseProcessor.class, "numberOfImagePropertiesAnnotations"); - - private static class SchemaProducer implements TableSchemaProducer { - - private static final long serialVersionUID = 1L; - - @Override - public TableSchema getTableSchema() { - return new TableSchema().setFields( - ImmutableList.of( - new TableFieldSchema() - .setName(Field.GCS_URI_FIELD) - .setType(Type.STRING) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.DOMINANT_COLORS).setType(Type.RECORD) - .setMode(Mode.REQUIRED) - .setFields(ImmutableList.of( - new TableFieldSchema() - .setName(Field.COLORS).setType(Type.RECORD) - .setMode(Mode.REPEATED) - .setFields(ImmutableList.of( - new TableFieldSchema() - .setName(Field.SCORE_FIELD) - .setType(Type.FLOAT) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.PIXEL_FRACTION) - .setType(Type.FLOAT) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.COLOR) - .setType(Type.RECORD) - .setMode(Mode.REQUIRED) - .setFields(ImmutableList.of( - new TableFieldSchema() - .setName(Field.COLOR_RED) - .setType(Type.FLOAT) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.COLOR_BLUE) - .setType(Type.FLOAT) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.COLOR_GREEN) - .setType(Type.FLOAT) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.COLOR_ALPHA) - .setType(Type.FLOAT) - .setMode(Mode.NULLABLE) - )) - )) - )), - new TableFieldSchema() - .setName(Field.TIMESTAMP_FIELD).setType(Type.TIMESTAMP) - .setMode(Mode.REQUIRED)) - ); - } - } - - @Override - public TableDetails destinationTableDetails() { - return TableDetails.create("Google Vision API Image Properties", - new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), - new TimePartitioning().setField(Field.TIMESTAMP_FIELD), new SchemaProducer()); - } - - private final BQDestination destination; - - /** - * Creates a processor and specifies the table id to persist to. - */ - public ImagePropertiesProcessor(String tableId) { - destination = new BQDestination(tableId); - } - - @Override - public Iterable> process( - String gcsURI, AnnotateImageResponse response) { - ImageProperties imageProperties = response.getImagePropertiesAnnotation(); - if (imageProperties == null || !imageProperties.hasDominantColors()) { - return null; - } - - counter.inc(); - - TableRow result = ProcessorUtils.startRow(gcsURI); - - DominantColorsAnnotation dominantColors = imageProperties.getDominantColors(); - List colors = new ArrayList<>(dominantColors.getColorsCount()); - dominantColors.getColorsList().forEach( - colorInfo -> { - TableRow colorInfoRow = new TableRow(); - colorInfoRow.put(Field.SCORE_FIELD, colorInfo.getScore()); - colorInfoRow.put(Field.PIXEL_FRACTION, colorInfo.getPixelFraction()); - Color color = colorInfo.getColor(); - TableRow colorRow = new TableRow(); - colorRow.put(Field.COLOR_RED, color.getRed()); - colorRow.put(Field.COLOR_GREEN, color.getGreen()); - colorRow.put(Field.COLOR_BLUE, color.getBlue()); - if (color.hasAlpha()) { - colorRow.put(Field.COLOR_ALPHA, color.getAlpha()); - } - colorInfoRow.put(Field.COLOR, colorRow); - colors.add(colorInfoRow); - } - ); - TableRow colorsRow = new TableRow(); - colorsRow.put(Field.COLORS, colors); - result.put(Field.DOMINANT_COLORS, colorsRow); - - LOG.debug("Processing {}", result); - - return Collections.singletonList(KV.of(destination, result)); - } -} diff --git a/src/main/java/com/google/solutions/ml/api/vision/processor/LabelAnnotationProcessor.java b/src/main/java/com/google/solutions/ml/api/vision/processor/LabelAnnotationProcessor.java deleted file mode 100644 index 87c1588..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/processor/LabelAnnotationProcessor.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.solutions.ml.api.vision.processor; - -import com.google.api.services.bigquery.model.Clustering; -import com.google.api.services.bigquery.model.TableFieldSchema; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import com.google.api.services.bigquery.model.TimePartitioning; -import com.google.cloud.vision.v1.AnnotateImageResponse; -import com.google.cloud.vision.v1.EntityAnnotation; -import com.google.common.collect.ImmutableList; -import com.google.solutions.ml.api.vision.BQDestination; -import com.google.solutions.ml.api.vision.BigQueryConstants.Mode; -import com.google.solutions.ml.api.vision.BigQueryConstants.Type; -import com.google.solutions.ml.api.vision.TableDetails; -import com.google.solutions.ml.api.vision.TableSchemaProducer; -import com.google.solutions.ml.api.vision.processor.Constants.Field; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.values.KV; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Extracts label annotations (https://cloud.google.com/vision/docs/labels) - */ -public class LabelAnnotationProcessor implements AnnotateImageResponseProcessor { - - private static final long serialVersionUID = 1L; - private static final Logger LOG = LoggerFactory.getLogger(LabelAnnotationProcessor.class); - - public final static Counter counter = - Metrics.counter(AnnotateImageResponseProcessor.class, "numberOfLabelAnnotations"); - - private static class SchemaProducer implements TableSchemaProducer { - - private static final long serialVersionUID = 1L; - - @Override - public TableSchema getTableSchema() { - return new TableSchema().setFields( - ImmutableList.of( - new TableFieldSchema() - .setName(Field.GCS_URI_FIELD) - .setType(Type.STRING) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.MID_FIELD).setType(Type.STRING) - .setMode(Mode.NULLABLE), - new TableFieldSchema() - .setName(Field.DESCRIPTION_FIELD).setType(Type.STRING) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.SCORE_FIELD).setType(Type.FLOAT) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.TOPICALITY_FIELD).setType(Type.FLOAT) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.TIMESTAMP_FIELD).setType(Type.TIMESTAMP) - .setMode(Mode.REQUIRED)) - ); - } - } - - @Override - public TableDetails destinationTableDetails() { - return TableDetails.create("Google Vision API Label Annotations", - new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), - new TimePartitioning().setField(Field.TIMESTAMP_FIELD), new SchemaProducer()); - } - - private final BQDestination destination; - - /** - * Creates a processor and specifies the table id to persist to. - */ - public LabelAnnotationProcessor(String tableId) { - destination = new BQDestination(tableId); - } - - @Override - public Iterable> process( - String gcsURI, AnnotateImageResponse response) { - int numberOfAnnotations = response.getLabelAnnotationsCount(); - if (numberOfAnnotations == 0) { - return null; - } - - counter.inc(numberOfAnnotations); - - Collection> result = new ArrayList<>(numberOfAnnotations); - for (EntityAnnotation annotation : response.getLabelAnnotationsList()) { - TableRow row = ProcessorUtils.startRow(gcsURI); - row.put(Field.MID_FIELD, annotation.getMid()); - row.put(Field.DESCRIPTION_FIELD, annotation.getDescription()); - row.put(Field.SCORE_FIELD, annotation.getScore()); - row.put(Field.TOPICALITY_FIELD, annotation.getTopicality()); - - LOG.debug("Processing {}", row); - result.add(KV.of(destination, row)); - } - return result; - } -} diff --git a/src/main/java/com/google/solutions/ml/api/vision/processor/LandmarkAnnotationProcessor.java b/src/main/java/com/google/solutions/ml/api/vision/processor/LandmarkAnnotationProcessor.java deleted file mode 100644 index 1300153..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/processor/LandmarkAnnotationProcessor.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.solutions.ml.api.vision.processor; - -import com.google.api.services.bigquery.model.Clustering; -import com.google.api.services.bigquery.model.TableFieldSchema; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import com.google.api.services.bigquery.model.TimePartitioning; -import com.google.cloud.vision.v1.AnnotateImageResponse; -import com.google.cloud.vision.v1.EntityAnnotation; -import com.google.common.collect.ImmutableList; -import com.google.solutions.ml.api.vision.BQDestination; -import com.google.solutions.ml.api.vision.BigQueryConstants.Mode; -import com.google.solutions.ml.api.vision.BigQueryConstants.Type; -import com.google.solutions.ml.api.vision.TableDetails; -import com.google.solutions.ml.api.vision.TableSchemaProducer; -import com.google.solutions.ml.api.vision.processor.Constants.Field; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.values.KV; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Extracts landmark annotations (https://cloud.google.com/vision/docs/detecting-landmarks) - */ -public class LandmarkAnnotationProcessor implements AnnotateImageResponseProcessor { - - private static final long serialVersionUID = 1L; - - private static final Logger LOG = LoggerFactory.getLogger(LandmarkAnnotationProcessor.class); - public final static Counter counter = - Metrics.counter(AnnotateImageResponseProcessor.class, "numberOfLandmarkAnnotations"); - - private static class SchemaProducer implements TableSchemaProducer { - - private static final long serialVersionUID = 1L; - - @Override - public TableSchema getTableSchema() { - return new TableSchema().setFields( - ImmutableList.of( - new TableFieldSchema() - .setName(Field.GCS_URI_FIELD) - .setType(Type.STRING) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.MID_FIELD).setType(Type.STRING) - .setMode(Mode.NULLABLE), - new TableFieldSchema() - .setName(Field.DESCRIPTION_FIELD).setType(Type.STRING) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.SCORE_FIELD).setType(Type.FLOAT) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.BOUNDING_POLY).setType(Type.RECORD) - .setMode(Mode.NULLABLE).setFields(Constants.POLYGON_FIELDS), - new TableFieldSchema() - .setName(Field.LOCATIONS).setType(Type.GEOGRAPHY).setMode(Mode.REPEATED), - new TableFieldSchema() - .setName(Field.TIMESTAMP_FIELD).setType(Type.TIMESTAMP) - .setMode(Mode.REQUIRED)) - ); - } - } - - @Override - public TableDetails destinationTableDetails() { - return TableDetails.create("Google Vision API Landmark Annotations", - new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), - new TimePartitioning().setField(Field.TIMESTAMP_FIELD), new SchemaProducer()); - } - - private final BQDestination destination; - - /** - * Creates a processor and specifies the table id to persist to. - */ - public LandmarkAnnotationProcessor(String tableId) { - destination = new BQDestination(tableId); - } - - @Override - public Iterable> process( - String gcsURI, AnnotateImageResponse response) { - int numberOfAnnotations = response.getLandmarkAnnotationsCount(); - if (numberOfAnnotations == 0) { - return null; - } - - counter.inc(numberOfAnnotations); - - Collection> result = new ArrayList<>(numberOfAnnotations); - for (EntityAnnotation annotation : response.getLandmarkAnnotationsList()) { - TableRow row = ProcessorUtils.startRow(gcsURI); - row.put(Field.MID_FIELD, annotation.getMid()); - row.put(Field.DESCRIPTION_FIELD, annotation.getDescription()); - row.put(Field.SCORE_FIELD, annotation.getScore()); - - ProcessorUtils.extractBoundingPoly(annotation, row); - - if (annotation.getLocationsCount() > 0) { - List locations = new ArrayList<>(annotation.getLocationsCount()); - annotation.getLocationsList().forEach( - location -> locations.add( - "POINT(" + location.getLatLng().getLongitude() + " " + - location.getLatLng().getLatitude() + ")")); - row.put(Field.LOCATIONS, locations); - } - - LOG.debug("Processing {}", row); - result.add(KV.of(destination, row)); - } - - return result; - } - -} diff --git a/src/main/java/com/google/solutions/ml/api/vision/processor/LogoAnnotationProcessor.java b/src/main/java/com/google/solutions/ml/api/vision/processor/LogoAnnotationProcessor.java deleted file mode 100644 index a0fb72f..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/processor/LogoAnnotationProcessor.java +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.solutions.ml.api.vision.processor; - -import com.google.api.services.bigquery.model.Clustering; -import com.google.api.services.bigquery.model.TableFieldSchema; -import com.google.api.services.bigquery.model.TableRow; -import com.google.api.services.bigquery.model.TableSchema; -import com.google.api.services.bigquery.model.TimePartitioning; -import com.google.cloud.vision.v1.AnnotateImageResponse; -import com.google.cloud.vision.v1.EntityAnnotation; -import com.google.common.collect.ImmutableList; -import com.google.solutions.ml.api.vision.BQDestination; -import com.google.solutions.ml.api.vision.BigQueryConstants.Mode; -import com.google.solutions.ml.api.vision.BigQueryConstants.Type; -import com.google.solutions.ml.api.vision.TableDetails; -import com.google.solutions.ml.api.vision.TableSchemaProducer; -import com.google.solutions.ml.api.vision.processor.Constants.Field; -import java.util.ArrayList; -import java.util.Collection; -import java.util.Collections; -import org.apache.beam.sdk.metrics.Counter; -import org.apache.beam.sdk.metrics.Metrics; -import org.apache.beam.sdk.values.KV; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Extracts logo annotations (https://cloud.google.com/vision/docs/detecting-logos) - */ -public class LogoAnnotationProcessor implements AnnotateImageResponseProcessor { - - private static final long serialVersionUID = 1L; - - public final static Counter counter = - Metrics.counter(AnnotateImageResponseProcessor.class, "numberOfLogoAnnotations"); - public static final Logger LOG = LoggerFactory.getLogger(LogoAnnotationProcessor.class); - - private static class SchemaProducer implements TableSchemaProducer { - - private static final long serialVersionUID = 1L; - - @Override - public TableSchema getTableSchema() { - return new TableSchema().setFields( - ImmutableList.of( - new TableFieldSchema() - .setName(Field.GCS_URI_FIELD) - .setType(Type.STRING) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.MID_FIELD).setType(Type.STRING) - .setMode(Mode.NULLABLE), - new TableFieldSchema() - .setName(Field.DESCRIPTION_FIELD).setType(Type.STRING) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.SCORE_FIELD).setType(Type.FLOAT) - .setMode(Mode.REQUIRED), - new TableFieldSchema() - .setName(Field.BOUNDING_POLY).setType(Type.RECORD) - .setMode(Mode.NULLABLE).setFields(Constants.POLYGON_FIELDS), - new TableFieldSchema() - .setName(Field.TIMESTAMP_FIELD).setType(Type.TIMESTAMP) - .setMode(Mode.REQUIRED)) - ); - } - } - - @Override - public TableDetails destinationTableDetails() { - return TableDetails.create("Google Vision API Logo Annotations", - new Clustering().setFields(Collections.singletonList(Field.GCS_URI_FIELD)), - new TimePartitioning().setField(Field.TIMESTAMP_FIELD), new SchemaProducer()); - } - - private final BQDestination destination; - - /** - * Creates a processor and specifies the table id to persist to. - */ - public LogoAnnotationProcessor(String tableId) { - destination = new BQDestination(tableId); - } - - @Override - public Iterable> process( - String gcsURI, AnnotateImageResponse response) { - int numberOfAnnotations = response.getLogoAnnotationsCount(); - if (numberOfAnnotations == 0) { - return null; - } - - counter.inc(numberOfAnnotations); - - Collection> result = new ArrayList<>(numberOfAnnotations); - for (EntityAnnotation annotation : response.getLabelAnnotationsList()) { - TableRow row = ProcessorUtils.startRow(gcsURI); - row.put(Field.MID_FIELD, annotation.getMid()); - row.put(Field.DESCRIPTION_FIELD, annotation.getDescription()); - row.put(Field.SCORE_FIELD, annotation.getScore()); - ProcessorUtils.extractBoundingPoly(annotation, row); - - LOG.debug("Processing {}", row); - result.add(KV.of(destination, row)); - } - return result; - } -} diff --git a/src/main/java/com/google/solutions/ml/api/vision/processor/ProcessorUtils.java b/src/main/java/com/google/solutions/ml/api/vision/processor/ProcessorUtils.java deleted file mode 100644 index ae7bfe7..0000000 --- a/src/main/java/com/google/solutions/ml/api/vision/processor/ProcessorUtils.java +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright 2020 Google LLC - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package com.google.solutions.ml.api.vision.processor; - -import com.google.api.services.bigquery.model.TableRow; -import com.google.cloud.vision.v1.BoundingPoly; -import com.google.cloud.vision.v1.EntityAnnotation; -import com.google.solutions.ml.api.vision.processor.Constants.Field; -import java.util.ArrayList; -import java.util.List; -import org.joda.time.DateTimeZone; -import org.joda.time.Instant; -import org.joda.time.format.DateTimeFormat; -import org.joda.time.format.DateTimeFormatter; - -/** - * Various utility functions used by processors - */ -public class ProcessorUtils { - - /** - * Extracts the bounding polygon if one exists and adds it to the row. - */ - static void extractBoundingPoly(EntityAnnotation annotation, TableRow row) { - if (annotation.hasBoundingPoly()) { - TableRow boundingPoly = getBoundingPolyAsRow(annotation.getBoundingPoly()); - row.put(Field.BOUNDING_POLY, boundingPoly); - } - } - - /** - * Converts {@link BoundingPoly} to a {@link TableRow}. - * - * @return table row - */ - static TableRow getBoundingPolyAsRow(BoundingPoly boundingPoly) { - List vertices = new ArrayList<>(); - boundingPoly.getVerticesList() - .forEach( - vertex -> { - TableRow vertexRow = new TableRow(); - vertexRow.put(Field.VERTEX_X, vertex.getX()); - vertexRow.put(Field.VERTEX_Y, vertex.getY()); - vertices.add(vertexRow); - }); - TableRow result = new TableRow(); - result.put(Field.VERTICES, vertices); - return result; - } - - /** - * Creates a TableRow and populates with two fields used in all processors: {@link - * Constants.Field#GCS_URI_FIELD} and {@link Constants.Field#TIMESTAMP_FIELD} - * - * @return new TableRow - */ - static TableRow startRow(String gcsURI) { - TableRow row = new TableRow(); - row.put(Field.GCS_URI_FIELD, gcsURI); - row.put(Field.TIMESTAMP_FIELD, getTimeStamp()); - return row; - } - - private static final DateTimeFormatter TIMESTAMP_FORMATTER = - DateTimeFormat.forPattern("yyyy-MM-dd HH:mm:ss.SSSSSS"); - - /** - * Formats the current timestamp in BigQuery compliant format - */ - public static String getTimeStamp() { - return TIMESTAMP_FORMATTER.print(Instant.now().toDateTime(DateTimeZone.UTC)); - } -} diff --git a/src/test/java/com/google/solutions/annotation/BatchRequestsTransformTest.java b/src/test/java/com/google/solutions/annotation/BatchRequestsTransformTest.java new file mode 100644 index 0000000..6da29ad --- /dev/null +++ b/src/test/java/com/google/solutions/annotation/BatchRequestsTransformTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation; + +import com.google.common.collect.ImmutableList; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import java.util.*; +import java.util.concurrent.atomic.AtomicReference; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.values.PCollection; +import org.junit.Rule; +import org.junit.Test; + +public class BatchRequestsTransformTest { + + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); + + @Test + public void testBatching() { + GCSFileInfo fileInfo1 = + new GCSFileInfo("gs://bucket/example1.jpg", "image/jpeg", new HashMap<>()); + GCSFileInfo fileInfo2 = + new GCSFileInfo("gs://bucket/example2.jpg", "image/jpeg", new HashMap<>()); + GCSFileInfo fileInfo3 = + new GCSFileInfo("gs://bucket/example3.jpg", "image/jpeg", new HashMap<>()); + GCSFileInfo fileInfo4 = + new GCSFileInfo("gs://bucket/example4.jpg", "image/jpeg", new HashMap<>()); + GCSFileInfo fileInfo5 = + new GCSFileInfo("gs://bucket/example5.jpg", "image/jpeg", new HashMap<>()); + Create.Values fileInfos = + Create.of(fileInfo1, fileInfo2, fileInfo3, fileInfo4, fileInfo5); + PCollection> batchedFileInfos = + pipeline.apply(fileInfos).apply(BatchRequestsTransform.create(2, 50)); + List expectedItems = + new ArrayList<>(ImmutableList.of(fileInfo1, fileInfo2, fileInfo3, fileInfo4, fileInfo5)); + List expectedBatchSizes = new ArrayList<>(ImmutableList.of(2, 2, 1)); + PAssert.that(batchedFileInfos) + .satisfies( + batches -> { + batches.forEach( + batch -> { + AtomicReference batchSize = new AtomicReference<>(0); + batch.forEach( + fileInfo -> { + // Make sure the item is in the expected list + if (!expectedItems.remove(fileInfo)) throw new AssertionError(); + batchSize.getAndSet(batchSize.get() + 1); + }); + // Make sure the batch size is expected + if (!expectedBatchSizes.remove(batchSize.get())) throw new AssertionError(); + }); + // Make sure no other item exists other than the expected ones + assert expectedItems.size() == 0; + return null; + }); + + pipeline.run().waitUntilFinish(); + } +} diff --git a/src/test/java/com/google/solutions/annotation/ml/FakeProcessor.java b/src/test/java/com/google/solutions/annotation/ml/FakeProcessor.java new file mode 100644 index 0000000..f023a6e --- /dev/null +++ b/src/test/java/com/google/solutions/annotation/ml/FakeProcessor.java @@ -0,0 +1,100 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml; + +import com.google.api.services.bigquery.model.*; +import com.google.cloud.videointelligence.v1p3beta1.LabelAnnotation; +import com.google.cloud.videointelligence.v1p3beta1.StreamingAnnotateVideoResponse; +import com.google.cloud.videointelligence.v1p3beta1.StreamingVideoAnnotationResults; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.bigquery.TableDetails; +import com.google.solutions.annotation.bigquery.TableSchemaProducer; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import java.util.ArrayList; +import java.util.Set; + +public class FakeProcessor implements MLApiResponseProcessor { + + private final BigQueryDestination destination; + private final Set metadataKeys; + + public FakeProcessor(String tableId, Set metadataKeys) { + this.destination = new BigQueryDestination(tableId); + this.metadataKeys = metadataKeys; + } + + @Override + public ProcessorResult process(GCSFileInfo fileInfo, GeneratedMessageV3 r) { + StreamingAnnotateVideoResponse response = (StreamingAnnotateVideoResponse) r; + StreamingVideoAnnotationResults annotationResults = response.getAnnotationResults(); + ProcessorResult result = new ProcessorResult("fake_type", destination); + for (LabelAnnotation annotation : annotationResults.getLabelAnnotationsList()) { + TableRow row = new TableRow(); + row.put(Constants.Field.GCS_URI_FIELD, fileInfo.getUri()); + row.set(Constants.Field.GCS_URI_FIELD, fileInfo.getUri()); + row.set(Constants.Field.ENTITY, annotation.getEntity().getDescription()); + ProcessorUtils.addMetadataValues(row, fileInfo, metadataKeys); + + result.allRows.add(row); + + if (annotation.getEntity().getDescription().equals("chocolate")) { + result.relevantRows.add(row); + } + } + return result; + } + + @Override + public TableDetails destinationTableDetails() { + return TableDetails.create("", null, null, new SchemaProducer(metadataKeys)); + } + + @Override + public boolean shouldProcess(GeneratedMessageV3 response) { + return (response instanceof StreamingAnnotateVideoResponse); + } + + private static class SchemaProducer implements TableSchemaProducer { + + private static final long serialVersionUID = 1L; + private final Set metadataKeys; + + SchemaProducer(Set metadataKeys) { + this.metadataKeys = metadataKeys; + } + + @Override + public TableSchema getTableSchema() { + + ArrayList fields = new ArrayList<>(); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.GCS_URI_FIELD) + .setType("STRING") + .setMode("REQUIRED")); + fields.add( + new TableFieldSchema() + .setName(Constants.Field.ENTITY) + .setType("STRING") + .setMode("REQUIRED")); + + ProcessorUtils.setMetadataFieldsSchema(fields, metadataKeys); + + return new TableSchema().setFields(fields); + } + } +} diff --git a/src/test/java/com/google/solutions/annotation/ml/ProcessMLApiResponseDoFnTest.java b/src/test/java/com/google/solutions/annotation/ml/ProcessMLApiResponseDoFnTest.java new file mode 100644 index 0000000..ecfe921 --- /dev/null +++ b/src/test/java/com/google/solutions/annotation/ml/ProcessMLApiResponseDoFnTest.java @@ -0,0 +1,100 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.ml; + +import com.google.api.services.bigquery.model.TableRow; +import com.google.cloud.videointelligence.v1p3beta1.Entity; +import com.google.cloud.videointelligence.v1p3beta1.LabelAnnotation; +import com.google.cloud.videointelligence.v1p3beta1.StreamingAnnotateVideoResponse; +import com.google.cloud.videointelligence.v1p3beta1.StreamingVideoAnnotationResults; +import com.google.common.collect.ImmutableSet; +import com.google.protobuf.ByteString; +import com.google.protobuf.GeneratedMessageV3; +import com.google.solutions.annotation.bigquery.BigQueryDestination; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.*; +import org.junit.Rule; +import org.junit.Test; + +public class ProcessMLApiResponseDoFnTest { + + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); + + private static final TupleTag> allRows = + new TupleTag>() {}; + private static final TupleTag> relevantRows = + new TupleTag>() {}; + + private StreamingAnnotateVideoResponse createResponse(String entity) { + return StreamingAnnotateVideoResponse.newBuilder() + .setAnnotationResults( + StreamingVideoAnnotationResults.newBuilder() + .addLabelAnnotations( + LabelAnnotation.newBuilder() + .setEntity( + Entity.newBuilder() + .setDescriptionBytes(ByteString.copyFrom(entity.getBytes())) + .build()) + .build()) + .build()) + .build(); + } + + @Test + public void testSuccess() { + Map processors = + Map.of("fake_table", new FakeProcessor("fake_table", Set.of())); + + GCSFileInfo fileInfo1 = + new GCSFileInfo("gs://mybucket/example1.jpg", "image/jpeg", new HashMap<>()); + StreamingAnnotateVideoResponse response1 = createResponse("chocolate"); + + GCSFileInfo fileInfo2 = + new GCSFileInfo("gs://mybucket/example2.jpg", "image/jpeg", new HashMap<>()); + StreamingAnnotateVideoResponse response2 = createResponse("coffee"); + + PCollection> annotatedFiles = + pipeline.apply(Create.of(KV.of(fileInfo1, response1), KV.of(fileInfo2, response2))); + PCollectionTuple annotationOutcome = + annotatedFiles.apply( + ParDo.of( + ProcessMLApiResponseDoFn.create( + ImmutableSet.copyOf(processors.values()), allRows, relevantRows)) + .withOutputTags(allRows, TupleTagList.of(relevantRows))); + + TableRow row1 = new TableRow(); + row1.put(Constants.Field.GCS_URI_FIELD, "gs://mybucket/example1.jpg"); + row1.set(Constants.Field.ENTITY, "chocolate"); + TableRow row2 = new TableRow(); + row2.put(Constants.Field.GCS_URI_FIELD, "gs://mybucket/example2.jpg"); + row2.set(Constants.Field.ENTITY, "coffee"); + + PAssert.that(annotationOutcome.get(allRows)) + .containsInAnyOrder( + KV.of(new BigQueryDestination("fake_table"), row1), + KV.of(new BigQueryDestination("fake_table"), row2)); + PAssert.that(annotationOutcome.get(relevantRows)).containsInAnyOrder(KV.of("fake_type", row1)); + + pipeline.run().waitUntilFinish(); + } +} diff --git a/src/test/java/com/google/solutions/annotation/pubsub/PubSubNotificationToGCSInfoDoFnTest.java b/src/test/java/com/google/solutions/annotation/pubsub/PubSubNotificationToGCSInfoDoFnTest.java new file mode 100644 index 0000000..c54cad2 --- /dev/null +++ b/src/test/java/com/google/solutions/annotation/pubsub/PubSubNotificationToGCSInfoDoFnTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2022 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.google.solutions.annotation.pubsub; + +import com.google.common.collect.ImmutableSet; +import com.google.solutions.annotation.gcs.GCSFileInfo; +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import org.apache.beam.sdk.io.gcp.pubsub.PubsubMessage; +import org.apache.beam.sdk.testing.PAssert; +import org.apache.beam.sdk.testing.TestPipeline; +import org.apache.beam.sdk.transforms.Create; +import org.apache.beam.sdk.transforms.ParDo; +import org.apache.beam.sdk.values.PCollection; +import org.codehaus.jackson.map.ObjectMapper; +import org.junit.Rule; +import org.junit.Test; + +public class PubSubNotificationToGCSInfoDoFnTest { + + @Rule public final transient TestPipeline pipeline = TestPipeline.create(); + + private static final ObjectMapper jsonMapper = new ObjectMapper(); + + private static Set getSampleSupportedContentTypes() { + return new HashSet<>(ImmutableSet.of("image/jpeg", "image/png")); + } + + private static Map getSamplePayload() { + return new HashMap<>(Map.of("contentType", "image/jpeg", "metadata", Map.of("foo", "bar"))); + } + + private static Map getSampleAttributes() { + return new HashMap<>( + Map.of( + "eventType", "OBJECT_FINALIZE", + "bucketId", "mybucket", + "objectId", "example.jpg")); + } + + @Test + public void testSuccess() throws IOException { + PubsubMessage message = + new PubsubMessage(jsonMapper.writeValueAsBytes(getSamplePayload()), getSampleAttributes()); + PCollection fileInfos = + pipeline + .apply(Create.of(message)) + .apply( + ParDo.of(PubSubNotificationToGCSInfoDoFn.create(getSampleSupportedContentTypes()))); + // Make sure the event was accepted and properly processed + PAssert.that(fileInfos) + .containsInAnyOrder( + new GCSFileInfo("gs://mybucket/example.jpg", "image/jpeg", Map.of("foo", "bar"))); + pipeline.run().waitUntilFinish(); + } + + @Test + public void testUnsupportedContentType() throws IOException { + Map payload = getSamplePayload(); + PubsubMessage message = + new PubsubMessage(jsonMapper.writeValueAsBytes(payload), getSampleAttributes()); + + // Remove the input file's content type from the list of supported ones + Set supportedContentTypes = getSampleSupportedContentTypes(); + supportedContentTypes.remove("image/jpeg"); + + PCollection fileInfos = + pipeline + .apply(Create.of(message)) + .apply(ParDo.of(PubSubNotificationToGCSInfoDoFn.create(supportedContentTypes))); + // Make sure the event was rejected + PAssert.that(fileInfos).empty(); + pipeline.run().waitUntilFinish(); + } + + @Test + public void testUnsupportedEventType() throws IOException { + Map attributes = getSampleAttributes(); + attributes.put("eventType", "XXXXX"); + PubsubMessage message = + new PubsubMessage(jsonMapper.writeValueAsBytes(getSamplePayload()), attributes); + PCollection fileInfos = + pipeline + .apply(Create.of(message)) + .apply( + ParDo.of(PubSubNotificationToGCSInfoDoFn.create(getSampleSupportedContentTypes()))); + // Make sure the event was rejected + PAssert.that(fileInfos).empty(); + pipeline.run().waitUntilFinish(); + } +}