Skip to content

SparkParquetWriterWithVariantShredding fails on wide mixed-schema Iceberg v3 tables with spark.sql.iceberg.shred-variants=true #16782

Description

@soumilshah1995

Summary

With spark.sql.iceberg.shred-variants=true, INSERT into an Iceberg v3 table that mixes scalars, ARRAYs, and many VARIANT columns fails:

ClassCastException: optional binary element (STRING) = N is not a group
  at SparkParquetWriterWithVariantShredding.initialize(...)

Same JSON, same parse_json() expressions, 2 rows only — but a 3-column table (id + 2 VARIANT) passes. Points to writer/schema layout, not bad JSON.


Environment

Component Version
Spark 4.0.2-amzn-0 (EMR 8.x)
Iceberg 1.11.0 (1.10.1-amzn-0 also on EMR classpath)
Catalog S3 Tables (s3-tables-catalog-for-iceberg-runtime:0.1.5)
Table format-version 3, VARIANT via parse_json()

Note: Wide write passes on local Apache Spark + Hadoop warehouse + Iceberg 1.11. Failure seen on EMR + S3 Tables.


Reproduce

Repro script: emr_s3tables_shred_repro.py (mock data, no external files).

spark-submit --deploy-mode client \
  --driver-memory 2g --executor-memory 2g \
  --executor-cores 2 --num-executors 2 \
  --conf spark.sql.iceberg.shred-variants=true \
  emr_s3tables_shred_repro.py
Case Table shape Result
minimal id + 2 VARIANT PASS
wide 20 cols (scalars + 11 VARIANT + 2 ARRAY) FAIL

Wide DDL (core columns):

CREATE TABLE t (
  id STRING, region STRING, user_key STRING, email STRING,
  payload_a VARIANT, payload_b VARIANT, updated_at TIMESTAMP,
  status STRING, hash_key STRING,
  extra_a VARIANT, extra_b VARIANT, extra_c VARIANT, extra_d VARIANT,
  extra_e VARIANT, extra_f VARIANT, extra_g VARIANT,
  id_list ARRAY<STRING>, extra_h VARIANT, extra_i VARIANT,
  tag_list ARRAY<STRING>
) USING iceberg
PARTITIONED BY (bucket(16, id))
TBLPROPERTIES ('format-version' = '3');

Insert (both rows identical):

INSERT OVERWRITE t (...)
SELECT
  id, region, user_key, email,
  parse_json(payload_a), parse_json(payload_b),
  COALESCE(updated_at, current_timestamp()), status, hash_key,
  parse_json(extra_a), /* ... other VARIANT cols ... */
  CASE WHEN id_list = '[]' THEN array() ELSE from_json(id_list, 'ARRAY<STRING>') END,
  /* ... */
FROM mock_source;

Mock JSON (same on every row): {"channel":{"type":"email"}}, {"segment":"alpha","score":42}, etc.


Expected vs actual

  • Expected: Wide table writes with shredding enabled when JSON is valid and consistent.
  • Actual: ClassCastException at Parquet write; field id N is internal Parquet id (not a JSON key).

Ruled out

  • Mixed JSON types at same path → fails not reproduced in minimal 3-col table
  • Row volume → fails with 2 rows
  • INSERT SQL → same SQL works in minimal table
  • shred-variants=false → full schema works (workaround)

Workaround

spark.sql.iceberg.shred-variants=false

Stack trace (abbreviated)

ClassCastException: optional binary element (STRING) = 21 is not a group
  at Type.asGroupType(Type.java:247)
  at ParquetWithSparkSchemaVisitor.visit(...)
  at SparkParquetWriters.buildWriter(...)
  at SparkParquetWriterWithVariantShredding.initialize(...)
  at SparkWrite$PartitionedDataWriter.write(...)

Also file if Iceberg says “catalog-specific”


Attachments

#!/usr/bin/env python3
"""

spark-submit --deploy-mode client \
  --driver-memory 2g \
  --executor-memory 2g \
  --executor-cores 2 \
  --num-executors 2 \
  --conf spark.sql.iceberg.shred-variants=true \
  emr_s3tables_shred_repro.py
"""
from __future__ import annotations

import argparse
import time
import traceback

import boto3
from botocore.exceptions import ClientError
from pyspark.sql import SparkSession
from pyspark.sql.types import StringType, StructField, StructType

CATALOG = "crm_user_props"
NAMESPACE = "test"
SOURCE_VIEW = "mock_source"
TABLE_BUCKET_ARN = (
    "arn:aws:s3tables:us-east-1:XX:bucket/XX"
)
AWS_REGION = "us-east-1"
DELETE_POLL_SEC = 5
DELETE_TIMEOUT_SEC = 90

JSON_A = '{"channel": {"type": "email", "value": "user@example.com"}}'
JSON_B = '{"segment": "alpha", "score": 42}'
JSON_SIMPLE = '{"enabled": true}'

