Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 51 additions & 6 deletions core/src/main/java/io/substrait/extension/SimpleExtension.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand Down Expand Up @@ -276,6 +278,8 @@ public String description() {

public abstract Map<String, Option> options();

public abstract Optional<Map<String, Object>> metadata();
Copy link
Member

@nielspardon nielspardon Mar 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering whether it would be useful to have an interface instead of Object for e.g. creating custom YAML deserializers / serializers and whether we should allow to add them to the ObjectMapper in SimpleExtension.load().

Otherwise I'm fine with the changes.


public List<Argument> requiredArguments() {
return requiredArgsSupplier.get();
}
Expand Down Expand Up @@ -381,25 +385,29 @@ public abstract static class ScalarFunction {
@Nullable
public abstract String description();

public abstract Optional<Map<String, Object>> metadata();

public abstract List<ScalarFunctionVariant> impls();

public Stream<ScalarFunctionVariant> resolve(String urn) {
return impls().stream().map(f -> f.resolve(urn, name(), description()));
return impls().stream().map(f -> f.resolve(urn, name(), description(), metadata()));
}
}

@JsonDeserialize(as = ImmutableSimpleExtension.ScalarFunctionVariant.class)
@JsonSerialize(as = ImmutableSimpleExtension.ScalarFunctionVariant.class)
@Value.Immutable
public abstract static class ScalarFunctionVariant extends Function {
public ScalarFunctionVariant resolve(String urn, String name, String description) {
public ScalarFunctionVariant resolve(
String urn, String name, String description, Optional<Map<String, Object>> metadata) {
return ImmutableSimpleExtension.ScalarFunctionVariant.builder()
.urn(urn)
.name(name)
.description(description)
.nullability(nullability())
.args(args())
.options(options())
.metadata(metadata)
.ordered(ordered())
.variadic(variadic())
.returnType(returnType())
Expand All @@ -417,10 +425,12 @@ public abstract static class AggregateFunction {
@Nullable
public abstract String description();

public abstract Optional<Map<String, Object>> metadata();

public abstract List<AggregateFunctionVariant> impls();

public Stream<AggregateFunctionVariant> resolve(String urn) {
return impls().stream().map(f -> f.resolve(urn, name(), description()));
return impls().stream().map(f -> f.resolve(urn, name(), description(), metadata()));
}
}

Expand All @@ -434,10 +444,12 @@ public abstract static class WindowFunction {
@Nullable
public abstract String description();

public abstract Optional<Map<String, Object>> metadata();

public abstract List<WindowFunctionVariant> impls();

public Stream<WindowFunctionVariant> resolve(String urn) {
return impls().stream().map(f -> f.resolve(urn, name(), description()));
return impls().stream().map(f -> f.resolve(urn, name(), description(), metadata()));
}

public static ImmutableSimpleExtension.WindowFunction.Builder builder() {
Expand All @@ -463,14 +475,16 @@ public String toString() {
@Nullable
public abstract TypeExpression intermediate();

AggregateFunctionVariant resolve(String urn, String name, String description) {
AggregateFunctionVariant resolve(
String urn, String name, String description, Optional<Map<String, Object>> metadata) {
return ImmutableSimpleExtension.AggregateFunctionVariant.builder()
.urn(urn)
.name(name)
.description(description)
.nullability(nullability())
.args(args())
.options(options())
.metadata(metadata)
.ordered(ordered())
.variadic(variadic())
.decomposability(decomposability())
Expand Down Expand Up @@ -505,14 +519,16 @@ public String toString() {
return super.toString();
}

WindowFunctionVariant resolve(String urn, String name, String description) {
WindowFunctionVariant resolve(
String urn, String name, String description, Optional<Map<String, Object>> metadata) {
return ImmutableSimpleExtension.WindowFunctionVariant.builder()
.urn(urn)
.name(name)
.description(description)
.nullability(nullability())
.args(args())
.options(options())
.metadata(metadata)
.ordered(ordered())
.variadic(variadic())
.decomposability(decomposability())
Expand Down Expand Up @@ -549,6 +565,8 @@ public abstract static class Type {

protected abstract Optional<Boolean> variadic();

public abstract Optional<Map<String, Object>> metadata();

public TypeAnchor getAnchor() {
return anchorSupplier.get();
}
Expand All @@ -574,6 +592,9 @@ public abstract static class ExtensionSignatures {
@JsonProperty("window_functions")
public abstract List<WindowFunction> windows();

@JsonProperty("metadata")
public abstract Optional<Map<String, Object>> metadata();

public int size() {
return (types() == null ? 0 : types().size())
+ (scalars() == null ? 0 : scalars().size())
Expand Down Expand Up @@ -643,6 +664,11 @@ BidiMap<String, String> uriUrnMap() {
return new BidiMap<>();
}

@Value.Default
public Map<String, Map<String, Object>> extensionMetadata() {
return Collections.emptyMap();
}

public abstract List<Type> types();

public abstract List<ScalarFunctionVariant> scalarFunctions();
Expand All @@ -655,6 +681,16 @@ public static ImmutableSimpleExtension.ExtensionCollection.Builder builder() {
return ImmutableSimpleExtension.ExtensionCollection.builder();
}

/**
* Gets the top-level metadata for a specific extension by URN.
*
* @param urn The URN of the extension
* @return The metadata map if present, empty Optional otherwise
*/
public Optional<Map<String, Object>> getExtensionMetadata(String urn) {
return Optional.ofNullable(extensionMetadata().get(urn));
}

public Type getType(TypeAnchor anchor) {
Type type = typeLookup.get().get(anchor);
if (type != null) {
Expand Down Expand Up @@ -744,6 +780,10 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) {
mergedUriUrnMap.merge(uriUrnMap());
mergedUriUrnMap.merge(extensionCollection.uriUrnMap());

Map<String, Map<String, Object>> mergedExtensionMetadata = new HashMap<>();
mergedExtensionMetadata.putAll(extensionMetadata());
mergedExtensionMetadata.putAll(extensionCollection.extensionMetadata());

return ImmutableSimpleExtension.ExtensionCollection.builder()
.addAllAggregateFunctions(aggregateFunctions())
.addAllAggregateFunctions(extensionCollection.aggregateFunctions())
Expand All @@ -754,6 +794,7 @@ public ExtensionCollection merge(ExtensionCollection extensionCollection) {
.addAllTypes(types())
.addAllTypes(extensionCollection.types())
.uriUrnMap(mergedUriUrnMap)
.extensionMetadata(mergedExtensionMetadata)
.build();
}
}
Expand Down Expand Up @@ -859,13 +900,17 @@ public static ExtensionCollection buildExtensionCollection(
BidiMap<String, String> uriUrnMap = new BidiMap<>();
uriUrnMap.put(uri, urn);

Map<String, Map<String, Object>> extMetadata = new HashMap<>();
extensionSignatures.metadata().ifPresent(m -> extMetadata.put(urn, m));

ImmutableSimpleExtension.ExtensionCollection collection =
ImmutableSimpleExtension.ExtensionCollection.builder()
.scalarFunctions(scalarFunctionVariants)
.aggregateFunctions(aggregateFunctionVariants)
.windowFunctions(allWindowFunctionVariants)
.addAllTypes(extensionSignatures.types())
.uriUrnMap(uriUrnMap)
.extensionMetadata(extMetadata)
.build();

LOGGER.atDebug().log(
Expand Down
102 changes: 102 additions & 0 deletions core/src/test/java/io/substrait/extension/MetadataExtensionTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package io.substrait.extension;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

import io.substrait.TestBase;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;
import org.junit.jupiter.api.Test;

/**
* Verifies that metadata can be read from extension YAML files at multiple levels:
*
* <ul>
* <li>Extension-level metadata (top-level)
* <li>Type-level metadata
* <li>Function-level metadata (scalar, aggregate, window)
* </ul>
*/
class MetadataExtensionTest extends TestBase {

static final String URN = "extension:test:metadata_extensions";
static final SimpleExtension.ExtensionCollection METADATA_EXTENSION;

static {
try {
String extensionStr = asString("extensions/metadata_extensions.yaml");
METADATA_EXTENSION = SimpleExtension.load(URN, extensionStr);
} catch (IOException e) {
throw new UncheckedIOException(e);
}
}

MetadataExtensionTest() {
super(METADATA_EXTENSION);
}

@Test
void testExtensionLevelMetadata() {
Map<String, Object> metadata = extensions.getExtensionMetadata(URN).orElseThrow();
assertEquals("1.0", metadata.get("version"));
assertEquals("test-team", metadata.get("author"));

@SuppressWarnings("unchecked")
Map<String, Object> customData = (Map<String, Object>) metadata.get("custom_data");
assertEquals(true, customData.get("nested_value"));
assertEquals(42, customData.get("numeric_value"));
}

@Test
void testExtensionLevelMetadataMissing() {
assertTrue(extensions.getExtensionMetadata("extension:nonexistent:urn").isEmpty());
}

@Test
void testTypeMetadata() {
SimpleExtension.TypeAnchor anchor = SimpleExtension.TypeAnchor.of(URN, "metadataType");
Map<String, Object> metadata = extensions.getType(anchor).metadata().orElseThrow();
assertEquals("custom-type-metadata", metadata.get("type_info"));
assertEquals("user-defined", metadata.get("category"));
}

@Test
void testScalarFunctionMetadata() {
SimpleExtension.FunctionAnchor anchor =
SimpleExtension.FunctionAnchor.of(URN, "metadataScalar:i64");
Map<String, Object> metadata = extensions.getScalarFunction(anchor).metadata().orElseThrow();
assertEquals("vectorized", metadata.get("perf_hint"));
assertEquals(1, metadata.get("cost"));
}

@Test
void testAggregateFunctionMetadata() {
SimpleExtension.FunctionAnchor anchor =
SimpleExtension.FunctionAnchor.of(URN, "metadataAggregate:i64");
assertEquals(
"incremental",
extensions.getAggregateFunction(anchor).metadata().orElseThrow().get("agg_info"));
}

@Test
void testWindowFunctionMetadata() {
SimpleExtension.FunctionAnchor anchor =
SimpleExtension.FunctionAnchor.of(URN, "metadataWindow:i64");
assertEquals(
"partitioned",
extensions.getWindowFunction(anchor).metadata().orElseThrow().get("window_info"));
}

@Test
void testMergePreservesMetadata() throws IOException {
String customExtensionStr = asString("extensions/custom_extensions.yaml");
SimpleExtension.ExtensionCollection customExtension =
SimpleExtension.load("extension:test:custom_extensions", customExtensionStr);

SimpleExtension.ExtensionCollection merged = METADATA_EXTENSION.merge(customExtension);

assertEquals("1.0", merged.getExtensionMetadata(URN).orElseThrow().get("version"));
assertTrue(merged.getExtensionMetadata("extension:test:custom_extensions").isEmpty());
}
}
39 changes: 39 additions & 0 deletions core/src/test/resources/extensions/metadata_extensions.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
%YAML 1.2
---
urn: extension:test:metadata_extensions
metadata:
version: "1.0"
author: "test-team"
custom_data:
nested_value: true
numeric_value: 42
types:
- name: "metadataType"
metadata:
type_info: "custom-type-metadata"
category: "user-defined"
scalar_functions:
- name: "metadataScalar"
metadata:
perf_hint: "vectorized"
cost: 1
impls:
- args:
- value: i64
return: i64
aggregate_functions:
- name: "metadataAggregate"
metadata:
agg_info: "incremental"
impls:
- args:
- value: i64
return: i64
window_functions:
- name: "metadataWindow"
metadata:
window_info: "partitioned"
impls:
- args:
- value: i64
return: i64