diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 249bd33..9c64f51 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -29,10 +29,10 @@ on: pull_request: push: branches: - - master + - main env: - SPARK_VERSION: '4.0.0' + SPARK_VERSION: '4.0.1' HADOOP_VERSION: '3' permissions: @@ -84,7 +84,12 @@ jobs: echo "Apache Spark is not installed" # Access the directory. mkdir -p ~/deps/ - wget -q https://dlcdn.apache.org/spark/spark-${{ env.SPARK_VERSION }}/spark-${{ env.SPARK_VERSION }}-bin-hadoop${{ env.HADOOP_VERSION }}.tgz + # dlcdn.apache.org only keeps current releases on its mirrors and + # occasionally 404s on older ones. archive.apache.org is the + # canonical mirror and never rotates — use it as a fallback. + ARCHIVE=spark-${{ env.SPARK_VERSION }}-bin-hadoop${{ env.HADOOP_VERSION }}.tgz + wget -q https://dlcdn.apache.org/spark/spark-${{ env.SPARK_VERSION }}/$ARCHIVE || \ + wget -q https://archive.apache.org/dist/spark/spark-${{ env.SPARK_VERSION }}/$ARCHIVE tar -xzf spark-${{ env.SPARK_VERSION }}-bin-hadoop${{ env.HADOOP_VERSION }}.tgz -C ~/deps/ # Delete the old file rm spark-${{ env.SPARK_VERSION }}-bin-hadoop${{ env.HADOOP_VERSION }}.tgz diff --git a/README.md b/README.md index 51d674f..6716843 100644 --- a/README.md +++ b/README.md @@ -1,76 +1,151 @@ -# Apache Spark Connect Client for Golang +# spark-connect-go -This project houses the **experimental** client for [Spark -Connect](https://spark.apache.org/docs/latest/spark-connect-overview.html) for -[Apache Spark](https://spark.apache.org/) written in [Golang](https://go.dev/). +> A maintained fork of [`apache/spark-connect-go`](https://github.com/apache/spark-connect-go) with a `database/sql` driver, edge-typed DataFrames, exposed gRPC dial options, and a typed `ClusterNotReady` error. Tracks upstream; deltas are queued to upstream. -## Current State of the Project +Spark Connect is Spark's [language-neutral gRPC protocol](https://spark.apache.org/docs/latest/spark-connect-overview.html). The upstream Go client is the official reference implementation. This fork carries the deltas needed for production usage while those patches work their way upstream — drop in by swapping the import path; the session API, DataFrame surface, and protobuf stubs are unchanged. -Currently, the Spark Connect client for Golang is highly experimental and should -not be used in any production setting. In addition, the PMC of the Apache Spark -project reserves the right to withdraw and abandon the development of this project -if it is not sustainable. +## What's added -## Getting started +- **`database/sql` driver.** `sql.Open("spark", "sc://host:port")` works with goose, sqlc-generated code, pgx-style consumers — anything that speaks `database/sql`. Registered under the name `spark` in `spark/sql/driver`. `$N` positional placeholders are rendered client-side into Spark SQL literals (the native parameter proto isn't reliable across every supported Spark version). +- **Edge-typed DataFrames.** `As[T](df) → *DataFrameOf[T]` caches a reflected row plan once; `Collect`, `Stream`, `First` materialise into struct types at the point you know the result shape. Top-level `Collect[T] / Stream[T] / First[T]` helpers do the `As[T]` plus the call in one shot. +- **`SparkSessionBuilder.WithDialOptions`.** gRPC dial options exposed on the builder — auth interceptors, TLS, observability handlers wire in without subclassing. +- **`sparkerrors.IsClusterNotReady(err)`.** Typed error for cluster cold-start states. Databricks serverless clusters take 30-90s to warm; retry logic upstack needs a reliable signal instead of string-matching on error messages. -This section explains how to run Spark Connect Go locally. +Every delta is tracked as a PR queued for `apache/spark-connect-go`. When a delta lands upstream we drop it from the fork. Long-term goal is zero deltas. -Step 1: Install Golang: https://go.dev/doc/install. +## Install -Step 2: Ensure you have installed `buf CLI` installed, [more info here](https://buf.build/docs/installation/) +```bash +go get github.com/datalake-go/spark-connect-go +``` -Step 3: Run the following commands to setup the Spark Connect client. +Requires a Spark Connect server (Spark 3.4+). -Building with Spark in case you need to re-generate the source files from the proto sources. +## Quick start -``` -git clone https://github.com/apache/spark-connect-go.git -git submodule update --init --recursive +```go +import ( + sparksql "github.com/datalake-go/spark-connect-go/spark/sql" +) -make gen && make test +session, err := sparksql.NewSessionBuilder(). + Remote("sc://spark.internal:15002"). + Build(ctx) +if err != nil { /* ... */ } +defer session.Stop() +df, _ := session.Sql(ctx, "SELECT id, email FROM users WHERE tier = 'gold'") +_ = df.Show(ctx, 20, false) ``` -Building without Spark +The `sparksql` alias avoids collision with stdlib `database/sql` — the actual package name is `sql`. +### Using DataFrames + +The untyped `DataFrame` is the building block — same surface as upstream. Transformations (`Where`, `Limit`, `OrderBy`, `Select`, `Join`, `GroupBy`) compose lazily and execute on the Spark side; materialisers (`Show`, `Collect`, `First`, `Count`) round-trip and return `[]types.Row`. + +```go +df, _ := session.Sql(ctx, "SELECT id, email, created_at FROM users") + +filtered, _ := df.Where(ctx, "tier = 'gold'") +top, _ := filtered.OrderBy(ctx, "created_at DESC").Limit(ctx, 100) + +rows, _ := top.Collect(ctx) +for _, r := range rows { + // r is types.Row — positional access by index or by name +} ``` -git clone https://github.com/apache/spark-connect-go.git -make && make test -``` -Step 4: Setup the Spark Driver on localhost. +Use this when the result shape is dynamic, or as the composition surface that you eventually re-type at the edge. + +### Using Typed DataFrames + +`As[T](df) → *DataFrameOf[T]` is the typed surface. It binds a result shape to a struct, caches the reflected row plan once, and materialises into `[]T` / `*T` without re-validating on every call. -1. [Download Spark distribution](https://spark.apache.org/downloads.html) (4.0.0+), unzip the package. +```go +type User struct { + ID string `spark:"id"` + Email string `spark:"email"` + Created time.Time `spark:"created_at"` +} + +df, _ := session.Sql(ctx, "SELECT id, email, created_at FROM users WHERE tier = 'gold'") +typed, _ := sparksql.As[User](df) + +users, _ := typed.Collect(ctx) +alice, err := typed.First(ctx) +if errors.Is(err, sparksql.ErrNotFound) { /* zero rows */ } +``` -2. Start the Spark Connect server with the following command (make sure to use a package version that matches your Spark distribution): +If you only need the result once, `Collect[T] / First[T] / Stream[T]` are top-level helpers that fold `As[T]` into the call: +```go +users, _ := sparksql.Collect[User](ctx, df) ``` -sbin/start-connect-server.sh + +Untagged fields map by snake_case'd field name, so plain Go structs work without tags. `spark:"-"` skips a field. `*DataFrameOf[T]` deliberately has no transformation methods — `Where` / `Limit` / `Select` / `Join` change the row shape and would make `T` lie. Compose on the untyped `DataFrame`, then re-type at the edge: + +```go +typed, _ := sparksql.As[User](df) +narrower, _ := typed.DataFrame().Select(ctx, "id", "email") // back to untyped +ids, _ := sparksql.Collect[struct{ ID string `spark:"id"` }](ctx, narrower) ``` -Step 5: Run the example Go application. +### Streaming Results +`Stream[T]` returns a Go 1.23 [`iter.Seq2[T, error]`](https://pkg.go.dev/iter#Seq2). One of the things Go gives us over the Python / Scala clients is a real pull-based iterator — rows decode one at a time as the gRPC stream resolves them, with constant memory regardless of result size. No need to buffer the whole result, no callback API: just `range`. + +```go +for row, err := range sparksql.Stream[User](ctx, df) { + if err != nil { break } + // use row — decoded from the next Arrow batch as it lands +} ``` -go run cmd/spark-connect-example-spark-session/main.go + +Schema binding happens on the first row; if a later row's schema diverges from the first, the error surfaces through the iterator (no per-row panics). + +Use `Stream[T]` when result sets are large, when you want to short-circuit early without dragging the rest of the rows over the wire, or when you're piping into another `iter.Seq2` consumer. + +### `database/sql` driver + +```go +import ( + "database/sql" + _ "github.com/datalake-go/spark-connect-go/spark/sql/driver" +) + +db, _ := sql.Open("spark", "sc://spark.internal:15002") +rows, _ := db.QueryContext(ctx, "SELECT id FROM users WHERE tier = $1", "gold") ``` -## Runnning Spark Connect Go Application in a Spark Cluster +`$N` placeholders render with type-aware quoting (strings, numbers, bools, `[]byte`, `time.Time`). `?` placeholders aren't supported — most `database/sql`-adjacent codegen (sqlc, goose dialects, pgx patterns) emits `$N`, so the narrower grammar keeps the renderer simple. -To run the Spark Connect Go application in a Spark Cluster, you need to build the Go application and submit it to the Spark Cluster. You can find a more detailed example runner and wrapper script in the `java` directory. +### Cluster cold-start -See the guide here: [Sample Spark-Submit Wrapper](java/README.md). +```go +import "github.com/datalake-go/spark-connect-go/spark/sparkerrors" -## How to write Spark Connect Go Application in your own project +df, err := session.Sql(ctx, query) +if sparkerrors.IsClusterNotReady(err) { + // retry with backoff — Databricks serverless usually warms in 30-90s +} +``` + +## Building from source -See [Quick Start Guide](quick-start.md) +```bash +git clone https://github.com/datalake-go/spark-connect-go.git +cd spark-connect-go +make && make test +``` -## High Level Design +Regenerating protobuf stubs from the Spark submodule: -The overall goal of the design is to find a good balance of principle of the least surprise for -develoeprs that are familiar with the APIs of Apache Spark and idiomatic Go usage. The high-level -structure of the packages follows roughly the PySpark giudance but with Go idioms. +```bash +git submodule update --init --recursive +make gen && make test +``` ## Contributing -Please review the [Contribution to Spark guide](https://spark.apache.org/contributing.html) -for information on how to get started contributing to the project. +Feature work that could land upstream should be proposed against [`apache/spark-connect-go`](https://github.com/apache/spark-connect-go) first. Fork-only changes (anything that wouldn't be accepted upstream) stay on this tree. See [CONTRIBUTING.md](CONTRIBUTING.md). diff --git a/cmd/spark-connect-example-raw-grpc-client/main.go b/cmd/spark-connect-example-raw-grpc-client/main.go index 1f463db..e22f051 100644 --- a/cmd/spark-connect-example-raw-grpc-client/main.go +++ b/cmd/spark-connect-example-raw-grpc-client/main.go @@ -22,7 +22,7 @@ import ( "log" "time" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" "github.com/google/uuid" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" diff --git a/cmd/spark-connect-example-spark-session/main.go b/cmd/spark-connect-example-spark-session/main.go index ec720dc..aef6640 100644 --- a/cmd/spark-connect-example-spark-session/main.go +++ b/cmd/spark-connect-example-spark-session/main.go @@ -22,12 +22,12 @@ import ( "fmt" "log" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - "github.com/apache/spark-connect-go/spark/sql/functions" + "github.com/datalake-go/spark-connect-go/spark/sql/functions" - "github.com/apache/spark-connect-go/spark/sql" - "github.com/apache/spark-connect-go/spark/sql/utils" + "github.com/datalake-go/spark-connect-go/spark/sql" + "github.com/datalake-go/spark-connect-go/spark/sql/utils" ) var ( diff --git a/go.mod b/go.mod index 3478ddb..112cb00 100644 --- a/go.mod +++ b/go.mod @@ -13,9 +13,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -module github.com/apache/spark-connect-go +module github.com/datalake-go/spark-connect-go -go 1.23.2 +go 1.24 require ( github.com/apache/arrow-go/v18 v18.4.0 diff --git a/internal/tests/integration/dataframe_test.go b/internal/tests/integration/dataframe_test.go index d383ca1..2fb26bc 100644 --- a/internal/tests/integration/dataframe_test.go +++ b/internal/tests/integration/dataframe_test.go @@ -21,15 +21,15 @@ import ( "os" "testing" - "github.com/apache/spark-connect-go/spark/sql/utils" + "github.com/datalake-go/spark-connect-go/spark/sql/utils" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - "github.com/apache/spark-connect-go/spark/sql/column" + "github.com/datalake-go/spark-connect-go/spark/sql/column" - "github.com/apache/spark-connect-go/spark/sql/functions" + "github.com/datalake-go/spark-connect-go/spark/sql/functions" - "github.com/apache/spark-connect-go/spark/sql" + "github.com/datalake-go/spark-connect-go/spark/sql" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -1196,7 +1196,6 @@ func TestDataFrame_RangeIter(t *testing.T) { } assert.Equal(t, 10, cnt) - // Check that errors are properly propagated df, err = spark.Sql(ctx, "select if(id = 5, raise_error('handle'), false) from range(10)") assert.NoError(t, err) for _, err := range df.All(ctx) { @@ -1224,3 +1223,197 @@ func TestDataFrame_SchemaTreeString(t *testing.T) { assert.Contains(t, ts, "|-- second: array") assert.Contains(t, ts, "|-- third: map") } + +func TestDataFrame_StreamRowsThroughChannel(t *testing.T) { + // Demonstrates how StreamRows can be used to pipe data through a channel for scenarios like: + // - Proxying Spark data through gRPC streaming or unary RPCs + // - Implementing producer-consumer patterns with backpressure based on Spark results + // - Buffering and rate-limiting data flow between systems + ctx, spark := connect() + df, err := spark.Sql(ctx, "select id, id * 2 as doubled, 'test_' || cast(id as string) as label from range(100)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + rowChan := make(chan map[string]interface{}, 10) + errChan := make(chan error, 1) + + go func() { + defer close(rowChan) + for row, err := range iter { + if err != nil { + errChan <- err + return + } + + rowData := make(map[string]interface{}) + names := row.FieldNames() + for i, name := range names { + rowData[name] = row.At(i) + } + + select { + case rowChan <- rowData: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } + } + }() + + // In a gRPC scenario, this would be your response handler + receivedRows := make([]map[string]interface{}, 0) + consumerDone := make(chan struct{}) + + go func() { + defer close(consumerDone) + for rowData := range rowChan { + receivedRows = append(receivedRows, rowData) + + id := rowData["id"].(int64) + doubled := rowData["doubled"].(int64) + assert.Equal(t, id*2, doubled) + } + }() + + <-consumerDone + + select { + case err := <-errChan: + assert.NoError(t, err) + default: + // continue + } + + assert.Equal(t, 100, len(receivedRows)) + + assert.Equal(t, int64(0), receivedRows[0]["id"]) + assert.Equal(t, int64(99), receivedRows[99]["id"]) + assert.Equal(t, "test_0", receivedRows[0]["label"]) + assert.Equal(t, "test_99", receivedRows[99]["label"]) +} + +func TestDataFrame_StreamRows(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select * from range(100)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + assert.NotNil(t, iter) + + cnt := 0 + + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + assert.Equal(t, 1, row.Len()) + cnt++ + } + assert.Equal(t, 100, cnt) +} + +func TestDataFrame_StreamRowsWithFilter(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select * from range(100)") + assert.NoError(t, err) + + df, err = df.Filter(ctx, functions.Col("id").Lt(functions.IntLit(10))) + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + cnt := 0 + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + assert.Equal(t, 1, row.Len()) + assert.Less(t, row.At(0).(int64), int64(10)) + cnt++ + } + assert.Equal(t, 10, cnt) +} + +func TestDataFrame_StreamRowsEmpty(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select * from range(0)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + cnt := 0 + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + cnt++ + } + assert.Equal(t, 0, cnt) +} + +func TestDataFrame_StreamRowsWithError(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select if(id = 5, raise_error('test error'), id) as id from range(10)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + errorEncountered := false + for _, err := range iter { + if err != nil { + errorEncountered = true + assert.Error(t, err) + break + } + } + assert.True(t, errorEncountered, "Expected to encounter an error during iteration") +} + +func TestDataFrame_StreamRowsMultipleColumns(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select id, id * 2 as doubled, 'test' as name from range(50)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + cnt := 0 + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + assert.Equal(t, 3, row.Len()) + + id := row.At(0).(int64) + doubled := row.At(1).(int64) + name := row.At(2).(string) + + assert.Equal(t, id*2, doubled) + assert.Equal(t, "test", name) + cnt++ + } + assert.Equal(t, 50, cnt) +} + +func TestDataFrame_StreamRowsLargeDataset(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select * from range(10000)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + cnt := 0 + lastValue := int64(-1) + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + currentValue := row.At(0).(int64) + assert.Greater(t, currentValue, lastValue) + lastValue = currentValue + cnt++ + } + assert.Equal(t, 10000, cnt) +} diff --git a/internal/tests/integration/functions_test.go b/internal/tests/integration/functions_test.go index 94310d7..287f45b 100644 --- a/internal/tests/integration/functions_test.go +++ b/internal/tests/integration/functions_test.go @@ -19,11 +19,11 @@ import ( "context" "testing" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - "github.com/apache/spark-connect-go/spark/sql/functions" + "github.com/datalake-go/spark-connect-go/spark/sql/functions" - "github.com/apache/spark-connect-go/spark/sql" + "github.com/datalake-go/spark-connect-go/spark/sql" "github.com/stretchr/testify/assert" ) diff --git a/internal/tests/integration/helper.go b/internal/tests/integration/helper.go index 902d223..28086de 100644 --- a/internal/tests/integration/helper.go +++ b/internal/tests/integration/helper.go @@ -22,7 +22,7 @@ import ( "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/apache/spark-connect-go/spark/sql" + "github.com/datalake-go/spark-connect-go/spark/sql" ) func connect() (context.Context, sql.SparkSession) { diff --git a/internal/tests/integration/spark_runner.go b/internal/tests/integration/spark_runner.go index b6cb688..6398a56 100644 --- a/internal/tests/integration/spark_runner.go +++ b/internal/tests/integration/spark_runner.go @@ -23,7 +23,7 @@ import ( "os/exec" "time" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) func StartSparkConnect() (int64, error) { diff --git a/internal/tests/integration/sql_test.go b/internal/tests/integration/sql_test.go index a0e0493..11ed17b 100644 --- a/internal/tests/integration/sql_test.go +++ b/internal/tests/integration/sql_test.go @@ -22,13 +22,13 @@ import ( "os" "testing" - "github.com/apache/spark-connect-go/spark/sql/column" + "github.com/datalake-go/spark-connect-go/spark/sql/column" - "github.com/apache/spark-connect-go/spark/sql/functions" + "github.com/datalake-go/spark-connect-go/spark/sql/functions" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - "github.com/apache/spark-connect-go/spark/sql" + "github.com/datalake-go/spark-connect-go/spark/sql" "github.com/stretchr/testify/assert" ) diff --git a/quick-start.md b/quick-start.md index 7382107..221c82a 100644 --- a/quick-start.md +++ b/quick-start.md @@ -5,7 +5,7 @@ In your Go project `go.mod` file, add `spark-connect-go` library: ``` require ( - github.com/apache/spark-connect-go master + github.com/datalake-go/spark-connect-go master ) ``` @@ -23,7 +23,7 @@ import ( "fmt" "log" - "github.com/apache/spark-connect-go/spark/sql" + "github.com/datalake-go/spark-connect-go/spark/sql" ) var ( diff --git a/spark/client/base/base.go b/spark/client/base/base.go index c5f01c7..ef9e77d 100644 --- a/spark/client/base/base.go +++ b/spark/client/base/base.go @@ -17,12 +17,13 @@ package base import ( "context" + "iter" - "github.com/apache/spark-connect-go/spark/sql/utils" + "github.com/datalake-go/spark-connect-go/spark/sql/utils" "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) type SparkConnectRPCClient generated.SparkConnectServiceClient @@ -47,6 +48,9 @@ type SparkConnectClient interface { } type ExecuteResponseStream interface { + // ToTable consumes all arrow.Record batches to a single arrow.Table. Useful for collecting all query results into a client DF. ToTable() (*types.StructType, arrow.Table, error) + // ToRecordSequence lazily consumes each arrow.Record retrieved by a query. Useful for streaming query results. + ToRecordSequence(ctx context.Context) iter.Seq2[arrow.Record, error] Properties() map[string]any } diff --git a/spark/client/channel/channel.go b/spark/client/channel/channel.go index 6403566..6abe1b8 100644 --- a/spark/client/channel/channel.go +++ b/spark/client/channel/channel.go @@ -29,13 +29,13 @@ import ( "strconv" "strings" - "github.com/apache/spark-connect-go/spark" + "github.com/datalake-go/spark-connect-go/spark" "github.com/google/uuid" "google.golang.org/grpc/credentials/insecure" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" "golang.org/x/oauth2" "google.golang.org/grpc" "google.golang.org/grpc/credentials" @@ -77,6 +77,28 @@ type BaseBuilder struct { headers map[string]string sessionId string userAgent string + dialOpts []grpc.DialOption +} + +// defaultMaxMessageSize is the per-RPC send/receive ceiling we apply +// when the caller doesn't set one. The server advertises +// spark.connect.grpc.maxInboundMessageSize at 128 MiB by default and +// operators routinely raise it to 1 GiB for bulk reads. gRPC's own +// default is 4 MiB, which silently truncates a single Arrow batch on +// any non-trivial query and surfaces as opaque ResourceExhausted +// errors deep inside Collect()/StreamRows(). Matching the upper end +// of the server's practical range keeps first-run queries from +// hitting a floor the caller didn't know existed. +const defaultMaxMessageSize = 1 << 30 // 1 GiB + +// WithDialOptions appends gRPC dial options to the channel builder. +// Callers use this to raise the per-call message ceiling further, set +// keepalive parameters for long-lived streams, inject interceptors, +// or swap in a custom dialer. Options are applied after the builder's +// defaults, so caller-supplied values win. +func (cb *BaseBuilder) WithDialOptions(opts ...grpc.DialOption) *BaseBuilder { + cb.dialOpts = append(cb.dialOpts, opts...) + return cb } func (cb *BaseBuilder) Host() string { @@ -114,6 +136,15 @@ func (cb *BaseBuilder) Build(ctx context.Context) (*grpc.ClientConn, error) { var opts []grpc.DialOption opts = append(opts, grpc.WithAuthority(cb.host)) + + // Raise the per-call send/receive ceiling off gRPC's 4 MiB default. + // Placed before WithDialOptions so caller-supplied DefaultCallOptions + // can still override if they want a tighter limit. + opts = append(opts, grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(defaultMaxMessageSize), + grpc.MaxCallSendMsgSize(defaultMaxMessageSize), + )) + if cb.token == "" { opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials())) } else { @@ -134,6 +165,9 @@ func (cb *BaseBuilder) Build(ctx context.Context) (*grpc.ClientConn, error) { opts = append(opts, grpc.WithPerRPCCredentials(oauth.TokenSource{TokenSource: ts})) } + // Caller overrides come last so they win over the defaults above. + opts = append(opts, cb.dialOpts...) + remote := fmt.Sprintf("%v:%v", cb.host, cb.port) conn, err := grpc.NewClient(remote, opts...) if err != nil { diff --git a/spark/client/channel/channel_test.go b/spark/client/channel/channel_test.go index b0f7bea..b33084d 100644 --- a/spark/client/channel/channel_test.go +++ b/spark/client/channel/channel_test.go @@ -22,9 +22,10 @@ import ( "testing" "github.com/google/uuid" + "google.golang.org/grpc" - "github.com/apache/spark-connect-go/spark/client/channel" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/client/channel" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" ) @@ -112,3 +113,28 @@ func TestChannelBulder_UserAgent(t *testing.T) { assert.True(t, strings.Contains(cb.UserAgent(), "spark/")) assert.True(t, strings.Contains(cb.UserAgent(), "os/")) } + +// TestChannelBuilder_WithDialOptions_CompilesAndBuilds is a smoke +// test: the builder accepts caller-supplied grpc.DialOption values, +// stores them, and a subsequent Build produces a live ClientConn +// without panicking. We assert the returned conn is non-nil rather +// than exercising the wire — grpc.NewClient is lazy and doesn't dial +// until the first RPC, so a stricter assertion would spin up a +// server here. +func TestChannelBuilder_WithDialOptions_CompilesAndBuilds(t *testing.T) { + cb, err := channel.NewBuilder("sc://localhost") + assert.NoError(t, err) + + // Representative knobs: raise the per-call ceiling further than + // the builder's own default, and give the connection a keepalive + // profile. Neither changes observable behaviour in this unit + // test, but both have to survive the builder round-trip. + cb.WithDialOptions( + grpc.WithDefaultCallOptions(grpc.MaxCallRecvMsgSize(2<<30)), + grpc.WithUserAgent("dorm-test"), + ) + + conn, err := cb.Build(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, conn) +} diff --git a/spark/client/client.go b/spark/client/client.go index a0d7a4c..d6b577d 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -20,25 +20,26 @@ import ( "errors" "fmt" "io" + "iter" - "github.com/apache/spark-connect-go/spark/sql/utils" + "github.com/datalake-go/spark-connect-go/spark/sql/utils" "google.golang.org/grpc" "google.golang.org/grpc/metadata" - "github.com/apache/spark-connect-go/spark/client/base" - "github.com/apache/spark-connect-go/spark/mocks" + "github.com/datalake-go/spark-connect-go/spark/client/base" + "github.com/datalake-go/spark-connect-go/spark/mocks" - "github.com/apache/spark-connect-go/spark/client/options" + "github.com/datalake-go/spark-connect-go/spark/client/options" "github.com/google/uuid" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) type sparkConnectClientImpl struct { @@ -434,6 +435,103 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { } } +// ToRecordSequence returns a single Seq2 iterator that directly yields results as they arrive. +func (c *ExecutePlanClient) ToRecordSequence(ctx context.Context) iter.Seq2[arrow.Record, error] { + return func(yield func(arrow.Record, error) bool) { + // Represents Spark's reattachable execution. + // Tracks logical completion locally to avoid racing on shared struct state. + // Spliced from ToTable. We may eventually want to DRY up these workflows. + done := false + + for { + select { + case <-ctx.Done(): + yield(nil, ctx.Err()) + return + default: + } + + resp, err := c.responseStream.Recv() + + select { + case <-ctx.Done(): + yield(nil, ctx.Err()) + return + default: + } + + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + if se := sparkerrors.FromRPCError(err); se != nil { + yield(nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)) + } else { + yield(nil, err) + } + return + } + + if resp == nil { + continue + } + + if resp.GetSessionId() != c.sessionId { + yield(nil, sparkerrors.WithType( + &sparkerrors.InvalidServerSideSessionDetailsError{ + OwnSessionId: c.sessionId, + ReceivedSessionId: resp.GetSessionId(), + }, sparkerrors.InvalidServerSideSessionError)) + return + } + + if resp.Schema != nil { + var schemaErr error + c.schema, schemaErr = types.ConvertProtoDataTypeToStructType(resp.Schema) + if schemaErr != nil { + yield(nil, sparkerrors.WithType(schemaErr, sparkerrors.ExecutionError)) + return + } + } + + switch x := resp.ResponseType.(type) { + case *proto.ExecutePlanResponse_SqlCommandResult_: + if val := x.SqlCommandResult.GetRelation(); val != nil { + c.properties["sql_command_result"] = val + } + + case *proto.ExecutePlanResponse_ArrowBatch_: + record, err := types.ReadArrowBatchToRecord(x.ArrowBatch.Data, c.schema) + if err != nil { + yield(nil, err) + return + } + if !yield(record, nil) { + return + } + + case *proto.ExecutePlanResponse_ResultComplete_: + done = true + return + + case *proto.ExecutePlanResponse_ExecutionProgress_: + // Progress updates - ignore for now + + default: + // Explicitly ignore unknown message types + } + } + + // Check that the result is logically complete. With re-attachable execution + // the server may interrupt the connection, and we need a ResultComplete + // message to confirm the full result was received. + if c.opts.ReattachExecution && !done { + yield(nil, sparkerrors.WithType(fmt.Errorf("the result is not complete"), sparkerrors.ExecutionError)) + } + } +} + func NewExecuteResponseStream( responseClient proto.SparkConnectService_ExecutePlanClient, sessionId string, diff --git a/spark/client/client_test.go b/spark/client/client_test.go index ca7019c..8b6f4c6 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -1,110 +1,597 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package client_test import ( + "bytes" "context" + "errors" + "iter" "testing" + "time" - "github.com/google/uuid" - - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/client" - "github.com/apache/spark-connect-go/spark/client/testutils" - "github.com/apache/spark-connect-go/spark/mocks" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apache/arrow-go/v18/arrow/memory" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/client" + "github.com/datalake-go/spark-connect-go/spark/mocks" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) { +func TestToRecordIterator_ChannelClosureWithoutData(t *testing.T) { + // Iterator should complete without yielding any records when no arrow batches present ctx := context.Background() - response := &proto.AnalyzePlanResponse{} - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(nil, response, nil, nil), nil, mocks.MockSessionId) - resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) - assert.NoError(t, err) - assert.NotNil(t, resp) + + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "test_column", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + + recordsReceived := 0 + errorsReceived := 0 + + for record, err := range iter { + if err != nil { + errorsReceived++ + break + } + if record != nil { + recordsReceived++ + } + } + + assert.Equal(t, 0, recordsReceived, "No records should be sent when no arrow batches present") + assert.Equal(t, 0, errorsReceived, "No errors should occur") } -func TestAnalyzePlanFailsIfClientFails(t *testing.T) { +func TestToRecordIterator_ArrowBatchStreaming(t *testing.T) { + // Arrow batch data should be correctly streamed ctx := context.Background() - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(nil, nil, assert.AnError, nil), nil, mocks.MockSessionId) - resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) - assert.Nil(t, resp) - assert.Error(t, err) + + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + + arrowData := createTestArrowBatch(t, []string{"value1", "value2", "value3"}) + + arrowBatch := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: arrowData, + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + arrowBatch, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + + records := collectRecordsFromSeq2(t, iter) + + require.Len(t, records, 1, "Should receive exactly one record") + + record := records[0] + assert.Equal(t, int64(3), record.NumRows(), "Record should have 3 rows") + assert.Equal(t, int64(1), record.NumCols(), "Record should have 1 column") + + col := record.Column(0).(*array.String) + assert.Equal(t, "value1", col.Value(0)) + assert.Equal(t, "value2", col.Value(1)) + assert.Equal(t, "value3", col.Value(2)) } -func TestExecutePlanCallsExecutePlanOnClient(t *testing.T) { +func TestToRecordIterator_MultipleArrowBatches(t *testing.T) { + // Multiple arrow batches should be streamed in order ctx := context.Background() - plan := &proto.Plan{} - // Generate a mock client - responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone) + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + + batch1 := createTestArrowBatch(t, []string{"batch1_row1", "batch1_row2"}) + batch2 := createTestArrowBatch(t, []string{"batch2_row1", "batch2_row2"}) + + arrowBatch1 := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: batch1, + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + } + + arrowBatch2 := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: batch2, + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + arrowBatch1, + arrowBatch2, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + records := collectRecordsFromSeq2(t, iter) + + require.Len(t, records, 2, "Should receive exactly two records") + + // Verify first batch + col1 := records[0].Column(0).(*array.String) + assert.Equal(t, "batch1_row1", col1.Value(0)) + assert.Equal(t, "batch1_row2", col1.Value(1)) + + // Verify second batch + col2 := records[1].Column(0).(*array.String) + assert.Equal(t, "batch2_row1", col2.Value(0)) + assert.Equal(t, "batch2_row2", col2.Value(1)) +} + +func TestToRecordIterator_ContextCancellationStopsStreaming(t *testing.T) { + // Context cancellation should stop streaming + ctx, cancel := context.WithCancel(context.Background()) + + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col0", + DataType: &proto.DataType{ + Kind: &proto.DataType_Integer_{ + Integer: &proto.DataType_Integer{}, + }, + }, + Nullable: true, + }, + }, + }, + }, + }, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + + // Cancel the context immediately + cancel() - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, mocks.MockSessionId) - resp, err := c.ExecutePlan(ctx, plan) - assert.NoError(t, err) - assert.NotNil(t, resp) + // Try to consume the iterator + timeout := time.After(100 * time.Millisecond) + done := make(chan bool) + + go func() { + for _, err := range iter { + if err != nil { + // Got an error - verify it's context cancellation + assert.ErrorIs(t, err, context.Canceled) + done <- true + return + } + } + done <- true + }() + + select { + case <-done: + // Good - iteration completed + case <-timeout: + // Timeout is acceptable as cancellation might have happened after all responses were processed + } } -func TestExecutePlanCallsExecuteCommandOnClient(t *testing.T) { +func TestToRecordIterator_RPCErrorPropagation(t *testing.T) { + // RPC errors should be properly propagated ctx := context.Background() - plan := &proto.Plan{} - - // Generate a mock client - responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) - - // Check that the execution fails if no command is supplied. - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, mocks.MockSessionId) - _, _, _, err := c.ExecuteCommand(ctx, plan) - assert.ErrorIs(t, err, sparkerrors.ExecutionError) - - // Generate a command and the execution should succeed. - sqlCommand := mocks.NewSqlCommand("select range(10)") - c = client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, mocks.MockSessionId) - _, _, _, err = c.ExecuteCommand(ctx, sqlCommand) - assert.NoError(t, err) + + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col1", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + + expectedError := errors.New("simulated RPC error") + errorResponse := &mocks.MockResponse{ + Err: expectedError, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + errorResponse) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + + errorReceived := false + for _, err := range iter { + if err != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), "simulated RPC error") + errorReceived = true + break + } + } + + assert.True(t, errorReceived, "Expected RPC error") } -func Test_ExecuteWithWrongSession(t *testing.T) { +func TestToRecordIterator_SessionValidation(t *testing.T) { + // Session validation error should be returned for wrong session ID ctx := context.Background() - sqlCommand := mocks.NewSqlCommand("select range(10)") - // Generate a mock client - responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) + wrongSessionResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: "wrong-session-id", + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col0", + DataType: &proto.DataType{ + Kind: &proto.DataType_Integer_{ + Integer: &proto.DataType_Integer{}, + }, + }, + Nullable: true, + }, + }, + }, + }, + }, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + wrongSessionResponse, + &mocks.ExecutePlanResponseEOF) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + + errorReceived := false + for _, err := range iter { + if err != nil { + assert.Error(t, err) + assert.ErrorIs(t, err, sparkerrors.InvalidServerSideSessionError) + errorReceived = true + break + } + } - // Check that the execution fails if no command is supplied. - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, uuid.NewString()) - _, _, _, err := c.ExecuteCommand(ctx, sqlCommand) - assert.ErrorIs(t, err, sparkerrors.InvalidServerSideSessionError) + assert.True(t, errorReceived, "Expected session validation error") } -func Test_Execute_SchemaParsingFails(t *testing.T) { +func TestToRecordIterator_SqlCommandResultProperties(t *testing.T) { + // SQL command results should be captured in properties ctx := context.Background() - sqlCommand := mocks.NewSqlCommand("select range(10)") - responseStream := mocks.NewProtoClientMock( - &mocks.ExecutePlanResponseBrokenSchema, + + sqlResultResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{ + Relation: &proto.Relation{ + RelType: &proto.Relation_Sql{ + Sql: &proto.SQL{Query: "test query"}, + }, + }, + }, + }, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + sqlResultResponse, &mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, mocks.MockSessionId) - _, _, _, err := c.ExecuteCommand(ctx, sqlCommand) - assert.ErrorIs(t, err, sparkerrors.ExecutionError) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("test query")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + _ = collectRecordsFromSeq2(t, iter) + + // Properties should contain the SQL command result + props := stream.(*client.ExecutePlanClient).Properties() + assert.NotNil(t, props["sql_command_result"]) +} + +func TestToRecordIterator_MixedResponseTypes(t *testing.T) { + // Mixed response types should be handled correctly in realistic order + ctx := context.Background() + + responses := []*mocks.MockResponse{ + // Schema first + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "id", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + // SQL command result + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{ + Relation: &proto.Relation{ + RelType: &proto.Relation_Sql{ + Sql: &proto.SQL{Query: "SELECT * FROM table"}, + }, + }, + }, + }, + }, + }, + // Progress updates + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_ExecutionProgress_{ + ExecutionProgress: &proto.ExecutePlanResponse_ExecutionProgress{ + Stages: nil, + NumInflightTasks: 0, + }, + }, + }, + }, + // Arrow batch + { + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: createTestArrowBatch(t, []string{"row1"}), + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + }, + // More progress + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_ExecutionProgress_{ + ExecutionProgress: &proto.ExecutePlanResponse_ExecutionProgress{ + Stages: nil, + NumInflightTasks: 0, + }, + }, + }, + }, + // Another arrow batch + { + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: createTestArrowBatch(t, []string{"row2", "row3"}), + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + }, + // Result complete + &mocks.ExecutePlanResponseDone, + // EOF + &mocks.ExecutePlanResponseEOF, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, responses...) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("SELECT * FROM table")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + records := collectRecordsFromSeq2(t, iter) + + require.Len(t, records, 2, "Should receive exactly two arrow batches") + + assert.Equal(t, int64(1), records[0].NumRows()) + assert.Equal(t, int64(2), records[1].NumRows()) +} + +// Helper function to create test arrow batch data +func createTestArrowBatch(t *testing.T, values []string) []byte { + t.Helper() + + arrowFields := []arrow.Field{ + {Name: "col", Type: arrow.BinaryTypes.String}, + } + arrowSchema := arrow.NewSchema(arrowFields, nil) + + alloc := memory.NewGoAllocator() + recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) + defer recordBuilder.Release() + + stringBuilder := recordBuilder.Field(0).(*array.StringBuilder) + for _, v := range values { + stringBuilder.Append(v) + } + + record := recordBuilder.NewRecord() + defer record.Release() + + var buf bytes.Buffer + arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) + defer arrowWriter.Close() + + err := arrowWriter.Write(record) + require.NoError(t, err) + err = arrowWriter.Close() + require.NoError(t, err) + + return buf.Bytes() +} + +// Helper function to collect all records from Seq2 iterator +func collectRecordsFromSeq2(t *testing.T, iter iter.Seq2[arrow.Record, error]) []arrow.Record { + t.Helper() + + var records []arrow.Record + + for record, err := range iter { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + break + } + if record != nil { + records = append(records, record) + } + } + + return records } diff --git a/spark/client/conf.go b/spark/client/conf.go index 11b301e..5dc0d2d 100644 --- a/spark/client/conf.go +++ b/spark/client/conf.go @@ -18,8 +18,8 @@ package client import ( "context" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/client/base" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/client/base" ) // Public interface RuntimeConfig diff --git a/spark/client/retry.go b/spark/client/retry.go index be7d90c..c802fa4 100644 --- a/spark/client/retry.go +++ b/spark/client/retry.go @@ -23,13 +23,13 @@ import ( "strings" "time" - "github.com/apache/spark-connect-go/spark/client/base" + "github.com/datalake-go/spark-connect-go/spark/client/base" - "github.com/apache/spark-connect-go/spark/client/options" + "github.com/datalake-go/spark-connect-go/spark/client/options" "google.golang.org/grpc/metadata" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" "google.golang.org/grpc" "google.golang.org/grpc/codes" ) diff --git a/spark/client/retry_test.go b/spark/client/retry_test.go index a9526e5..cf3f717 100644 --- a/spark/client/retry_test.go +++ b/spark/client/retry_test.go @@ -22,11 +22,11 @@ import ( "testing" "time" - "github.com/apache/spark-connect-go/spark/client/options" + "github.com/datalake-go/spark-connect-go/spark/client/options" - "github.com/apache/spark-connect-go/spark/client/testutils" - "github.com/apache/spark-connect-go/spark/mocks" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/client/testutils" + "github.com/datalake-go/spark-connect-go/spark/mocks" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" "google.golang.org/grpc/codes" diff --git a/spark/client/testutils/utils.go b/spark/client/testutils/utils.go index c0b3bb5..a3a62a7 100644 --- a/spark/client/testutils/utils.go +++ b/spark/client/testutils/utils.go @@ -19,7 +19,7 @@ import ( "context" "testing" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" "google.golang.org/grpc" ) diff --git a/spark/mocks/mock_executor.go b/spark/mocks/mock_executor.go index 600e9e0..2a16765 100644 --- a/spark/mocks/mock_executor.go +++ b/spark/mocks/mock_executor.go @@ -19,13 +19,13 @@ import ( "context" "errors" - "github.com/apache/spark-connect-go/spark/sql/utils" + "github.com/datalake-go/spark-connect-go/spark/sql/utils" - "github.com/apache/spark-connect-go/spark/client/base" + "github.com/datalake-go/spark-connect-go/spark/client/base" "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) type TestExecutor struct { diff --git a/spark/mocks/mocks.go b/spark/mocks/mocks.go index 3a313f2..7569e0f 100644 --- a/spark/mocks/mocks.go +++ b/spark/mocks/mocks.go @@ -25,7 +25,7 @@ import ( "github.com/google/uuid" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" "google.golang.org/grpc/metadata" ) diff --git a/spark/sparkerrors/cluster_not_ready.go b/spark/sparkerrors/cluster_not_ready.go new file mode 100644 index 0000000..44193fe --- /dev/null +++ b/spark/sparkerrors/cluster_not_ready.go @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sparkerrors + +import ( + "errors" + "fmt" + "strings" +) + +// ErrClusterNotReady is returned when a Spark cluster is in a +// warm-up state and the RPC is safe to retry after backoff. The +// canonical case is Databricks clusters returning +// `[FailedPrecondition]` with `state Pending` during cold start, +// but the pattern also fires for self-managed clusters whose +// drivers are still resolving JARs. +// +// Wrap the typed ClusterNotReady with %w so errors.Is lookups keep +// working: +// +// if errors.Is(err, sparkerrors.ErrClusterNotReady) { ... } +// if sparkerrors.IsClusterNotReady(err) { ... } +var ErrClusterNotReady = errors.New("sparkerrors: cluster not ready") + +// ClusterNotReady is the typed form of ErrClusterNotReady. Carries +// enough context (state string, request ID, original cause) for an +// operator to correlate with server-side logs. +type ClusterNotReady struct { + State string // e.g. "Pending", "PENDING" + RequestID string + Message string + Cause error +} + +func (e *ClusterNotReady) Error() string { + return fmt.Sprintf("cluster not ready (state=%s): %s", e.State, e.Message) +} + +func (e *ClusterNotReady) Unwrap() error { return e.Cause } + +// Is reports whether target is ErrClusterNotReady. Lets callers +// branch via errors.Is regardless of whether they hold the sentinel +// or the typed form. +func (e *ClusterNotReady) Is(target error) bool { return target == ErrClusterNotReady } + +// IsRetryable marks the error as safe to retry after backoff. +// Callers that drive their own retry loops can branch on this +// without pulling in the package's retry machinery. +func (e *ClusterNotReady) IsRetryable() bool { return true } + +// IsClusterNotReady is the convenience wrapper around errors.As for +// the ClusterNotReady type. Returns true whether err is the typed +// form or wraps the sentinel; either way a caller's retry loop +// makes the right call. +func IsClusterNotReady(err error) bool { + if err == nil { + return false + } + var typed *ClusterNotReady + if errors.As(err, &typed) { + return true + } + return errors.Is(err, ErrClusterNotReady) +} + +// NewClusterNotReady inspects err for the canonical +// `[FailedPrecondition]` + `state Pending` pattern and returns a +// typed ClusterNotReady if it matches, else nil. The string-matching +// looks fragile but is the detection pattern that's held across +// multiple Databricks runtime versions and is the shape self-managed +// Spark clusters emit too. +// +// Callers typically invoke this in an RPC error path: +// +// resp, err := cli.ExecutePlan(ctx, req) +// if cn := sparkerrors.NewClusterNotReady(err); cn != nil { +// // retry with backoff +// } +func NewClusterNotReady(err error) *ClusterNotReady { + if err == nil { + return nil + } + errStr := err.Error() + + if !strings.Contains(errStr, "[FailedPrecondition]") || + (!strings.Contains(errStr, "state Pending") && + !strings.Contains(errStr, "state PENDING")) { + return nil + } + + state := "PENDING" + if idx := strings.Index(errStr, "[state="); idx != -1 { + endIdx := strings.Index(errStr[idx:], "]") + if endIdx != -1 { + state = errStr[idx+7 : idx+endIdx] + } + } + + var requestID string + if idx := strings.Index(errStr, "(requestId="); idx != -1 { + endIdx := strings.Index(errStr[idx:], ")") + if endIdx != -1 { + requestID = errStr[idx+11 : idx+endIdx] + } + } + + return &ClusterNotReady{ + State: state, + RequestID: requestID, + Message: "The cluster is starting up. Please retry your request in a few moments.", + Cause: err, + } +} diff --git a/spark/sparkerrors/cluster_not_ready_test.go b/spark/sparkerrors/cluster_not_ready_test.go new file mode 100644 index 0000000..5e6d18f --- /dev/null +++ b/spark/sparkerrors/cluster_not_ready_test.go @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sparkerrors + +import ( + "errors" + "fmt" + "testing" +) + +func TestNewClusterNotReady_DetectsCanonicalDatabricksPattern(t *testing.T) { + // Shape ported verbatim from production Databricks responses. + raw := errors.New( + "rpc error: code = FailedPrecondition desc = " + + "[FailedPrecondition] cluster with id=0313-xyz is in state Pending " + + "[state=PENDING] (requestId=abc-123)", + ) + got := NewClusterNotReady(raw) + if got == nil { + t.Fatalf("NewClusterNotReady returned nil on canonical Databricks error: %v", raw) + } + if got.State != "PENDING" { + t.Errorf("State = %q, want PENDING", got.State) + } + if got.RequestID != "abc-123" { + t.Errorf("RequestID = %q, want abc-123", got.RequestID) + } + if got.Cause != raw { + t.Errorf("Cause should be the original error unchanged") + } + if !got.IsRetryable() { + t.Errorf("IsRetryable should be true") + } +} + +func TestNewClusterNotReady_RejectsUnrelatedErrors(t *testing.T) { + cases := []error{ + nil, + errors.New("connection refused"), + errors.New("rpc error: code = Internal desc = server error"), + errors.New("[FailedPrecondition] lock held"), // no state Pending marker + } + for _, e := range cases { + if got := NewClusterNotReady(e); got != nil { + t.Errorf("NewClusterNotReady(%v) = %v, want nil", e, got) + } + } +} + +func TestIsClusterNotReady_MatchesTypedAndSentinel(t *testing.T) { + raw := errors.New( + "[FailedPrecondition] state Pending starting up", + ) + typed := NewClusterNotReady(raw) + if !IsClusterNotReady(typed) { + t.Errorf("IsClusterNotReady on typed ClusterNotReady should be true") + } + // Wrap the typed error — IsClusterNotReady should still match + // because errors.As unwinds wrapping. + wrapped := fmt.Errorf("retrying: %w", typed) + if !IsClusterNotReady(wrapped) { + t.Errorf("IsClusterNotReady on wrapped typed error should be true") + } + // A bare sentinel wrap also matches for callers that don't + // produce the typed struct. + sentinelWrap := fmt.Errorf("context: %w", ErrClusterNotReady) + if !IsClusterNotReady(sentinelWrap) { + t.Errorf("IsClusterNotReady on %%w-wrapped sentinel should be true") + } +} + +func TestIsClusterNotReady_FalseOnNilAndUnrelated(t *testing.T) { + if IsClusterNotReady(nil) { + t.Errorf("IsClusterNotReady(nil) should be false") + } + if IsClusterNotReady(errors.New("some other error")) { + t.Errorf("IsClusterNotReady on unrelated error should be false") + } +} + +func TestClusterNotReady_IsMatchesSentinel(t *testing.T) { + // errors.Is should match the sentinel through the typed form's + // own Is method (tested independently of the convenience + // IsClusterNotReady helper). + typed := &ClusterNotReady{State: "PENDING"} + if !errors.Is(typed, ErrClusterNotReady) { + t.Errorf("errors.Is(typed, ErrClusterNotReady) should be true") + } +} + +func TestClusterNotReady_UnwrapExposesCause(t *testing.T) { + cause := errors.New("underlying grpc error") + typed := &ClusterNotReady{Cause: cause} + if got := errors.Unwrap(typed); got != cause { + t.Errorf("Unwrap returned %v, want %v", got, cause) + } +} diff --git a/spark/sql/column/column.go b/spark/sql/column/column.go index 79e2fad..7ed2080 100644 --- a/spark/sql/column/column.go +++ b/spark/sql/column/column.go @@ -18,9 +18,9 @@ package column import ( "context" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" ) // Convertible is the interface for all things that can be converted into a protobuf expression. diff --git a/spark/sql/column/column_test.go b/spark/sql/column/column_test.go index fa97e80..7f1134e 100644 --- a/spark/sql/column/column_test.go +++ b/spark/sql/column/column_test.go @@ -19,7 +19,7 @@ import ( "context" "testing" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/column/expressions.go b/spark/sql/column/expressions.go index 399a611..fc8780e 100644 --- a/spark/sql/column/expressions.go +++ b/spark/sql/column/expressions.go @@ -20,11 +20,11 @@ import ( "fmt" "strings" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" ) func newProtoExpression() *proto.Expression { diff --git a/spark/sql/column/expressions_test.go b/spark/sql/column/expressions_test.go index 836d5a8..c1c2fb1 100644 --- a/spark/sql/column/expressions_test.go +++ b/spark/sql/column/expressions_test.go @@ -20,7 +20,7 @@ import ( "reflect" "testing" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index 4c7a563..82f19db 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -22,14 +22,14 @@ import ( "math/rand/v2" "github.com/apache/arrow-go/v18/arrow" - "github.com/apache/spark-connect-go/spark/sql/utils" + "github.com/datalake-go/spark-connect-go/spark/sql/utils" - "github.com/apache/spark-connect-go/spark/sql/column" + "github.com/datalake-go/spark-connect-go/spark/sql/column" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) // ResultCollector receives a stream of result rows @@ -200,6 +200,11 @@ type DataFrame interface { // Sort returns a new DataFrame sorted by the specified columns. Sort(ctx context.Context, columns ...column.Convertible) (DataFrame, error) Stat() DataFrameStatFunctions + // StreamRows returns a lazy iterator over rows from Spark. + // No rows are fetched from Spark over gRPC until the previous one has been consumed. + // It provides no internal buffering: each Row is produced only when the caller + // requests it, ensuring client back-pressure is respected. + StreamRows(ctx context.Context) (iter.Seq2[types.Row, error], error) // Subtract subtracts the other DataFrame from the current DataFrame. And only returns // distinct rows. Subtract(ctx context.Context, other DataFrame) DataFrame @@ -936,6 +941,17 @@ func (df *dataFrameImpl) ToArrow(ctx context.Context) (*arrow.Table, error) { return &table, nil } +func (df *dataFrameImpl) StreamRows(ctx context.Context) (iter.Seq2[types.Row, error], error) { + responseClient, err := df.session.client.ExecutePlan(ctx, df.createPlan()) + if err != nil { + return nil, sparkerrors.WithType(fmt.Errorf("failed to execute plan: %w", err), sparkerrors.ExecutionError) + } + + seq2 := responseClient.ToRecordSequence(ctx) + + return types.NewRowSequence(ctx, seq2), nil +} + func (df *dataFrameImpl) UnionAll(ctx context.Context, other DataFrame) DataFrame { otherDf := other.(*dataFrameImpl) isAll := true diff --git a/spark/sql/dataframe_test.go b/spark/sql/dataframe_test.go index 0bee27b..3859629 100644 --- a/spark/sql/dataframe_test.go +++ b/spark/sql/dataframe_test.go @@ -20,8 +20,8 @@ import ( "context" "testing" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sql/functions" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sql/functions" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/dataframe_typed.go b/spark/sql/dataframe_typed.go new file mode 100644 index 0000000..cd41f0a --- /dev/null +++ b/spark/sql/dataframe_typed.go @@ -0,0 +1,365 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "fmt" + "iter" + "reflect" + "strings" + "sync" + "time" + "unicode" + + "github.com/apache/arrow-go/v18/arrow" +) + +// DataFrameOf[T] is an edge-typing wrapper around a regular DataFrame. +// Its role is narrow: cache the reflected row plan so a caller who +// already knows the result shape can materialise into []T / iter / +// *T without re-validating T on every call. +// +// The wrapper deliberately has no transformation methods (no Where, +// Limit, OrderBy, Select, Join, GroupBy). Transformations change row +// shape and make the type parameter lie. Callers compose +// transformations on the untyped DataFrame, then re-type at the edge +// via As[T] when ready to collect. +// +// Column binding uses struct tags in the same shape sqlx / parquet-go +// already use: +// +// type User struct { +// ID string `spark:"id"` +// Email string `spark:"email"` +// Created time.Time `spark:"created_at"` +// } +// +// Fields without a `spark:"..."` tag are mapped by snake_case'd field +// name, so a plain Go struct works without any tags at all. Fields +// tagged `spark:"-"` are skipped. Columns in the DataFrame that +// don't match any field are ignored — typical of projections narrower +// than the struct. +// +// Schema drift (a struct field that the result's projection doesn't +// contain) surfaces at the first Collect / Stream / First call as a +// single error rather than per-row panics. +type DataFrameOf[T any] struct { + df DataFrame + plan *rowPlan +} + +// As wraps df in the typed surface. T must be a struct; the row plan +// is computed and cached once so subsequent Collect / Stream / First +// calls don't re-reflect. Schema compatibility with T is validated +// lazily — the first materialisation surfaces drift, not As itself. +// +// This is the only supported constructor for DataFrameOf[T]. Callers +// hold a DataFrame (from session.Sql, session.Table, a chain of +// untyped transformations) and hand it here at the point the result +// shape is known. +func As[T any](df DataFrame) (*DataFrameOf[T], error) { + var zero T + rt := reflect.TypeOf(zero) + if rt == nil || rt.Kind() != reflect.Struct { + return nil, fmt.Errorf("DataFrameOf[T]: T must be a struct, got %v", rt) + } + plan, err := buildRowPlan(rt) + if err != nil { + return nil, err + } + return &DataFrameOf[T]{df: df, plan: plan}, nil +} + +// DataFrame returns the underlying untyped DataFrame. Escape hatch +// for transformations the typed surface deliberately doesn't carry +// — GroupBy, joins, window functions, Where, Limit, OrderBy. Chain +// on the untyped handle, then call As[T] again when the output +// shape is known. +func (d *DataFrameOf[T]) DataFrame() DataFrame { return d.df } + +// Collect materialises every row into a []T. Holds the whole table +// on the heap for the duration of the call — callers with large +// result sets should project narrower on the SQL side or use Stream +// for constant-memory iteration. +func (d *DataFrameOf[T]) Collect(ctx context.Context) ([]T, error) { + rows, err := d.df.Collect(ctx) + if err != nil { + return nil, err + } + if len(rows) == 0 { + return nil, nil + } + bindings, err := d.plan.bind(rows[0].FieldNames()) + if err != nil { + return nil, err + } + out := make([]T, len(rows)) + for i, r := range rows { + if err := decodeRow(d.plan, r.Values(), bindings, &out[i]); err != nil { + return nil, fmt.Errorf("DataFrameOf[T].Collect: row %d: %w", i, err) + } + } + return out, nil +} + +// Stream yields typed rows one at a time with constant memory. Uses +// the untyped DataFrame.All streaming primitive underneath. Consumers +// range with Go 1.23's iter.Seq2: +// +// for row, err := range ds.Stream(ctx) { +// if err != nil { break } +// // use row +// } +func (d *DataFrameOf[T]) Stream(ctx context.Context) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { + var zero T + var bindings []columnBinding + for row, rerr := range d.df.All(ctx) { + if rerr != nil { + yield(zero, rerr) + return + } + if bindings == nil { + b, berr := d.plan.bind(row.FieldNames()) + if berr != nil { + yield(zero, berr) + return + } + bindings = b + } + var out T + if derr := decodeRow(d.plan, row.Values(), bindings, &out); derr != nil { + yield(zero, derr) + return + } + if !yield(out, nil) { + return + } + } + } +} + +// First returns a pointer to the first row as T, or ErrNotFound when +// the DataFrame produces zero rows. Runs Collect underneath at v0; +// a LIMIT 1 pushdown lands alongside the Spark Connect plan-cache +// work in a future cycle. +func (d *DataFrameOf[T]) First(ctx context.Context) (*T, error) { + rows, err := d.Collect(ctx) + if err != nil { + return nil, err + } + if len(rows) == 0 { + return nil, ErrNotFound + } + return &rows[0], nil +} + +// rowPlan caches the reflected structure of T so Collect doesn't +// reflect on every row. Built once per DataFrameOf[T]. +type rowPlan struct { + goType reflect.Type + fields []plannedField +} + +type plannedField struct { + name string // column name from tag or snake_case'd field name + index []int // reflect.FieldByIndex path + gotyp reflect.Type +} + +// columnBinding maps a result-set column position to the field slot +// in the plan that should receive it. A column that the struct +// doesn't describe has planIndex = -1 and is skipped. +type columnBinding struct { + planIndex int +} + +var rowPlanCache sync.Map // reflect.Type -> *rowPlan + +func buildRowPlan(rt reflect.Type) (*rowPlan, error) { + if cached, ok := rowPlanCache.Load(rt); ok { + return cached.(*rowPlan), nil + } + plan := &rowPlan{goType: rt} + if err := walkPlan(rt, nil, plan); err != nil { + return nil, err + } + rowPlanCache.Store(rt, plan) + return plan, nil +} + +func walkPlan(rt reflect.Type, parent []int, plan *rowPlan) error { + for i := 0; i < rt.NumField(); i++ { + sf := rt.Field(i) + if !sf.IsExported() { + continue + } + idx := append(append([]int{}, parent...), i) + tag := sf.Tag.Get("spark") + if tag == "-" { + continue + } + if sf.Anonymous && sf.Type.Kind() == reflect.Struct && tag == "" { + if err := walkPlan(sf.Type, idx, plan); err != nil { + return err + } + continue + } + name := tag + if name == "" { + name = snakeCase(sf.Name) + } + plan.fields = append(plan.fields, plannedField{ + name: name, + index: idx, + gotyp: sf.Type, + }) + } + return nil +} + +// bind aligns the plan's fields with a concrete set of result +// columns. Result columns that don't map to any planned field are +// tagged planIndex = -1 and dropped at decode time. A planned field +// that doesn't appear in the columns is a schema-drift error — the +// SQL changed in a way the struct doesn't describe and we want the +// caller to know at plan time, not per row. +func (p *rowPlan) bind(columns []string) ([]columnBinding, error) { + byName := make(map[string]int, len(p.fields)) + for i, f := range p.fields { + byName[f.name] = i + } + bindings := make([]columnBinding, len(columns)) + seen := make(map[int]bool, len(p.fields)) + for ci, name := range columns { + if idx, ok := byName[name]; ok { + bindings[ci] = columnBinding{planIndex: idx} + seen[idx] = true + } else { + bindings[ci] = columnBinding{planIndex: -1} + } + } + var missing []string + for i, f := range p.fields { + if !seen[i] { + missing = append(missing, f.name) + } + } + if len(missing) > 0 { + return nil, fmt.Errorf("DataFrameOf[T]: struct field(s) not in result schema: %s", + strings.Join(missing, ", ")) + } + return bindings, nil +} + +// decodeRow writes values into *T using the bindings. dest is a +// pointer to T; we take *T (not T) so the caller can write each +// element of a pre-allocated []T slice without paying for a +// reflect.Value per row. +func decodeRow[T any](plan *rowPlan, values []any, bindings []columnBinding, dest *T) error { + dv := reflect.ValueOf(dest).Elem() + for ci := 0; ci < len(values) && ci < len(bindings); ci++ { + b := bindings[ci] + if b.planIndex < 0 { + continue + } + pf := &plan.fields[b.planIndex] + target := fieldByIndex(dv, pf.index) + if err := assignTypedValue(target, values[ci]); err != nil { + return fmt.Errorf("column %d (%s): %w", ci, pf.name, err) + } + } + return nil +} + +func fieldByIndex(v reflect.Value, index []int) reflect.Value { + cur := v + for _, i := range index { + for cur.Kind() == reflect.Ptr { + if cur.IsNil() { + cur.Set(reflect.New(cur.Type().Elem())) + } + cur = cur.Elem() + } + cur = cur.Field(i) + } + return cur +} + +// assignTypedValue writes src into the reflect field, doing the +// small set of conversions Spark's row values need — nil → zero for +// optional fields, arrow.Timestamp → time.Time for TIMESTAMP +// columns, assignable/convertible for primitives. Anything rarer +// surfaces as an explicit error so callers can tighten their struct +// to match. +func assignTypedValue(dst reflect.Value, src any) error { + if src == nil { + dst.Set(reflect.Zero(dst.Type())) + return nil + } + dt := dst.Type() + isPtr := dt.Kind() == reflect.Ptr + innerType := dt + if isPtr { + innerType = dt.Elem() + } + if ts, ok := src.(arrow.Timestamp); ok && innerType == reflect.TypeOf(time.Time{}) { + setTypedValue(dst, reflect.ValueOf(ts.ToTime(arrow.Microsecond)), isPtr, innerType) + return nil + } + sv := reflect.ValueOf(src) + if sv.Type().AssignableTo(innerType) { + setTypedValue(dst, sv, isPtr, innerType) + return nil + } + if sv.Type().ConvertibleTo(innerType) { + setTypedValue(dst, sv.Convert(innerType), isPtr, innerType) + return nil + } + return fmt.Errorf("cannot assign %T to %v", src, dt) +} + +func setTypedValue(dst, src reflect.Value, isPtr bool, inner reflect.Type) { + if isPtr { + p := reflect.New(inner) + p.Elem().Set(src) + dst.Set(p) + return + } + dst.Set(src) +} + +// snakeCase converts a Go field name to its snake_case form. Matches +// the convention used by sqlx / gorm / jackc so a plain Go struct +// with no tags lines up with columns that follow standard SQL naming. +func snakeCase(s string) string { + if s == "" { + return s + } + var b strings.Builder + runes := []rune(s) + for i, r := range runes { + if i > 0 && unicode.IsUpper(r) { + prev := runes[i-1] + if unicode.IsLower(prev) || (i+1 < len(runes) && unicode.IsLower(runes[i+1])) { + b.WriteByte('_') + } + } + b.WriteRune(unicode.ToLower(r)) + } + return b.String() +} diff --git a/spark/sql/dataframe_typed_test.go b/spark/sql/dataframe_typed_test.go new file mode 100644 index 0000000..0261fa0 --- /dev/null +++ b/spark/sql/dataframe_typed_test.go @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "reflect" + "testing" + "time" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type typedUser struct { + ID string `spark:"id"` + Email string `spark:"email"` + Country string // no tag — falls back to snake_case of field name + CreatedAt time.Time `spark:"created_at"` + Secret string `spark:"-"` // skipped + ignored string //nolint:unused // unexported → skipped +} + +func TestBuildRowPlan_TagsAndSnakeCase(t *testing.T) { + plan, err := buildRowPlan(reflect.TypeOf(typedUser{})) + require.NoError(t, err) + + names := make([]string, 0, len(plan.fields)) + for _, f := range plan.fields { + names = append(names, f.name) + } + assert.Equal(t, []string{"id", "email", "country", "created_at"}, names, + "tag takes precedence; untagged fields snake_case; `-` and unexported skipped") +} + +func TestBind_MatchesColumns(t *testing.T) { + plan, _ := buildRowPlan(reflect.TypeOf(typedUser{})) + + // Columns may arrive in a different order than fields and can + // include extras the struct doesn't care about; the binder + // should pick the right indices and ignore the stranger. + cols := []string{"created_at", "extra_col", "email", "id", "country"} + bindings, err := plan.bind(cols) + require.NoError(t, err) + + want := []int{3, -1, 1, 0, 2} + for i, b := range bindings { + assert.Equalf(t, want[i], b.planIndex, + "column %q bound to wrong plan index", cols[i]) + } +} + +func TestBind_SchemaDriftErrorsEarly(t *testing.T) { + plan, _ := buildRowPlan(reflect.TypeOf(typedUser{})) + // email column missing — the struct wants it, the result doesn't + // have it. Bind must surface this, not defer to per-row decode. + _, err := plan.bind([]string{"id", "country", "created_at"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "email", + "schema-drift error should name the missing column") +} + +func TestDecodeRow_ArrowTimestampToTime(t *testing.T) { + plan, _ := buildRowPlan(reflect.TypeOf(typedUser{})) + cols := []string{"id", "email", "country", "created_at"} + bindings, err := plan.bind(cols) + require.NoError(t, err) + + when := time.Date(2026, 4, 19, 12, 0, 0, 0, time.UTC) + // arrow.Timestamp is an int64 of microseconds since epoch. Build + // one directly so the decode path's Microsecond branch fires. + ts := arrow.Timestamp(when.UnixMicro()) + + values := []any{"abc", "alice@example.com", "UK", ts} + + var out typedUser + require.NoError(t, decodeRow(plan, values, bindings, &out)) + + assert.Equal(t, "abc", out.ID) + assert.Equal(t, "alice@example.com", out.Email) + assert.Equal(t, "UK", out.Country) + assert.True(t, out.CreatedAt.Equal(when), "TIMESTAMP micros should round-trip to time.Time") +} + +func TestDecodeRow_AssignableAndConvertible(t *testing.T) { + type row struct { + N int `spark:"n"` // assignable from int + S string `spark:"s"` // assignable from string + F int `spark:"f"` // float64 → int is convertible + } + plan, err := buildRowPlan(reflect.TypeOf(row{})) + require.NoError(t, err) + bindings, err := plan.bind([]string{"n", "s", "f"}) + require.NoError(t, err) + + var out row + require.NoError(t, decodeRow(plan, []any{int(42), "hello", float64(7)}, bindings, &out)) + assert.Equal(t, 42, out.N) + assert.Equal(t, "hello", out.S) + assert.Equal(t, 7, out.F) +} + +func TestDecodeRow_NilIsZero(t *testing.T) { + type row struct { + N int `spark:"n"` + S string `spark:"s"` + } + plan, _ := buildRowPlan(reflect.TypeOf(row{})) + bindings, _ := plan.bind([]string{"n", "s"}) + + var out row + require.NoError(t, decodeRow(plan, []any{nil, nil}, bindings, &out)) + assert.Zero(t, out.N) + assert.Zero(t, out.S) +} + +func TestAs_RejectsNonStruct(t *testing.T) { + // As[T] surfaces the misuse at construction time, not at Collect. + // A map / slice / primitive should fail clearly with a pointer + // back at the caller's T. + type notAStruct = map[string]string + _, err := As[notAStruct](nil) + require.Error(t, err) + assert.Contains(t, err.Error(), "T must be a struct") +} + +func TestSnakeCase_CommonShapes(t *testing.T) { + cases := map[string]string{ + "ID": "id", + "Email": "email", + "CreatedAt": "created_at", + "HTTPServer": "http_server", + "JSONPayload": "json_payload", + "A": "a", + "": "", + } + for in, want := range cases { + assert.Equalf(t, want, snakeCase(in), "snakeCase(%q)", in) + } +} diff --git a/spark/sql/dataframenafunctions.go b/spark/sql/dataframenafunctions.go index 9845bb1..552788e 100644 --- a/spark/sql/dataframenafunctions.go +++ b/spark/sql/dataframenafunctions.go @@ -18,7 +18,7 @@ package sql import ( "context" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) type DataFrameNaFunctions interface { diff --git a/spark/sql/dataframewriter.go b/spark/sql/dataframewriter.go index 8c096f8..ca05b64 100644 --- a/spark/sql/dataframewriter.go +++ b/spark/sql/dataframewriter.go @@ -21,8 +21,8 @@ import ( "fmt" "strings" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) // DataFrameWriter supports writing data frame to storage. diff --git a/spark/sql/dataframewriter_test.go b/spark/sql/dataframewriter_test.go index bc85f65..3a60fe8 100644 --- a/spark/sql/dataframewriter_test.go +++ b/spark/sql/dataframewriter_test.go @@ -19,10 +19,10 @@ import ( "context" "testing" - "github.com/apache/spark-connect-go/spark/client" + "github.com/datalake-go/spark-connect-go/spark/client" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/mocks" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/mocks" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/driver/conn.go b/spark/sql/driver/conn.go new file mode 100644 index 0000000..65b7f9d --- /dev/null +++ b/spark/sql/driver/conn.go @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "context" + "database/sql/driver" + "errors" + + sparksql "github.com/datalake-go/spark-connect-go/spark/sql" +) + +// conn is the per-logical-connection state database/sql keeps in its +// pool. Each Conn holds one Spark Connect session; Close stops it. +// +// Implements: +// - driver.Conn (legacy Prepare / Close / Begin) +// - driver.ConnPrepareContext +// - driver.ConnBeginTx +// - driver.ExecerContext +// - driver.QueryerContext +// - driver.Pinger (Ping via session round-trip) +type conn struct { + session sparksql.SparkSession + cfg *dsnConfig + closed bool +} + +// Prepare wraps the SQL in a driver.Stmt. Legacy (non-context) +// variant of PrepareContext. No server-side preparation — Spark +// Connect doesn't expose a prepare primitive. Each Exec/Query +// re-sends the statement; v1+ adds a plan cache if latency matters. +// +// Implements database/sql/driver.Conn. +func (c *conn) Prepare(query string) (driver.Stmt, error) { + return c.PrepareContext(context.Background(), query) +} + +// PrepareContext is the context-aware Prepare. Implements +// database/sql/driver.ConnPrepareContext. +func (c *conn) PrepareContext(_ context.Context, query string) (driver.Stmt, error) { + if c.closed { + return nil, driver.ErrBadConn + } + return &stmt{conn: c, query: query}, nil +} + +// Close releases the underlying Spark Connect session. Idempotent. +// Implements database/sql/driver.Conn. +func (c *conn) Close() error { + if c.closed { + return nil + } + c.closed = true + if c.session == nil { + return nil + } + return c.session.Stop() +} + +// Begin returns a no-op transaction. Implements +// database/sql/driver.Conn (legacy; ConnBeginTx below is preferred). +func (c *conn) Begin() (driver.Tx, error) { + return c.BeginTx(context.Background(), driver.TxOptions{}) +} + +// BeginTx returns a no-op transaction. Lakehouse commit semantics +// (Iceberg snapshot commits, Delta transaction log appends) happen +// at the per-statement level inside Spark, not at a SQL-layer Tx +// boundary. Wrapping statements in a database/sql Tx doesn't buy +// atomicity for lakehouse writes, so this driver's Tx is a no-op +// that lets tools (goose, sqlc) that default to transactional +// execution not break. +// +// Implements database/sql/driver.ConnBeginTx. +func (c *conn) BeginTx(_ context.Context, _ driver.TxOptions) (driver.Tx, error) { + if c.closed { + return nil, driver.ErrBadConn + } + return &tx{}, nil +} + +// ExecContext runs a statement that returns no rows. Forwards to +// the Stmt path so the NumInput-based argument check lives in one +// place. +// +// Implements database/sql/driver.ExecerContext. +func (c *conn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + if c.closed { + return nil, driver.ErrBadConn + } + return (&stmt{conn: c, query: query}).ExecContext(ctx, args) +} + +// QueryContext runs a SELECT returning rows. Implements +// database/sql/driver.QueryerContext. +func (c *conn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + if c.closed { + return nil, driver.ErrBadConn + } + return (&stmt{conn: c, query: query}).QueryContext(ctx, args) +} + +// Ping checks the Spark Connect session is alive. Implements +// database/sql/driver.Pinger. +func (c *conn) Ping(ctx context.Context) error { + if c.closed { + return driver.ErrBadConn + } + _, err := c.session.Sql(ctx, "SELECT 1") + return err +} + +// tx is the no-op transaction described on BeginTx. +type tx struct{} + +// Commit succeeds silently. Implements database/sql/driver.Tx. +func (tx) Commit() error { return nil } + +// Rollback returns an error because lakehouse rollback isn't a +// meaningful concept at the database/sql layer. Callers that expect +// Rollback to un-apply statements see the error and can choose to +// handle it; tools that default to rollback-on-error (goose with +// failing migrations) get a signal that rollback didn't happen. +// Implements database/sql/driver.Tx. +func (tx) Rollback() error { + return errors.New("spark driver: transactional rollback is not supported; lakehouse commit semantics are per-statement") +} diff --git a/spark/sql/driver/driver.go b/spark/sql/driver/driver.go new file mode 100644 index 0000000..75eb5ca --- /dev/null +++ b/spark/sql/driver/driver.go @@ -0,0 +1,126 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package driver implements `database/sql`'s driver interfaces over +// Spark Connect, so every Go tool that speaks `database/sql` — goose, +// sqlc, pgx consumers, ad-hoc test harnesses, script code — can target +// a Spark-backed lakehouse without learning the native client API. +// +// The package registers itself on import as the driver name "spark": +// +// import ( +// "database/sql" +// _ "github.com/datalake-go/spark-connect-go/spark/sql/driver" +// ) +// +// db, err := sql.Open("spark", "sc://localhost:15002?format=iceberg") +// +// DSN grammar is the plain Spark Connect URL with optional query +// parameters: +// +// sc://host:port +// sc://host:port?token= +// sc://host:port?format=iceberg +// sc://host:port?format=delta +// sc://host:port?token=&format=delta +// +// The `format` parameter is consumed by downstream tools (such as +// goose-spark) that need to emit format-specific DDL; the driver +// itself doesn't interpret it. +// +// v0 scope: sufficient for goose to create its version table, insert +// applied-migration rows, and select them back. Specifically: +// +// - Exec / Query against a Spark Connect endpoint via session.Sql. +// - No statement preparation (every call is `NumInput() == 0`; +// callers that pass arguments get an error so SQL injection +// isn't possible through this driver, and migrations author their +// own literal SQL anyway). +// - No-op transactions — Begin / Commit succeed silently, Rollback +// warns but does not error. Lakehouse commit semantics live in +// Iceberg / Delta at the per-statement level; wrapping statements +// in a `database/sql` Tx doesn't buy atomicity here. Documented +// in the package comment so a caller who expects real transaction +// semantics isn't surprised. +// +// v1+ adds: parameter binding via the Spark Connect parameterised- +// query proto, prepared-statement caching, catalog introspection for +// row metadata (column names + types) in the Rows wrapper. +package driver + +import ( + "context" + "database/sql" + "database/sql/driver" + + sparksql "github.com/datalake-go/spark-connect-go/spark/sql" +) + +// init registers the driver under the name "spark". Consumers that +// only want the typed DataFrame surface don't have to import this +// package; consumers that use database/sql do. +func init() { + sql.Register("spark", &sparkDriver{}) +} + +// sparkDriver is the database/sql.Driver implementation. Satisfies +// both the legacy Driver interface (Open) and the modern +// DriverContext interface (OpenConnector). +type sparkDriver struct{} + +// Open opens a new connection using the given DSN. Implements +// database/sql/driver.Driver. +// +// database/sql calls Open for every new connection the pool wants. +// Keeping this lightweight — parse the DSN, return a Connector, +// defer actual session construction to Connect — matches how the +// standard library's pool expects drivers to behave. +func (d *sparkDriver) Open(dsn string) (driver.Conn, error) { + c, err := d.OpenConnector(dsn) + if err != nil { + return nil, err + } + return c.Connect(context.Background()) +} + +// OpenConnector returns a Connector ready to produce Conns. +// Implements database/sql/driver.DriverContext. +func (d *sparkDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := parseDSN(dsn) + if err != nil { + return nil, err + } + return &connector{driver: d, cfg: cfg}, nil +} + +// connector wraps a parsed DSN and produces Conns on demand. +type connector struct { + driver *sparkDriver + cfg *dsnConfig +} + +// Connect opens a Spark Connect session and returns a driver.Conn +// wrapping it. Implements database/sql/driver.Connector. +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + session, err := sparksql.NewSessionBuilder().Remote(c.cfg.sparkDSN).Build(ctx) + if err != nil { + return nil, err + } + return &conn{session: session, cfg: c.cfg}, nil +} + +// Driver returns the owning driver. Implements +// database/sql/driver.Connector. +func (c *connector) Driver() driver.Driver { return c.driver } diff --git a/spark/sql/driver/driver_test.go b/spark/sql/driver/driver_test.go new file mode 100644 index 0000000..c6e8e4b --- /dev/null +++ b/spark/sql/driver/driver_test.go @@ -0,0 +1,219 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "database/sql" + "database/sql/driver" + "io" + "strings" + "testing" + + "github.com/datalake-go/spark-connect-go/spark/sql/types" +) + +// --- DSN parsing --------------------------------------------------- + +func TestParseDSN_AcceptsPlainSc(t *testing.T) { + cfg, err := parseDSN("sc://localhost:15002") + if err != nil { + t.Fatalf("parseDSN: %v", err) + } + if cfg.sparkDSN != "sc://localhost:15002" { + t.Errorf("sparkDSN = %q", cfg.sparkDSN) + } +} + +func TestParseDSN_RejectsMissingSchemePrefix(t *testing.T) { + _, err := parseDSN("localhost:15002") + if err == nil || !strings.Contains(err.Error(), "sc://") { + t.Errorf("err = %v, want sc:// required", err) + } +} + +func TestParseDSN_RejectsEmpty(t *testing.T) { + _, err := parseDSN("") + if err == nil || !strings.Contains(err.Error(), "DSN is required") { + t.Errorf("err = %v, want DSN-required error", err) + } +} + +func TestParseDSN_PreservesQueryParamsInSparkDSN(t *testing.T) { + // The driver does not interpret query parameters — not `token`, + // not anything else. Parameters must round-trip unchanged into + // sparkDSN so the upstream SparkSessionBuilder.Remote sees them. + cfg, err := parseDSN("sc://host:15002?token=secret&user=alice") + if err != nil { + t.Fatalf("parseDSN: %v", err) + } + if !strings.Contains(cfg.sparkDSN, "token=secret") { + t.Errorf("sparkDSN should carry token unchanged; got %q", cfg.sparkDSN) + } + if !strings.Contains(cfg.sparkDSN, "user=alice") { + t.Errorf("sparkDSN should carry user unchanged; got %q", cfg.sparkDSN) + } +} + +// --- sql.Register side-effect --------------------------------------- + +func TestDriver_RegistersAsSparkOnImport(t *testing.T) { + // Importing this package (the test binary already does) triggers + // init() which calls sql.Register("spark", ...). Verify the + // driver list contains "spark". + found := false + for _, name := range sql.Drivers() { + if name == "spark" { + found = true + break + } + } + if !found { + t.Errorf("driver list = %v; expected to contain \"spark\"", sql.Drivers()) + } +} + +func TestDriver_OpenConnectorInvalidDSN(t *testing.T) { + _, err := (&sparkDriver{}).OpenConnector("bogus://nope") + if err == nil { + t.Error("OpenConnector should reject non-sc:// DSN") + } +} + +// Parameter binding coverage lives in render_test.go. Here we only +// pin that NumInput returns -1 so database/sql passes args through +// untouched instead of rejecting them up front. + +func TestStmt_NumInputIsNegativeOne(t *testing.T) { + s := &stmt{} + if got := s.NumInput(); got != -1 { + t.Errorf("NumInput = %d, want -1 (database/sql should not pre-count args)", got) + } +} + +// --- Rows wrapper -------------------------------------------------- + +// rowFake is a minimal types.Row satisfying every method the +// interface declares. Only FieldNames + Values + Len + At get +// exercised from rows.go; the rest are required by the interface +// shape. +type rowFake struct { + names []string + vals []any +} + +func (r rowFake) FieldNames() []string { return r.names } +func (r rowFake) Values() []any { return r.vals } +func (r rowFake) Len() int { return len(r.vals) } +func (r rowFake) At(i int) any { + if i < 0 || i >= len(r.vals) { + return nil + } + return r.vals[i] +} + +func (r rowFake) Value(name string) any { + for i, n := range r.names { + if n == name { + return r.vals[i] + } + } + return nil +} +func (r rowFake) ToJsonString() (string, error) { return "", nil } + +func TestRows_ColumnsReportsFieldNamesFromFirstRow(t *testing.T) { + fakes := []types.Row{ + rowFake{names: []string{"id", "name"}, vals: []any{int64(1), "alice"}}, + rowFake{names: []string{"id", "name"}, vals: []any{int64(2), "bob"}}, + } + r := newRows(fakes) + got := r.Columns() + if len(got) != 2 || got[0] != "id" || got[1] != "name" { + t.Errorf("Columns = %v, want [id name]", got) + } +} + +func TestRows_ColumnsEmptyOnEmptyResult(t *testing.T) { + r := newRows(nil) + if len(r.Columns()) != 0 { + t.Errorf("Columns = %v on empty result, want []", r.Columns()) + } +} + +func TestRows_NextWalksResultAndStopsOnEOF(t *testing.T) { + fakes := []types.Row{ + rowFake{names: []string{"id", "name"}, vals: []any{int64(1), "alice"}}, + rowFake{names: []string{"id", "name"}, vals: []any{int64(2), "bob"}}, + } + r := newRows(fakes) + dest := make([]driver.Value, 2) + + if err := r.Next(dest); err != nil { + t.Fatalf("Next row 1: %v", err) + } + if dest[0].(int64) != 1 || dest[1].(string) != "alice" { + t.Errorf("row 1 = %v, want [1 alice]", dest) + } + + if err := r.Next(dest); err != nil { + t.Fatalf("Next row 2: %v", err) + } + if dest[0].(int64) != 2 || dest[1].(string) != "bob" { + t.Errorf("row 2 = %v, want [2 bob]", dest) + } + + if err := r.Next(dest); err != io.EOF { + t.Errorf("Next past end = %v, want io.EOF", err) + } +} + +// --- no-op transaction -------------------------------------------- + +func TestTx_CommitIsNoop(t *testing.T) { + var transaction tx + if err := transaction.Commit(); err != nil { + t.Errorf("Commit err = %v, want nil", err) + } +} + +func TestTx_RollbackErrors(t *testing.T) { + var transaction tx + err := transaction.Rollback() + if err == nil || !strings.Contains(err.Error(), "rollback is not supported") { + t.Errorf("Rollback err = %v, want rollback-not-supported error", err) + } +} + +// --- result -------------------------------------------------------- + +func TestResult_RowsAffectedReturnsUnknown(t *testing.T) { + r := result{} + n, err := r.RowsAffected() + if err != nil { + t.Errorf("RowsAffected err = %v, want nil", err) + } + if n != -1 { + t.Errorf("RowsAffected = %d, want -1 (unknown)", n) + } +} + +func TestResult_LastInsertIdErrors(t *testing.T) { + r := result{} + _, err := r.LastInsertId() + if err == nil || !strings.Contains(err.Error(), "not supported") { + t.Errorf("LastInsertId err = %v, want not-supported error", err) + } +} diff --git a/spark/sql/driver/dsn.go b/spark/sql/driver/dsn.go new file mode 100644 index 0000000..8ff4e73 --- /dev/null +++ b/spark/sql/driver/dsn.go @@ -0,0 +1,52 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "errors" + "fmt" + "strings" +) + +// dsnConfig is the parsed DSN. The driver is a boring transport +// layer — it stays out of table-format concerns (Iceberg vs Delta +// vs parquet) and keeps the original URL intact for the Spark +// Connect session builder to interpret. +type dsnConfig struct { + // sparkDSN is what gets passed to NewSessionBuilder().Remote(...). + // Includes any `?token=` or other query fragment verbatim so the + // upstream builder sees the URL exactly as the caller wrote it. + sparkDSN string +} + +// parseDSN accepts the `sc://host:port[?token=...]` Spark Connect +// URL form. The driver does not interpret query parameters beyond +// validating the scheme — any `token`, `user`, or future Spark +// Connect flag rides through in `sparkDSN` untouched. +// +// Table format selection (Iceberg, Delta, parquet) is explicitly +// NOT a driver concern. Migrations and application SQL pick format +// via Spark's native `USING ` DDL clause; the transport +// layer stays boring and reusable. +func parseDSN(dsn string) (*dsnConfig, error) { + if dsn == "" { + return nil, errors.New("spark driver: DSN is required") + } + if !strings.HasPrefix(dsn, "sc://") { + return nil, fmt.Errorf("spark driver: DSN must start with sc://, got %q", dsn) + } + return &dsnConfig{sparkDSN: dsn}, nil +} diff --git a/spark/sql/driver/render.go b/spark/sql/driver/render.go new file mode 100644 index 0000000..b06e65c --- /dev/null +++ b/spark/sql/driver/render.go @@ -0,0 +1,135 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "database/sql/driver" + "fmt" + "strconv" + "strings" + "time" +) + +// render interpolates driver.NamedValue arguments into a query +// string at `$N` placeholders ($1, $2, ...). Values are quoted as +// Spark SQL literals on the way through. +// +// Spark Connect's parameterised-query proto exists but doesn't +// round-trip reliably across every Spark version the driver +// supports. Client-side rendering keeps the driver compatible with +// the full 3.4+ range and mirrors what the native Spark driver in +// datalakego/lakeorm already does for the same reason. +// +// Accepts `$N` placeholders only. `?` is reserved for a future cycle +// if a caller asks for it — most database/sql-adjacent codegen +// (sqlc, goose's own dialects, pgx patterns) emits `$N`, so the +// narrower set suffices and keeps the renderer simple. +// +// Supported types are the stdlib-sql defaults: nil, bool, int64, +// float64, string, []byte, time.Time. Anything rarer surfaces as an +// error so the caller tightens their query before seeing a silently +// mis-interpolated value on the wire. +func render(query string, args []driver.NamedValue) (string, error) { + if len(args) == 0 { + return query, nil + } + byOrdinal := make(map[int]driver.Value, len(args)) + for _, a := range args { + byOrdinal[a.Ordinal] = a.Value + } + + var b strings.Builder + b.Grow(len(query) + 32) + for i := 0; i < len(query); { + c := query[i] + + // Pass single-quoted string literals through untouched, so a + // `'$1'` in the caller's query stays literal rather than + // getting parameter-substituted. + if c == '\'' { + b.WriteByte(c) + i++ + for i < len(query) { + if query[i] == '\'' { + // Doubled '' is an escaped quote; keep scanning. + if i+1 < len(query) && query[i+1] == '\'' { + b.WriteByte('\'') + b.WriteByte('\'') + i += 2 + continue + } + b.WriteByte('\'') + i++ + break + } + b.WriteByte(query[i]) + i++ + } + continue + } + + if c == '$' && i+1 < len(query) && query[i+1] >= '0' && query[i+1] <= '9' { + j := i + 1 + for j < len(query) && query[j] >= '0' && query[j] <= '9' { + j++ + } + ord, _ := strconv.Atoi(query[i+1 : j]) + v, ok := byOrdinal[ord] + if !ok { + return "", fmt.Errorf("spark driver: no argument for placeholder $%d", ord) + } + lit, err := sqlLiteral(v) + if err != nil { + return "", fmt.Errorf("spark driver: placeholder $%d: %w", ord, err) + } + b.WriteString(lit) + i = j + continue + } + + b.WriteByte(c) + i++ + } + return b.String(), nil +} + +// sqlLiteral renders a single Go value as a Spark SQL literal token. +// Strings are single-quoted with embedded quotes doubled. Time values +// render as TIMESTAMP literals in UTC microsecond precision — matches +// Spark's default timestamp parsing. +func sqlLiteral(v driver.Value) (string, error) { + switch x := v.(type) { + case nil: + return "NULL", nil + case bool: + if x { + return "TRUE", nil + } + return "FALSE", nil + case int64: + return strconv.FormatInt(x, 10), nil + case float64: + return strconv.FormatFloat(x, 'g', -1, 64), nil + case string: + return "'" + strings.ReplaceAll(x, "'", "''") + "'", nil + case []byte: + return "'" + strings.ReplaceAll(string(x), "'", "''") + "'", nil + case time.Time: + return "TIMESTAMP '" + x.UTC().Format("2006-01-02 15:04:05.000000") + "'", nil + default: + return "", fmt.Errorf("unsupported arg type %T", v) + } +} diff --git a/spark/sql/driver/render_test.go b/spark/sql/driver/render_test.go new file mode 100644 index 0000000..358ffe6 --- /dev/null +++ b/spark/sql/driver/render_test.go @@ -0,0 +1,119 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "database/sql/driver" + "strings" + "testing" + "time" +) + +func namedArgs(vs ...any) []driver.NamedValue { + out := make([]driver.NamedValue, len(vs)) + for i, v := range vs { + out[i] = driver.NamedValue{Ordinal: i + 1, Value: v} + } + return out +} + +func TestRender_NoArgsPassesThrough(t *testing.T) { + got, err := render("SELECT 1", nil) + if err != nil { + t.Fatalf("render: %v", err) + } + if got != "SELECT 1" { + t.Errorf("got %q, want unchanged", got) + } +} + +func TestRender_Int64AndBool(t *testing.T) { + // goose's Insert always passes (version int64, is_applied bool). + // Verify the exact shape it emits renders to valid Spark SQL. + q := `INSERT INTO goose_db_version (version_id, is_applied) VALUES ($1, $2)` + got, err := render(q, namedArgs(int64(20260419000001), true)) + if err != nil { + t.Fatalf("render: %v", err) + } + want := `INSERT INTO goose_db_version (version_id, is_applied) VALUES (20260419000001, TRUE)` + if got != want { + t.Errorf("got %q, want %q", got, want) + } +} + +func TestRender_StringEscapesEmbeddedQuotes(t *testing.T) { + got, err := render(`SELECT * FROM t WHERE name = $1`, namedArgs("O'Brien")) + if err != nil { + t.Fatalf("render: %v", err) + } + if !strings.Contains(got, "'O''Brien'") { + t.Errorf("got %q, want doubled-quote escape", got) + } +} + +func TestRender_LeavesLiteralInsideStringAlone(t *testing.T) { + // A `$1` that lives inside a quoted literal must stay literal. + // Otherwise something like `WHERE s = '$1'` would try to + // substitute into the user-intended string. + got, err := render(`SELECT '$1 then $2' WHERE id = $1`, namedArgs(int64(42))) + if err != nil { + t.Fatalf("render: %v", err) + } + if !strings.Contains(got, "'$1 then $2'") { + t.Errorf("got %q, string-literal $-tokens should stay untouched", got) + } + if !strings.Contains(got, "id = 42") { + t.Errorf("got %q, outer $1 should have been replaced with 42", got) + } +} + +func TestRender_NilBecomesSQLNull(t *testing.T) { + got, err := render(`SELECT $1`, namedArgs(nil)) + if err != nil { + t.Fatalf("render: %v", err) + } + if got != "SELECT NULL" { + t.Errorf("got %q, want SELECT NULL", got) + } +} + +func TestRender_TimeRendersAsTimestampLiteral(t *testing.T) { + ts := time.Date(2026, 4, 19, 12, 34, 56, 0, time.UTC) + got, err := render(`SELECT $1`, namedArgs(ts)) + if err != nil { + t.Fatalf("render: %v", err) + } + if !strings.Contains(got, "TIMESTAMP '2026-04-19 12:34:56") { + t.Errorf("got %q, want TIMESTAMP literal", got) + } +} + +func TestRender_UnsupportedTypeErrors(t *testing.T) { + type custom struct{ N int } + _, err := render(`SELECT $1`, namedArgs(custom{N: 5})) + if err == nil || !strings.Contains(err.Error(), "unsupported arg type") { + t.Errorf("err = %v, want unsupported-arg-type", err) + } +} + +func TestRender_MissingOrdinalErrors(t *testing.T) { + // $3 referenced but only two args supplied — surfaces at render + // time rather than on the wire, so the error points at the query. + _, err := render(`INSERT INTO t VALUES ($1, $2, $3)`, namedArgs(int64(1), int64(2))) + if err == nil || !strings.Contains(err.Error(), "$3") { + t.Errorf("err = %v, want missing-ordinal error naming $3", err) + } +} diff --git a/spark/sql/driver/result.go b/spark/sql/driver/result.go new file mode 100644 index 0000000..b24384e --- /dev/null +++ b/spark/sql/driver/result.go @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import "errors" + +// result is the driver.Result returned by every Exec / ExecContext. +// Spark Connect's session.Sql doesn't surface a row-count the way +// Postgres' CommandComplete message does, so we return -1 per the +// driver.Result docs: "RowsAffected returns the number of rows +// affected by an update, insert, or delete. Not every database or +// database driver may support this." +type result struct{} + +// LastInsertId is not supported — lakehouse tables don't have +// auto-increment semantics. Returns an error so callers that reach +// for it see the intent clearly rather than getting a misleading +// zero. +func (result) LastInsertId() (int64, error) { + return 0, errors.New("spark driver: LastInsertId is not supported; lakehouse tables have no auto-increment sequence") +} + +// RowsAffected returns -1 to signal "unknown." Matches the +// convention several other drivers (notably snowflake, clickhouse) +// follow when the underlying engine doesn't emit the metric. +func (result) RowsAffected() (int64, error) { + return -1, nil +} diff --git a/spark/sql/driver/rows.go b/spark/sql/driver/rows.go new file mode 100644 index 0000000..7a39b54 --- /dev/null +++ b/spark/sql/driver/rows.go @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "database/sql/driver" + "io" + + "github.com/datalake-go/spark-connect-go/spark/sql/types" +) + +// rows wraps a materialised slice of sparksql types.Row values so +// they satisfy database/sql/driver.Rows. The SELECT path in stmt.go +// collects a DataFrame first and hands the result here — fine for +// the small reads goose makes against the version table; larger +// scans should bypass the database/sql driver and use the native +// DataFrame / iter.Seq2 path instead. +type rows struct { + rows []types.Row + cols []string + index int +} + +// newRows builds a Rows handle from a slice of sparksql rows. +// Columns are read from the first row's FieldNames; if the slice +// is empty, the column list is empty and Next immediately reports +// io.EOF. +func newRows(collected []types.Row) *rows { + var cols []string + if len(collected) > 0 { + cols = collected[0].FieldNames() + } + return &rows{rows: collected, cols: cols} +} + +// Columns returns the column names. Implements +// database/sql/driver.Rows. +func (r *rows) Columns() []string { return r.cols } + +// Close releases any driver-held resources. The slice is already +// fully materialised; this is a no-op. +// Implements database/sql/driver.Rows. +func (r *rows) Close() error { return nil } + +// Next advances to the next row and writes its values into dest. +// Returns io.EOF when the slice is exhausted. Implements +// database/sql/driver.Rows. +func (r *rows) Next(dest []driver.Value) error { + if r.index >= len(r.rows) { + return io.EOF + } + values := r.rows[r.index].Values() + for i := range dest { + if i >= len(values) { + dest[i] = nil + continue + } + dest[i] = values[i] + } + r.index++ + return nil +} diff --git a/spark/sql/driver/stmt.go b/spark/sql/driver/stmt.go new file mode 100644 index 0000000..e1c5dc2 --- /dev/null +++ b/spark/sql/driver/stmt.go @@ -0,0 +1,117 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package driver + +import ( + "context" + "database/sql/driver" +) + +// stmt wraps one query string against a conn. The driver doesn't +// cache a prepared plan on the server; every Exec/Query re-sends +// the statement. Works for goose's load (dozens of statements per +// migration run) and keeps the implementation simple; v1+ adds a +// server-side plan cache if latency-sensitive callers appear. +// +// Implements: +// - driver.Stmt (legacy Close / NumInput / Exec / Query) +// - driver.StmtExecContext +// - driver.StmtQueryContext +type stmt struct { + conn *conn + query string +} + +// Close is a no-op; the statement holds no server-side state. +// Implements database/sql/driver.Stmt. +func (*stmt) Close() error { return nil } + +// NumInput reports -1 — "driver doesn't know the number of +// placeholders." The renderer parses `$N` tokens inline at Exec / +// Query time, so the sql package's up-front arg-count check doesn't +// apply. Returning -1 tells database/sql to hand any arg count +// through unchanged. +// +// Implements database/sql/driver.Stmt. +func (*stmt) NumInput() int { return -1 } + +// Exec is the legacy (non-context) Exec entry. Implements +// database/sql/driver.Stmt. +func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { + named := make([]driver.NamedValue, len(args)) + for i, v := range args { + named[i] = driver.NamedValue{Ordinal: i + 1, Value: v} + } + return s.ExecContext(context.Background(), named) +} + +// Query is the legacy (non-context) Query entry. Implements +// database/sql/driver.Stmt. +func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { + named := make([]driver.NamedValue, len(args)) + for i, v := range args { + named[i] = driver.NamedValue{Ordinal: i + 1, Value: v} + } + return s.QueryContext(context.Background(), named) +} + +// ExecContext runs the statement and discards rows. Returns a +// Result with RowsAffected=-1 because Spark Connect doesn't surface +// that metric reliably at the session.Sql layer. database/sql +// allows -1 per the driver.Result docs. +// +// Arguments are rendered into the query at $N placeholders before +// the session call fires; Spark Connect's protocol-level parameter +// binding doesn't round-trip reliably across every supported Spark +// version, so client-side rendering is the compatible path. +// +// Implements database/sql/driver.StmtExecContext. +func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + q, err := render(s.query, args) + if err != nil { + return nil, err + } + if _, err := s.conn.session.Sql(ctx, q); err != nil { + return nil, err + } + return result{}, nil +} + +// QueryContext runs the statement and returns rows. The underlying +// DataFrame is materialised via Collect and walked by the Rows +// wrapper. For small result sets (the version-table SELECTs goose +// fires) this is right-sized; large result sets should bypass the +// driver and use the native DataFrame / iter.Seq2 path directly. +// +// Arguments are rendered into the query at $N placeholders before +// the session call fires — see ExecContext for rationale. +// +// Implements database/sql/driver.StmtQueryContext. +func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + q, err := render(s.query, args) + if err != nil { + return nil, err + } + df, err := s.conn.session.Sql(ctx, q) + if err != nil { + return nil, err + } + rows, err := df.Collect(ctx) + if err != nil { + return nil, err + } + return newRows(rows), nil +} diff --git a/spark/sql/functions/buiitins.go b/spark/sql/functions/buiitins.go index 4dca8bf..8d30662 100644 --- a/spark/sql/functions/buiitins.go +++ b/spark/sql/functions/buiitins.go @@ -16,8 +16,8 @@ package functions import ( - "github.com/apache/spark-connect-go/spark/sql/column" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/column" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) func Expr(expr string) column.Column { diff --git a/spark/sql/functions/generated.go b/spark/sql/functions/generated.go index 844fa61..071d33b 100644 --- a/spark/sql/functions/generated.go +++ b/spark/sql/functions/generated.go @@ -15,7 +15,7 @@ package functions -import "github.com/apache/spark-connect-go/spark/sql/column" +import "github.com/datalake-go/spark-connect-go/spark/sql/column" // BitwiseNOT - Computes bitwise not. // diff --git a/spark/sql/group.go b/spark/sql/group.go index 3943969..87c306f 100644 --- a/spark/sql/group.go +++ b/spark/sql/group.go @@ -19,12 +19,12 @@ package sql import ( "context" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sparkerrors" - "github.com/apache/spark-connect-go/spark/sql/column" - "github.com/apache/spark-connect-go/spark/sql/functions" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/sql/column" + "github.com/datalake-go/spark-connect-go/spark/sql/functions" ) type GroupedData struct { diff --git a/spark/sql/group_test.go b/spark/sql/group_test.go index fe02b49..562ed45 100644 --- a/spark/sql/group_test.go +++ b/spark/sql/group_test.go @@ -19,10 +19,10 @@ import ( "context" "testing" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/client" - "github.com/apache/spark-connect-go/spark/client/testutils" - "github.com/apache/spark-connect-go/spark/mocks" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/client" + "github.com/datalake-go/spark-connect-go/spark/client/testutils" + "github.com/datalake-go/spark-connect-go/spark/mocks" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/plan.go b/spark/sql/plan.go index 022298f..73d5310 100644 --- a/spark/sql/plan.go +++ b/spark/sql/plan.go @@ -19,7 +19,7 @@ package sql import ( "sync/atomic" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" ) var atomicInt64 atomic.Int64 diff --git a/spark/sql/sparksession.go b/spark/sql/sparksession.go index a84bb61..071396d 100644 --- a/spark/sql/sparksession.go +++ b/spark/sql/sparksession.go @@ -23,20 +23,21 @@ import ( "time" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" - "github.com/apache/spark-connect-go/spark/client/base" + "github.com/datalake-go/spark-connect-go/spark/client/base" - "github.com/apache/spark-connect-go/spark/client/options" + "github.com/datalake-go/spark-connect-go/spark/client/options" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/client" - "github.com/apache/spark-connect-go/spark/client/channel" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/client" + "github.com/datalake-go/spark-connect-go/spark/client/channel" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" "github.com/google/uuid" + "google.golang.org/grpc" "google.golang.org/grpc/metadata" ) @@ -58,6 +59,7 @@ func NewSessionBuilder() *SparkSessionBuilder { type SparkSessionBuilder struct { connectionString string channelBuilder channel.Builder + dialOpts []grpc.DialOption } // Remote sets the connection string for remote connection @@ -71,6 +73,17 @@ func (s *SparkSessionBuilder) WithChannelBuilder(cb channel.Builder) *SparkSessi return s } +// WithDialOptions appends gRPC dial options that will be applied when +// the session builds its underlying channel. Useful for raising the +// per-call message ceiling further, tuning keepalive, or installing +// interceptors — anything Spark Connect doesn't expose as a server +// conf. Ignored when WithChannelBuilder was also called; a custom +// channel owns its own dial options. +func (s *SparkSessionBuilder) WithDialOptions(opts ...grpc.DialOption) *SparkSessionBuilder { + s.dialOpts = append(s.dialOpts, opts...) + return s +} + func (s *SparkSessionBuilder) Build(ctx context.Context) (SparkSession, error) { if s.channelBuilder == nil { cb, err := channel.NewBuilder(s.connectionString) @@ -78,6 +91,9 @@ func (s *SparkSessionBuilder) Build(ctx context.Context) (SparkSession, error) { return nil, sparkerrors.WithType(fmt.Errorf( "failed to connect to remote %s: %w", s.connectionString, err), sparkerrors.ConnectionError) } + if len(s.dialOpts) > 0 { + cb.WithDialOptions(s.dialOpts...) + } s.channelBuilder = cb } conn, err := s.channelBuilder.Build(ctx) diff --git a/spark/sql/sparksession_integration_test.go b/spark/sql/sparksession_integration_test.go index c23d671..d6cecdc 100644 --- a/spark/sql/sparksession_integration_test.go +++ b/spark/sql/sparksession_integration_test.go @@ -21,7 +21,7 @@ import ( "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/memory" - "github.com/apache/spark-connect-go/spark/sql/types" + "github.com/datalake-go/spark-connect-go/spark/sql/types" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/sparksession_test.go b/spark/sql/sparksession_test.go index 11539af..b6d1d3c 100644 --- a/spark/sql/sparksession_test.go +++ b/spark/sql/sparksession_test.go @@ -28,12 +28,13 @@ import ( "github.com/apache/arrow-go/v18/arrow/memory" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/grpc" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/client" - "github.com/apache/spark-connect-go/spark/client/testutils" - "github.com/apache/spark-connect-go/spark/mocks" - "github.com/apache/spark-connect-go/spark/sparkerrors" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/client" + "github.com/datalake-go/spark-connect-go/spark/client/testutils" + "github.com/datalake-go/spark-connect-go/spark/mocks" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) func TestSparkSessionTable(t *testing.T) { @@ -191,3 +192,19 @@ func TestWriteResultStreamsArrowResultToCollector(t *testing.T) { assert.NoError(t, err) assert.Equal(t, []any{"str2"}, vals) } + +// TestSessionBuilder_WithDialOptions_Forwarded asserts that +// caller-supplied grpc.DialOption values are forwarded to the +// channel builder the session creates. We build against a bogus +// endpoint and don't dial — grpc.NewClient is lazy and doesn't open +// a connection until the first RPC, so the interesting surface is +// that WithDialOptions composes into the builder's state without +// error. +func TestSessionBuilder_WithDialOptions_Forwarded(t *testing.T) { + ctx := context.Background() + _, err := NewSessionBuilder(). + Remote("sc://localhost"). + WithDialOptions(grpc.WithUserAgent("dorm-test")). + Build(ctx) + assert.NoError(t, err) +} diff --git a/spark/sql/typed_helpers.go b/spark/sql/typed_helpers.go new file mode 100644 index 0000000..31d33b1 --- /dev/null +++ b/spark/sql/typed_helpers.go @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "errors" + "iter" +) + +// ErrNotFound is returned by First when the DataFrame produces zero +// rows. +var ErrNotFound = errors.New("spark: no rows returned") + +// Collect materialises every row of df into a []T. Thin wrapper over +// As[T] + (*DataFrameOf[T]).Collect for callers who hold an untyped +// DataFrame and want the result in one call. +func Collect[T any](ctx context.Context, df DataFrame) ([]T, error) { + typed, err := As[T](df) + if err != nil { + return nil, err + } + return typed.Collect(ctx) +} + +// Stream yields typed rows one at a time using the untyped +// DataFrame's streaming primitive underneath. Constant memory +// regardless of result size. Schema binding happens on the first +// row; a subsequent row whose schema diverges from the first +// surfaces the error through the iterator. +// +// Consumers range over the return value with Go 1.23's iter.Seq2: +// +// for row, err := range sql.Stream[User](ctx, df) { +// if err != nil { break } +// // use row +// } +func Stream[T any](ctx context.Context, df DataFrame) iter.Seq2[T, error] { + return func(yield func(T, error) bool) { + var zero T + typed, err := As[T](df) + if err != nil { + yield(zero, err) + return + } + for row, rerr := range typed.Stream(ctx) { + if !yield(row, rerr) { + return + } + } + } +} + +// First returns the first row of df decoded as T, or ErrNotFound if +// df produced no rows. +func First[T any](ctx context.Context, df DataFrame) (*T, error) { + typed, err := As[T](df) + if err != nil { + return nil, err + } + return typed.First(ctx) +} diff --git a/spark/sql/typed_helpers_test.go b/spark/sql/typed_helpers_test.go new file mode 100644 index 0000000..d10b823 --- /dev/null +++ b/spark/sql/typed_helpers_test.go @@ -0,0 +1,55 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sql + +import ( + "context" + "errors" + "strings" + "testing" +) + +func TestErrNotFound_Sentinel(t *testing.T) { + // First's empty-result path wraps ErrNotFound; verify callers can + // branch on it via errors.Is. + wrapped := errors.Join(ErrNotFound, errors.New("context")) + if !errors.Is(wrapped, ErrNotFound) { + t.Error("ErrNotFound should match via errors.Is through errors.Join") + } +} + +func TestAs_RejectsNonStructT(t *testing.T) { + // As[T] is the sole constructor for DataFrameOf[T]; T must be a + // struct. Verify the error surfaces before any DataFrame I/O. + _, err := As[int](nil) + if err == nil || !strings.Contains(err.Error(), "must be a struct") { + t.Errorf("want struct-required error, got %v", err) + } +} + +func TestCollect_RejectsNonStructT(t *testing.T) { + _, err := Collect[int](context.Background(), nil) + if err == nil || !strings.Contains(err.Error(), "must be a struct") { + t.Errorf("want struct-required error, got %v", err) + } +} + +func TestFirst_RejectsNonStructT(t *testing.T) { + _, err := First[int](context.Background(), nil) + if err == nil || !strings.Contains(err.Error(), "must be a struct") { + t.Errorf("want struct-required error, got %v", err) + } +} diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go index e33068f..ef4c434 100644 --- a/spark/sql/types/arrow.go +++ b/spark/sql/types/arrow.go @@ -20,13 +20,13 @@ import ( "bytes" "fmt" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" "github.com/apache/arrow-go/v18/arrow/ipc" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) func ReadArrowTableToRows(table arrow.Table) ([]Row, error) { @@ -68,6 +68,12 @@ func ReadArrowTableToRows(table arrow.Table) ([]Row, error) { return result, nil } +func ReadArrowRecordToRows(record arrow.Record) ([]Row, error) { + table := array.NewTableFromRecords(record.Schema(), []arrow.Record{record}) + defer table.Release() + return ReadArrowTableToRows(table) +} + func readArrayData(t arrow.Type, data arrow.ArrayData) ([]any, error) { buf := make([]any, 0) // Switch over the type t and append the values to buf. diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go index e75ea33..9fdc40f 100644 --- a/spark/sql/types/arrow_test.go +++ b/spark/sql/types/arrow_test.go @@ -30,8 +30,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sql/types" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sql/types" ) func TestShowArrowBatchData(t *testing.T) { @@ -438,3 +438,60 @@ func TestConvertProtoDataTypeToDataType_UnsupportedType(t *testing.T) { } assert.Equal(t, "Unsupported", types.ConvertProtoDataTypeToDataType(unsupportedDataType).TypeName()) } + +func TestReadArrowBatchToRecord(t *testing.T) { + // Create a test arrow record + arrowFields := []arrow.Field{ + {Name: "col1", Type: arrow.BinaryTypes.String}, + {Name: "col2", Type: arrow.PrimitiveTypes.Int32}, + } + arrowSchema := arrow.NewSchema(arrowFields, nil) + + alloc := memory.NewGoAllocator() + recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) + defer recordBuilder.Release() + + recordBuilder.Field(0).(*array.StringBuilder).Append("test1") + recordBuilder.Field(0).(*array.StringBuilder).Append("test2") + recordBuilder.Field(1).(*array.Int32Builder).Append(100) + recordBuilder.Field(1).(*array.Int32Builder).Append(200) + + originalRecord := recordBuilder.NewRecord() + defer originalRecord.Release() + + // Serialize to arrow batch format + var buf bytes.Buffer + arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) + defer arrowWriter.Close() + + err := arrowWriter.Write(originalRecord) + require.NoError(t, err) + + // Test ReadArrowBatchToRecord + record, err := types.ReadArrowBatchToRecord(buf.Bytes(), nil) + require.NoError(t, err) + defer record.Release() + + // Verify the record was read correctly + assert.Equal(t, int64(2), record.NumRows()) + assert.Equal(t, int64(2), record.NumCols()) + assert.Equal(t, "col1", record.Schema().Field(0).Name) + assert.Equal(t, "col2", record.Schema().Field(1).Name) +} + +func TestReadArrowBatchToRecord_InvalidData(t *testing.T) { + // Test with invalid arrow data + invalidData := []byte{0x00, 0x01, 0x02} + + _, err := types.ReadArrowBatchToRecord(invalidData, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to create arrow reader") +} + +func TestReadArrowBatchToRecord_EmptyData(t *testing.T) { + // Test with empty data + emptyData := []byte{} + + _, err := types.ReadArrowBatchToRecord(emptyData, nil) + assert.Error(t, err) +} diff --git a/spark/sql/types/builtin.go b/spark/sql/types/builtin.go index 1f74695..24fc2b2 100644 --- a/spark/sql/types/builtin.go +++ b/spark/sql/types/builtin.go @@ -19,7 +19,7 @@ package types import ( "context" - proto "github.com/apache/spark-connect-go/internal/generated" + proto "github.com/datalake-go/spark-connect-go/internal/generated" ) type LiteralType interface { diff --git a/spark/sql/types/conversion.go b/spark/sql/types/conversion.go index b2652e2..8f5d7f1 100644 --- a/spark/sql/types/conversion.go +++ b/spark/sql/types/conversion.go @@ -19,8 +19,8 @@ package types import ( "errors" - "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sparkerrors" + "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sparkerrors" ) func ConvertProtoDataTypeToStructType(input *generated.DataType) (*StructType, error) { diff --git a/spark/sql/types/conversion_test.go b/spark/sql/types/conversion_test.go index cd62779..9302c56 100644 --- a/spark/sql/types/conversion_test.go +++ b/spark/sql/types/conversion_test.go @@ -19,8 +19,8 @@ package types_test import ( "testing" - proto "github.com/apache/spark-connect-go/internal/generated" - "github.com/apache/spark-connect-go/spark/sql/types" + proto "github.com/datalake-go/spark-connect-go/internal/generated" + "github.com/datalake-go/spark-connect-go/spark/sql/types" "github.com/stretchr/testify/assert" ) diff --git a/spark/sql/types/rowiterator.go b/spark/sql/types/rowiterator.go new file mode 100644 index 0000000..7f9d451 --- /dev/null +++ b/spark/sql/types/rowiterator.go @@ -0,0 +1,64 @@ +package types + +import ( + "context" + "errors" + "io" + "iter" + + "github.com/apache/arrow-go/v18/arrow" +) + +// rowIterFromRecord converts an Arrow record into a row iterator, +// releasing the record when iteration completes or the consumer stops. +func rowIterFromRecord(rec arrow.Record) iter.Seq2[Row, error] { + return func(yield func(Row, error) bool) { + defer rec.Release() + rows, err := ReadArrowRecordToRows(rec) + if err != nil { + _ = yield(nil, err) + return + } + for _, row := range rows { + if !yield(row, nil) { + return + } + } + } +} + +// NewRowSequence flattens record batches to a sequence of rows stream. +func NewRowSequence(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error]) iter.Seq2[Row, error] { + return func(yield func(Row, error) bool) { + for rec, recErr := range recordSeq { + select { + case <-ctx.Done(): + _ = yield(nil, ctx.Err()) + return + default: + } + + // Treat io.EOF as clean stream termination. Some Spark + // implementations (notably Databricks clusters as of 05/2025) + // yield EOF as an error value instead of ending the sequence. + if errors.Is(recErr, io.EOF) { + return + } + if recErr != nil { + _ = yield(nil, recErr) + return + } + if rec == nil { + _ = yield(nil, errors.New("expected non-nil arrow.Record, got nil")) + return + } + + for row, err := range rowIterFromRecord(rec) { + cont := yield(row, err) + if err != nil || !cont { + return + } + } + } + } +} diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go new file mode 100644 index 0000000..931dca6 --- /dev/null +++ b/spark/sql/types/rowiterator_test.go @@ -0,0 +1,399 @@ +package types_test + +import ( + "context" + "errors" + "io" + "iter" + "testing" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/datalake-go/spark-connect-go/spark/sql/types" +) + +// Helper function to create test records +func createTestRecord(values []string) arrow.Record { + schema := arrow.NewSchema( + []arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, + nil, + ) + + alloc := memory.NewGoAllocator() + builder := array.NewRecordBuilder(alloc, schema) + + for _, v := range values { + builder.Field(0).(*array.StringBuilder).Append(v) + } + + record := builder.NewRecord() + builder.Release() + + return record +} + +// Helper function to create a Seq2 iterator from test data +func createTestSeq2(records []arrow.Record, err error) iter.Seq2[arrow.Record, error] { + return func(yield func(arrow.Record, error) bool) { + // Yield each record + for _, record := range records { + // Retain before yielding since consumer will release + record.Retain() + if !yield(record, nil) { + return + } + } + + if err != nil { + yield(nil, err) + } + } +} + +func TestRowIterator_BasicIteration(t *testing.T) { + // Create test records + records := []arrow.Record{ + createTestRecord([]string{"row1", "row2"}), + createTestRecord([]string{"row3", "row4"}), + } + + // Clean up records after test + defer func() { + for _, r := range records { + r.Release() + } + }() + + seq2 := createTestSeq2(records, nil) + + rowIter := types.NewRowSequence(context.Background(), seq2) + + // Collect all rows + var rows []types.Row + for row, err := range rowIter { + require.NoError(t, err) + rows = append(rows, row) + } + + // Verify we got all 4 rows + assert.Len(t, rows, 4) + assert.Equal(t, "row1", rows[0].At(0)) + assert.Equal(t, "row2", rows[1].At(0)) + assert.Equal(t, "row3", rows[2].At(0)) + assert.Equal(t, "row4", rows[3].At(0)) +} + +func TestRowIterator_EmptyResult(t *testing.T) { + // Create empty Seq2 + seq2 := func(yield func(arrow.Record, error) bool) { + // Don't yield anything - sequence is immediately over + } + + next := types.NewRowSequence(context.Background(), seq2) + + // Should iterate zero times + count := 0 + for _, err := range next { + require.NoError(t, err) + count++ + } + assert.Equal(t, 0, count) +} + +func TestRowIterator_ErrorPropagation(t *testing.T) { + testErr := errors.New("test error") + + // Create Seq2 that yields one record then an error + seq2 := func(yield func(arrow.Record, error) bool) { + record := createTestRecord([]string{"row1"}) + record.Retain() // Consumer will release + if !yield(record, nil) { + record.Release() // Clean up if yield returns false + return + } + yield(nil, testErr) + } + + next := types.NewRowSequence(context.Background(), seq2) + + var rows []types.Row + var gotError error + + for row, err := range next { + if err != nil { + gotError = err + break + } + rows = append(rows, row) + } + + // Should have read first row successfully + assert.Len(t, rows, 1) + assert.Equal(t, "row1", rows[0].At(0)) + + // Should have received the error + assert.Equal(t, testErr, gotError) +} + +func TestRowIterator_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Create a Seq2 that yields records indefinitely + seq2 := func(yield func(arrow.Record, error) bool) { + for { + select { + case <-ctx.Done(): + yield(nil, ctx.Err()) + return + default: + record := createTestRecord([]string{"row"}) + record.Retain() // Consumer will release + if !yield(record, nil) { + record.Release() + return + } + } + } + } + + next := types.NewRowSequence(ctx, seq2) + + var rows []types.Row + count := 0 + + for row, err := range next { + if err != nil { + assert.ErrorIs(t, err, context.Canceled) + break + } + rows = append(rows, row) + count++ + + // Cancel after first row + if count == 1 { + cancel() + } + + if count > 10 { + break + } + } + + assert.GreaterOrEqual(t, len(rows), 1) + assert.Equal(t, "row", rows[0].At(0)) +} + +func TestRowIterator_EarlyBreak(t *testing.T) { + // Create multiple records + records := []arrow.Record{ + createTestRecord([]string{"row1"}), + createTestRecord([]string{"row2"}), + createTestRecord([]string{"row3"}), + } + + // Clean up records after test + defer func() { + for _, r := range records { + r.Release() + } + }() + + seq2 := createTestSeq2(records, nil) + + next := types.NewRowSequence(context.Background(), seq2) + + // Read only one row then break + var rows []types.Row + for row, err := range next { + require.NoError(t, err) + rows = append(rows, row) + if len(rows) >= 1 { + break // Early termination + } + } + + // Should have only one row + assert.Len(t, rows, 1) + assert.Equal(t, "row1", rows[0].At(0)) +} + +func TestRowIterator_EmptyBatchHandling(t *testing.T) { + // Test handling of empty records (0 rows but valid record) + emptyRecord := createTestRecord([]string{}) // No rows + validRecord := createTestRecord([]string{"row1"}) + + records := []arrow.Record{emptyRecord, validRecord} + defer func() { + for _, r := range records { + r.Release() + } + }() + + seq2 := createTestSeq2(records, nil) + next := types.NewRowSequence(context.Background(), seq2) + + // Should skip empty batch and return row from second batch + var rows []types.Row + for row, err := range next { + require.NoError(t, err) + rows = append(rows, row) + } + + assert.Len(t, rows, 1) + assert.Equal(t, "row1", rows[0].At(0)) +} + +func TestRowIterator_DatabricksEOFBehavior(t *testing.T) { + // Test Databricks-specific behavior where io.EOF is sent as an error + // value rather than just ending the sequence. NewRowSequence treats + // io.EOF as clean termination. + seq2 := func(yield func(arrow.Record, error) bool) { + record1 := createTestRecord([]string{"row1", "row2"}) + record1.Retain() + if !yield(record1, nil) { + record1.Release() + return + } + + record2 := createTestRecord([]string{"row3"}) + record2.Retain() + if !yield(record2, nil) { + record2.Release() + return + } + + // Databricks sends io.EOF as error — should terminate cleanly + yield(nil, io.EOF) + } + + next := types.NewRowSequence(context.Background(), seq2) + + // Read all rows successfully + var rows []types.Row + for row, err := range next { + require.NoError(t, err) + rows = append(rows, row) + } + + // Should have all 3 rows + assert.Len(t, rows, 3) + assert.Equal(t, "row1", rows[0].At(0)) + assert.Equal(t, "row2", rows[1].At(0)) + assert.Equal(t, "row3", rows[2].At(0)) +} + +func TestRowIterator_NilRecordReturnsError(t *testing.T) { + // Test that receiving a nil record returns an error + seq2 := func(yield func(arrow.Record, error) bool) { + record := createTestRecord([]string{"row1"}) + record.Retain() + if !yield(record, nil) { + record.Release() + return + } + + // Yield nil record (shouldn't happen in production) + yield(nil, nil) + } + + next := types.NewRowSequence(context.Background(), seq2) + + var rows []types.Row + var gotError error + + for row, err := range next { + if err != nil { + gotError = err + break + } + rows = append(rows, row) + } + + // Should have read first row successfully + assert.Len(t, rows, 1) + assert.Equal(t, "row1", rows[0].At(0)) + + // Should have received error about nil record + assert.Error(t, gotError) + assert.Contains(t, gotError.Error(), "expected non-nil arrow.Record, got nil") +} + +func TestRowSeq2_DirectUsage(t *testing.T) { + // Test using NewRowSequence directly as a Seq2 + records := []arrow.Record{ + createTestRecord([]string{"row1", "row2"}), + createTestRecord([]string{"row3"}), + } + + defer func() { + for _, r := range records { + r.Release() + } + }() + + recordSeq := createTestSeq2(records, nil) + rowSeq := types.NewRowSequence(context.Background(), recordSeq) + + var rows []types.Row + for row, err := range rowSeq { + require.NoError(t, err) + rows = append(rows, row) + } + + // Should have all 3 rows flattened + assert.Len(t, rows, 3) + assert.Equal(t, "row1", rows[0].At(0)) + assert.Equal(t, "row2", rows[1].At(0)) + assert.Equal(t, "row3", rows[2].At(0)) +} + +func TestRowIterator_MultipleIterations(t *testing.T) { + // Test that ranging the same iterator twice works safely when the + // upstream is single-use (like a real gRPC stream). + records := []arrow.Record{ + createTestRecord([]string{"row1", "row2"}), + } + + defer func() { + for _, r := range records { + r.Release() + } + }() + + // Build a single-use upstream to simulate a gRPC stream. + exhausted := false + seq2 := func(yield func(arrow.Record, error) bool) { + if exhausted { + return + } + exhausted = true + for _, record := range records { + record.Retain() + if !yield(record, nil) { + return + } + } + } + + next := types.NewRowSequence(context.Background(), seq2) + + var rows1 []types.Row + for row, err := range next { + require.NoError(t, err) + rows1 = append(rows1, row) + } + assert.Len(t, rows1, 2) + + // Second iteration — upstream exhausted, should yield nothing + var rows2 []types.Row + for row, err := range next { + require.NoError(t, err) + rows2 = append(rows2, row) + } + assert.Len(t, rows2, 0) +} diff --git a/spark/sql/utils/consts.go b/spark/sql/utils/consts.go index e1312ef..83bdad2 100644 --- a/spark/sql/utils/consts.go +++ b/spark/sql/utils/consts.go @@ -15,7 +15,7 @@ package utils -import proto "github.com/apache/spark-connect-go/internal/generated" +import proto "github.com/datalake-go/spark-connect-go/internal/generated" type ExplainMode int