WIDE_SOURCE_SCHEMA = StructType(
    [
        StructField("id", StringType()),
        StructField("region", StringType()),
        StructField("user_key", StringType()),
        StructField("email", StringType()),
        StructField("payload_a", StringType()),
        StructField("payload_b", StringType()),
        StructField("updated_at", StringType()),
        StructField("status", StringType()),
        StructField("hash_key", StringType()),
        StructField("extra_a", StringType()),
        StructField("extra_b", StringType()),
        StructField("extra_c", StringType()),
        StructField("extra_d", StringType()),
        StructField("extra_e", StringType()),
        StructField("extra_f", StringType()),
        StructField("extra_g", StringType()),
        StructField("id_list", StringType()),
        StructField("extra_h", StringType()),
        StructField("extra_i", StringType()),
        StructField("tag_list", StringType()),
    ]
)


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="EMR S3 Tables VARIANT shred bug repro.")
    p.add_argument(
        "--case",
        choices=("all", "minimal", "wide"),
        default="all",
        help="minimal=PASS control, wide=FAIL repro (default: all)",
    )
    p.add_argument(
        "--no-shred",
        action="store_true",
        help="Set spark.sql.iceberg.shred-variants=false (wide should pass)",
    )
    p.add_argument("--skip-delete", action="store_true")
    p.add_argument("--catalog", default=CATALOG)
    p.add_argument("--namespace", default=NAMESPACE)
    p.add_argument("--bucket-arn", default=TABLE_BUCKET_ARN)
    p.add_argument("--region", default=AWS_REGION)
    return p.parse_args()


def _table_exists(client, bucket_arn: str, namespace: str, name: str) -> bool:
    try:
        client.get_table(
            tableBucketARN=bucket_arn, namespace=namespace, name=name
        )
        return True
    except ClientError as exc:
        code = exc.response.get("Error", {}).get("Code", "")
        if code in {"NotFoundException", "ResourceNotFoundException"}:
            return False
        raise


def delete_table_if_exists(
    name: str, bucket_arn: str, namespace: str, region: str
) -> None:
    client = boto3.client("s3tables", region_name=region)
    if not _table_exists(client, bucket_arn, namespace, name):
        return
    client.delete_table(
        tableBucketARN=bucket_arn, namespace=namespace, name=name
    )
    print(f"Delete requested: {namespace}.{name}")
    deadline = time.time() + DELETE_TIMEOUT_SEC
    while time.time() < deadline:
        if not _table_exists(client, bucket_arn, namespace, name):
            print(f"Confirmed gone: {namespace}.{name}")
            return
        time.sleep(DELETE_POLL_SEC)
    raise TimeoutError(f"Table {namespace}.{name} still exists after delete")


def sql_create_minimal(catalog: str, namespace: str, table: str) -> str:
    return f"""
    CREATE TABLE `{catalog}`.`{namespace}`.`{table}` (
        id STRING,
        payload_a VARIANT,
        payload_b VARIANT
    ) USING iceberg
    PARTITIONED BY (bucket(16, id))
    TBLPROPERTIES ('format-version' = '3')
    """


def sql_create_wide(catalog: str, namespace: str, table: str) -> str:
    return f"""
    CREATE TABLE `{catalog}`.`{namespace}`.`{table}` (
        id STRING,
        region STRING,
        user_key STRING,
        email STRING,
        payload_a VARIANT,
        payload_b VARIANT,
        updated_at TIMESTAMP,
        status STRING,
        hash_key STRING,
        extra_a VARIANT,
        extra_b VARIANT,
        extra_c VARIANT,
        extra_d VARIANT,
        extra_e VARIANT,
        extra_f VARIANT,
        extra_g VARIANT,
        id_list ARRAY<STRING>,
        extra_h VARIANT,
        extra_i VARIANT,
        tag_list ARRAY<STRING>
    ) USING iceberg
    PARTITIONED BY (bucket(16, id))
    TBLPROPERTIES ('format-version' = '3')
    """


def sql_insert_minimal(catalog: str, namespace: str, table: str) -> str:
    return f"""
    INSERT OVERWRITE `{catalog}`.`{namespace}`.`{table}` (id, payload_a, payload_b)
    SELECT id, parse_json(payload_a), parse_json(payload_b)
    FROM {SOURCE_VIEW}
    """


def sql_insert_wide(catalog: str, namespace: str, table: str) -> str:
    return f"""
    INSERT OVERWRITE `{catalog}`.`{namespace}`.`{table}` (
        id, region, user_key, email, payload_a, payload_b, updated_at,
        status, hash_key, extra_a, extra_b, extra_c, extra_d, extra_e,
        extra_f, extra_g, id_list, extra_h, extra_i, tag_list
    )
    SELECT
        id,
        region,
        user_key,
        email,
        parse_json(payload_a),
        parse_json(payload_b),
        COALESCE(updated_at, CURRENT_TIMESTAMP()),
        status,
        hash_key,
        parse_json(extra_a),
        parse_json(extra_b),
        parse_json(extra_c),
        parse_json(extra_d),
        parse_json(extra_e),
        parse_json(extra_f),
        parse_json(extra_g),
        CASE
            WHEN id_list = '[]' OR id_list IS NULL THEN array()
            ELSE from_json(id_list, 'ARRAY<STRING>')
        END,
        parse_json(extra_h),
        parse_json(extra_i),
        CASE
            WHEN tag_list = '[]' OR tag_list IS NULL THEN array()
            ELSE from_json(tag_list, 'ARRAY<STRING>')
        END
    FROM {SOURCE_VIEW}
    """


def _wide_row(row_id: str) -> tuple[str, ...]:
    return (
        row_id,
        "us-east-1",
        "user-100",
        "user@example.com",
        JSON_A,
        JSON_B,
        None,
        "active",
        "deadbeef",
        JSON_SIMPLE,
        JSON_SIMPLE,
        JSON_SIMPLE,
        JSON_SIMPLE,
        JSON_SIMPLE,
        JSON_SIMPLE,
        JSON_SIMPLE,
        "[]",
        JSON_SIMPLE,
        JSON_SIMPLE,
        "[]",
    )


def _is_shred_bug(exc: BaseException) -> bool:
    msg = str(exc)
    return "ClassCastException" in msg and "is not a group" in msg


def _prepare_table(
    spark: SparkSession,
    args: argparse.Namespace,
    table: str,
    create_sql: str,
) -> None:
    if not args.skip_delete:
        delete_table_if_exists(
            table, args.bucket_arn, args.namespace, args.region
        )
    spark.catalog.clearCache()
    spark.sql(f"CREATE NAMESPACE IF NOT EXISTS `{args.catalog}`.`{args.namespace}`")
    spark.sql(create_sql)


def run_minimal(spark: SparkSession, args: argparse.Namespace) -> None:
    table = "shred_bug_minimal"
    print(f"\n=== minimal (id + 2 VARIANT) -> table {args.namespace}.{table} ===")
    _prepare_table(
        spark, args, table, sql_create_minimal(args.catalog, args.namespace, table)
    )
    rows = [("row-1", JSON_A, JSON_B), ("row-2", JSON_A, JSON_B)]
    spark.createDataFrame(rows, ["id", "payload_a", "payload_b"]).createOrReplaceTempView(
        SOURCE_VIEW
    )
    spark.sql(sql_insert_minimal(args.catalog, args.namespace, table))
    n = spark.table(f"{args.catalog}.{args.namespace}.{table}").count()
    print(f"OK: inserted {n} rows (expected PASS)")


def run_wide(spark: SparkSession, args: argparse.Namespace) -> None:
    table = "shred_bug_wide"
    shred = not args.no_shred
    print(
        f"\n=== wide mixed schema -> table {args.namespace}.{table} "
        f"(shred={shred}, expected {'FAIL' if shred else 'PASS'}) ==="
    )
    _prepare_table(
        spark, args, table, sql_create_wide(args.catalog, args.namespace, table)
    )
    rows = [_wide_row("row-1"), _wide_row("row-2")]
    spark.createDataFrame(rows, WIDE_SOURCE_SCHEMA).createOrReplaceTempView(
        SOURCE_VIEW
    )
    spark.sql(sql_insert_wide(args.catalog, args.namespace, table))
    n = spark.table(f"{args.catalog}.{args.namespace}.{table}").count()
    print(f"OK: inserted {n} rows")


def main() -> int:
    args = parse_args()
    shred = not args.no_shred

    print("EMR S3 Tables VARIANT shred bug repro")
    print(f"  catalog={args.catalog}  namespace={args.namespace}")
    print(f"  bucket={args.bucket_arn}")
    print(f"  spark.sql.iceberg.shred-variants={shred}")

    spark = SparkSession.builder.appName("emr_s3tables_shred_repro").getOrCreate()
    spark.sparkContext.setLogLevel("WARN")
    spark.conf.set("spark.sql.iceberg.shred-variants", str(shred).lower())
    print(f"  Spark version: {spark.version}")

    wide_result = "PASS"
    try:
        if args.case in ("all", "minimal"):
            run_minimal(spark, args)

        if args.case in ("all", "wide"):
            try:
                run_wide(spark, args)
            except Exception as exc:
                if _is_shred_bug(exc):
                    wide_result = "FAIL"
                    print(f"BUG REPRODUCED: {str(exc).split(chr(10))[0]}")
                else:
                    wide_result = "ERROR"
                    print(f"ERROR: {exc}")
                    traceback.print_exc(limit=10)
                    return 1
    finally:
        spark.catalog.dropTempView(SOURCE_VIEW)
        spark.stop()

    print("\n=== SUMMARY ===")
    if args.case in ("all", "minimal"):
        print("  minimal: PASS (expected PASS)")
    if args.case in ("all", "wide"):
        expected = "FAIL" if shred else "PASS"
        ok = wide_result == expected
        print(f"  wide:    {wide_result} (expected {expected})  {'OK' if ok else 'UNEXPECTED'}")
        if wide_result == "FAIL" and shred:
            print("\nBug reproduced on S3 Tables + EMR.")
        if not ok:
            return 1
    return 0


if __name__ == "__main__":
    raise SystemExit(main())

Willingness to contribute

  • I can contribute a fix for this bug independently
  • I would be willing to contribute a fix for this bug with guidance from the Iceberg community
  • I cannot contribute a fix for this bug at this time

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions