diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 249bd33..ef218ab 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -32,7 +32,7 @@ on: - master env: - SPARK_VERSION: '4.0.0' + SPARK_VERSION: '4.0.1' HADOOP_VERSION: '3' permissions: diff --git a/internal/tests/integration/dataframe_test.go b/internal/tests/integration/dataframe_test.go index d383ca1..df16620 100644 --- a/internal/tests/integration/dataframe_test.go +++ b/internal/tests/integration/dataframe_test.go @@ -1196,7 +1196,6 @@ func TestDataFrame_RangeIter(t *testing.T) { } assert.Equal(t, 10, cnt) - // Check that errors are properly propagated df, err = spark.Sql(ctx, "select if(id = 5, raise_error('handle'), false) from range(10)") assert.NoError(t, err) for _, err := range df.All(ctx) { @@ -1224,3 +1223,197 @@ func TestDataFrame_SchemaTreeString(t *testing.T) { assert.Contains(t, ts, "|-- second: array") assert.Contains(t, ts, "|-- third: map") } + +func TestDataFrame_StreamRowsThroughChannel(t *testing.T) { + // Demonstrates how StreamRows can be used to pipe data through a channel for scenarios like: + // - Proxying Spark data through gRPC streaming or unary RPCs + // - Implementing producer-consumer patterns with backpressure based on Spark results + // - Buffering and rate-limiting data flow between systems + ctx, spark := connect() + df, err := spark.Sql(ctx, "select id, id * 2 as doubled, 'test_' || cast(id as string) as label from range(100)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + rowChan := make(chan map[string]interface{}, 10) + errChan := make(chan error, 1) + + go func() { + defer close(rowChan) + for row, err := range iter { + if err != nil { + errChan <- err + return + } + + rowData := make(map[string]interface{}) + names := row.FieldNames() + for i, name := range names { + rowData[name] = row.At(i) + } + + select { + case rowChan <- rowData: + case <-ctx.Done(): + errChan <- ctx.Err() + return + } + } + }() + + // In a gRPC scenario, this would be your response handler + receivedRows := make([]map[string]interface{}, 0) + consumerDone := make(chan struct{}) + + go func() { + defer close(consumerDone) + for rowData := range rowChan { + receivedRows = append(receivedRows, rowData) + + id := rowData["id"].(int64) + doubled := rowData["doubled"].(int64) + assert.Equal(t, id*2, doubled) + } + }() + + <-consumerDone + + select { + case err := <-errChan: + assert.NoError(t, err) + default: + // continue + } + + assert.Equal(t, 100, len(receivedRows)) + + assert.Equal(t, int64(0), receivedRows[0]["id"]) + assert.Equal(t, int64(99), receivedRows[99]["id"]) + assert.Equal(t, "test_0", receivedRows[0]["label"]) + assert.Equal(t, "test_99", receivedRows[99]["label"]) +} + +func TestDataFrame_StreamRows(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select * from range(100)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + assert.NotNil(t, iter) + + cnt := 0 + + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + assert.Equal(t, 1, row.Len()) + cnt++ + } + assert.Equal(t, 100, cnt) +} + +func TestDataFrame_StreamRowsWithFilter(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select * from range(100)") + assert.NoError(t, err) + + df, err = df.Filter(ctx, functions.Col("id").Lt(functions.IntLit(10))) + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + cnt := 0 + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + assert.Equal(t, 1, row.Len()) + assert.Less(t, row.At(0).(int64), int64(10)) + cnt++ + } + assert.Equal(t, 10, cnt) +} + +func TestDataFrame_StreamRowsEmpty(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select * from range(0)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + cnt := 0 + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + cnt++ + } + assert.Equal(t, 0, cnt) +} + +func TestDataFrame_StreamRowsWithError(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select if(id = 5, raise_error('test error'), id) as id from range(10)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + errorEncountered := false + for _, err := range iter { + if err != nil { + errorEncountered = true + assert.Error(t, err) + break + } + } + assert.True(t, errorEncountered, "Expected to encounter an error during iteration") +} + +func TestDataFrame_StreamRowsMultipleColumns(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select id, id * 2 as doubled, 'test' as name from range(50)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + cnt := 0 + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + assert.Equal(t, 3, row.Len()) + + id := row.At(0).(int64) + doubled := row.At(1).(int64) + name := row.At(2).(string) + + assert.Equal(t, id*2, doubled) + assert.Equal(t, "test", name) + cnt++ + } + assert.Equal(t, 50, cnt) +} + +func TestDataFrame_StreamRowsLargeDataset(t *testing.T) { + ctx, spark := connect() + df, err := spark.Sql(ctx, "select * from range(10000)") + assert.NoError(t, err) + + iter, err := df.StreamRows(ctx) + assert.NoError(t, err) + + cnt := 0 + lastValue := int64(-1) + for row, err := range iter { + assert.NoError(t, err) + assert.NotNil(t, row) + currentValue := row.At(0).(int64) + assert.Greater(t, currentValue, lastValue) + lastValue = currentValue + cnt++ + } + assert.Equal(t, 10000, cnt) +} diff --git a/spark/client/base/base.go b/spark/client/base/base.go index c5f01c7..ee1ae79 100644 --- a/spark/client/base/base.go +++ b/spark/client/base/base.go @@ -17,6 +17,7 @@ package base import ( "context" + "iter" "github.com/apache/spark-connect-go/spark/sql/utils" @@ -47,6 +48,9 @@ type SparkConnectClient interface { } type ExecuteResponseStream interface { + // ToTable consumes all arrow.Record batches to a single arrow.Table. Useful for collecting all query results into a client DF. ToTable() (*types.StructType, arrow.Table, error) + // ToRecordSequence lazily consumes each arrow.Record retrieved by a query. Useful for streaming query results. + ToRecordSequence(ctx context.Context) iter.Seq2[arrow.Record, error] Properties() map[string]any } diff --git a/spark/client/client.go b/spark/client/client.go index a0d7a4c..7eceee1 100644 --- a/spark/client/client.go +++ b/spark/client/client.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "io" + "iter" "github.com/apache/spark-connect-go/spark/sql/utils" @@ -434,6 +435,103 @@ func (c *ExecutePlanClient) ToTable() (*types.StructType, arrow.Table, error) { } } +// ToRecordSequence returns a single Seq2 iterator that directly yields results as they arrive. +func (c *ExecutePlanClient) ToRecordSequence(ctx context.Context) iter.Seq2[arrow.Record, error] { + return func(yield func(arrow.Record, error) bool) { + // Represents Spark's reattachable execution. + // Tracks logical completion locally to avoid racing on shared struct state. + // Spliced from ToTable. We may eventually want to DRY up these workflows. + done := false + + for { + select { + case <-ctx.Done(): + yield(nil, ctx.Err()) + return + default: + } + + resp, err := c.responseStream.Recv() + + select { + case <-ctx.Done(): + yield(nil, ctx.Err()) + return + default: + } + + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + if se := sparkerrors.FromRPCError(err); se != nil { + yield(nil, sparkerrors.WithType(se, sparkerrors.ExecutionError)) + } else { + yield(nil, err) + } + return + } + + if resp == nil { + continue + } + + if resp.GetSessionId() != c.sessionId { + yield(nil, sparkerrors.WithType( + &sparkerrors.InvalidServerSideSessionDetailsError{ + OwnSessionId: c.sessionId, + ReceivedSessionId: resp.GetSessionId(), + }, sparkerrors.InvalidServerSideSessionError)) + return + } + + if resp.Schema != nil { + var schemaErr error + c.schema, schemaErr = types.ConvertProtoDataTypeToStructType(resp.Schema) + if schemaErr != nil { + yield(nil, sparkerrors.WithType(schemaErr, sparkerrors.ExecutionError)) + return + } + } + + switch x := resp.ResponseType.(type) { + case *proto.ExecutePlanResponse_SqlCommandResult_: + if val := x.SqlCommandResult.GetRelation(); val != nil { + c.properties["sql_command_result"] = val + } + + case *proto.ExecutePlanResponse_ArrowBatch_: + record, err := types.ReadArrowBatchToRecord(x.ArrowBatch.Data, c.schema) + if err != nil { + yield(nil, err) + return + } + if !yield(record, nil) { + return + } + + case *proto.ExecutePlanResponse_ResultComplete_: + done = true + return + + case *proto.ExecutePlanResponse_ExecutionProgress_: + // Progress updates - ignore for now + + default: + // Explicitly ignore unknown message types + } + } + + // Check that the result is logically complete. With re-attachable execution + // the server may interrupt the connection, and we need a ResultComplete + // message to confirm the full result was received. + if c.opts.ReattachExecution && !done { + yield(nil, sparkerrors.WithType(fmt.Errorf("the result is not complete"), sparkerrors.ExecutionError)) + } + } +} + func NewExecuteResponseStream( responseClient proto.SparkConnectService_ExecutePlanClient, sessionId string, diff --git a/spark/client/client_test.go b/spark/client/client_test.go index ca7019c..d9f9ade 100644 --- a/spark/client/client_test.go +++ b/spark/client/client_test.go @@ -1,110 +1,597 @@ -// Licensed to the Apache Software Foundation (ASF) under one or more -// contributor license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright ownership. -// The ASF licenses this file to You under the Apache License, Version 2.0 -// (the "License"); you may not use this file except in compliance with -// the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - package client_test import ( + "bytes" "context" + "errors" + "iter" "testing" + "time" - "github.com/google/uuid" - + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/ipc" + "github.com/apache/arrow-go/v18/arrow/memory" proto "github.com/apache/spark-connect-go/internal/generated" "github.com/apache/spark-connect-go/spark/client" - "github.com/apache/spark-connect-go/spark/client/testutils" "github.com/apache/spark-connect-go/spark/mocks" "github.com/apache/spark-connect-go/spark/sparkerrors" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func TestAnalyzePlanCallsAnalyzePlanOnClient(t *testing.T) { +func TestToRecordIterator_ChannelClosureWithoutData(t *testing.T) { + // Iterator should complete without yielding any records when no arrow batches present ctx := context.Background() - response := &proto.AnalyzePlanResponse{} - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(nil, response, nil, nil), nil, mocks.MockSessionId) - resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) - assert.NoError(t, err) - assert.NotNil(t, resp) + + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "test_column", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + + recordsReceived := 0 + errorsReceived := 0 + + for record, err := range iter { + if err != nil { + errorsReceived++ + break + } + if record != nil { + recordsReceived++ + } + } + + assert.Equal(t, 0, recordsReceived, "No records should be sent when no arrow batches present") + assert.Equal(t, 0, errorsReceived, "No errors should occur") } -func TestAnalyzePlanFailsIfClientFails(t *testing.T) { +func TestToRecordIterator_ArrowBatchStreaming(t *testing.T) { + // Arrow batch data should be correctly streamed ctx := context.Background() - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(nil, nil, assert.AnError, nil), nil, mocks.MockSessionId) - resp, err := c.AnalyzePlan(ctx, &proto.Plan{}) - assert.Nil(t, resp) - assert.Error(t, err) + + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + + arrowData := createTestArrowBatch(t, []string{"value1", "value2", "value3"}) + + arrowBatch := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: arrowData, + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + arrowBatch, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + + records := collectRecordsFromSeq2(t, iter) + + require.Len(t, records, 1, "Should receive exactly one record") + + record := records[0] + assert.Equal(t, int64(3), record.NumRows(), "Record should have 3 rows") + assert.Equal(t, int64(1), record.NumCols(), "Record should have 1 column") + + col := record.Column(0).(*array.String) + assert.Equal(t, "value1", col.Value(0)) + assert.Equal(t, "value2", col.Value(1)) + assert.Equal(t, "value3", col.Value(2)) } -func TestExecutePlanCallsExecutePlanOnClient(t *testing.T) { +func TestToRecordIterator_MultipleArrowBatches(t *testing.T) { + // Multiple arrow batches should be streamed in order ctx := context.Background() - plan := &proto.Plan{} - // Generate a mock client - responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone) + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + + batch1 := createTestArrowBatch(t, []string{"batch1_row1", "batch1_row2"}) + batch2 := createTestArrowBatch(t, []string{"batch2_row1", "batch2_row2"}) + + arrowBatch1 := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: batch1, + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + } + + arrowBatch2 := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: batch2, + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + arrowBatch1, + arrowBatch2, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select col")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + records := collectRecordsFromSeq2(t, iter) + + require.Len(t, records, 2, "Should receive exactly two records") + + // Verify first batch + col1 := records[0].Column(0).(*array.String) + assert.Equal(t, "batch1_row1", col1.Value(0)) + assert.Equal(t, "batch1_row2", col1.Value(1)) + + // Verify second batch + col2 := records[1].Column(0).(*array.String) + assert.Equal(t, "batch2_row1", col2.Value(0)) + assert.Equal(t, "batch2_row2", col2.Value(1)) +} + +func TestToRecordIterator_ContextCancellationStopsStreaming(t *testing.T) { + // Context cancellation should stop streaming + ctx, cancel := context.WithCancel(context.Background()) + + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col0", + DataType: &proto.DataType{ + Kind: &proto.DataType_Integer_{ + Integer: &proto.DataType_Integer{}, + }, + }, + Nullable: true, + }, + }, + }, + }, + }, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + &mocks.ExecutePlanResponseDone, + &mocks.ExecutePlanResponseEOF) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + + // Cancel the context immediately + cancel() - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, mocks.MockSessionId) - resp, err := c.ExecutePlan(ctx, plan) - assert.NoError(t, err) - assert.NotNil(t, resp) + // Try to consume the iterator + timeout := time.After(100 * time.Millisecond) + done := make(chan bool) + + go func() { + for _, err := range iter { + if err != nil { + // Got an error - verify it's context cancellation + assert.ErrorIs(t, err, context.Canceled) + done <- true + return + } + } + done <- true + }() + + select { + case <-done: + // Good - iteration completed + case <-timeout: + // Timeout is acceptable as cancellation might have happened after all responses were processed + } } -func TestExecutePlanCallsExecuteCommandOnClient(t *testing.T) { +func TestToRecordIterator_RPCErrorPropagation(t *testing.T) { + // RPC errors should be properly propagated ctx := context.Background() - plan := &proto.Plan{} - - // Generate a mock client - responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) - - // Check that the execution fails if no command is supplied. - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, mocks.MockSessionId) - _, _, _, err := c.ExecuteCommand(ctx, plan) - assert.ErrorIs(t, err, sparkerrors.ExecutionError) - - // Generate a command and the execution should succeed. - sqlCommand := mocks.NewSqlCommand("select range(10)") - c = client.NewSparkExecutorFromClient(testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, mocks.MockSessionId) - _, _, _, err = c.ExecuteCommand(ctx, sqlCommand) - assert.NoError(t, err) + + schemaResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col1", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + } + + expectedError := errors.New("simulated RPC error") + errorResponse := &mocks.MockResponse{ + Err: expectedError, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + schemaResponse, + errorResponse) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + + errorReceived := false + for _, err := range iter { + if err != nil { + assert.Error(t, err) + assert.Contains(t, err.Error(), "simulated RPC error") + errorReceived = true + break + } + } + + assert.True(t, errorReceived, "Expected RPC error") } -func Test_ExecuteWithWrongSession(t *testing.T) { +func TestToRecordIterator_SessionValidation(t *testing.T) { + // Session validation error should be returned for wrong session ID ctx := context.Background() - sqlCommand := mocks.NewSqlCommand("select range(10)") - // Generate a mock client - responseStream := mocks.NewProtoClientMock(&mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) + wrongSessionResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: "wrong-session-id", + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "col0", + DataType: &proto.DataType{ + Kind: &proto.DataType_Integer_{ + Integer: &proto.DataType_Integer{}, + }, + }, + Nullable: true, + }, + }, + }, + }, + }, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + wrongSessionResponse, + &mocks.ExecutePlanResponseEOF) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("select 1")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + + errorReceived := false + for _, err := range iter { + if err != nil { + assert.Error(t, err) + assert.ErrorIs(t, err, sparkerrors.InvalidServerSideSessionError) + errorReceived = true + break + } + } - // Check that the execution fails if no command is supplied. - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, uuid.NewString()) - _, _, _, err := c.ExecuteCommand(ctx, sqlCommand) - assert.ErrorIs(t, err, sparkerrors.InvalidServerSideSessionError) + assert.True(t, errorReceived, "Expected session validation error") } -func Test_Execute_SchemaParsingFails(t *testing.T) { +func TestToRecordIterator_SqlCommandResultProperties(t *testing.T) { + // SQL command results should be captured in properties ctx := context.Background() - sqlCommand := mocks.NewSqlCommand("select range(10)") - responseStream := mocks.NewProtoClientMock( - &mocks.ExecutePlanResponseBrokenSchema, + + sqlResultResponse := &mocks.MockResponse{ + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{ + Relation: &proto.Relation{ + RelType: &proto.Relation_Sql{ + Sql: &proto.SQL{Query: "test query"}, + }, + }, + }, + }, + }, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, + sqlResultResponse, &mocks.ExecutePlanResponseDone, &mocks.ExecutePlanResponseEOF) - c := client.NewSparkExecutorFromClient( - testutils.NewConnectServiceClientMock(responseStream, nil, nil, t), nil, mocks.MockSessionId) - _, _, _, err := c.ExecuteCommand(ctx, sqlCommand) - assert.ErrorIs(t, err, sparkerrors.ExecutionError) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("test query")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + _ = collectRecordsFromSeq2(t, iter) + + // Properties should contain the SQL command result + props := stream.(*client.ExecutePlanClient).Properties() + assert.NotNil(t, props["sql_command_result"]) +} + +func TestToRecordIterator_MixedResponseTypes(t *testing.T) { + // Mixed response types should be handled correctly in realistic order + ctx := context.Background() + + responses := []*mocks.MockResponse{ + // Schema first + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + Schema: &proto.DataType{ + Kind: &proto.DataType_Struct_{ + Struct: &proto.DataType_Struct{ + Fields: []*proto.DataType_StructField{ + { + Name: "id", + DataType: &proto.DataType{ + Kind: &proto.DataType_String_{ + String_: &proto.DataType_String{}, + }, + }, + Nullable: false, + }, + }, + }, + }, + }, + }, + }, + // SQL command result + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_SqlCommandResult_{ + SqlCommandResult: &proto.ExecutePlanResponse_SqlCommandResult{ + Relation: &proto.Relation{ + RelType: &proto.Relation_Sql{ + Sql: &proto.SQL{Query: "SELECT * FROM table"}, + }, + }, + }, + }, + }, + }, + // Progress updates + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_ExecutionProgress_{ + ExecutionProgress: &proto.ExecutePlanResponse_ExecutionProgress{ + Stages: nil, + NumInflightTasks: 0, + }, + }, + }, + }, + // Arrow batch + { + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: createTestArrowBatch(t, []string{"row1"}), + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + }, + // More progress + { + Resp: &proto.ExecutePlanResponse{ + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + ResponseType: &proto.ExecutePlanResponse_ExecutionProgress_{ + ExecutionProgress: &proto.ExecutePlanResponse_ExecutionProgress{ + Stages: nil, + NumInflightTasks: 0, + }, + }, + }, + }, + // Another arrow batch + { + Resp: &proto.ExecutePlanResponse{ + ResponseType: &proto.ExecutePlanResponse_ArrowBatch_{ + ArrowBatch: &proto.ExecutePlanResponse_ArrowBatch{ + Data: createTestArrowBatch(t, []string{"row2", "row3"}), + }, + }, + SessionId: mocks.MockSessionId, + OperationId: mocks.MockOperationId, + }, + }, + // Result complete + &mocks.ExecutePlanResponseDone, + // EOF + &mocks.ExecutePlanResponseEOF, + } + + c := client.NewTestConnectClientFromResponses(mocks.MockSessionId, responses...) + + stream, err := c.ExecutePlan(ctx, mocks.NewSqlCommand("SELECT * FROM table")) + require.NoError(t, err) + + iter := stream.ToRecordSequence(ctx) + records := collectRecordsFromSeq2(t, iter) + + require.Len(t, records, 2, "Should receive exactly two arrow batches") + + assert.Equal(t, int64(1), records[0].NumRows()) + assert.Equal(t, int64(2), records[1].NumRows()) +} + +// Helper function to create test arrow batch data +func createTestArrowBatch(t *testing.T, values []string) []byte { + t.Helper() + + arrowFields := []arrow.Field{ + {Name: "col", Type: arrow.BinaryTypes.String}, + } + arrowSchema := arrow.NewSchema(arrowFields, nil) + + alloc := memory.NewGoAllocator() + recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) + defer recordBuilder.Release() + + stringBuilder := recordBuilder.Field(0).(*array.StringBuilder) + for _, v := range values { + stringBuilder.Append(v) + } + + record := recordBuilder.NewRecord() + defer record.Release() + + var buf bytes.Buffer + arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) + defer arrowWriter.Close() + + err := arrowWriter.Write(record) + require.NoError(t, err) + err = arrowWriter.Close() + require.NoError(t, err) + + return buf.Bytes() +} + +// Helper function to collect all records from Seq2 iterator +func collectRecordsFromSeq2(t *testing.T, iter iter.Seq2[arrow.Record, error]) []arrow.Record { + t.Helper() + + var records []arrow.Record + + for record, err := range iter { + if err != nil { + t.Fatalf("Unexpected error: %v", err) + break + } + if record != nil { + records = append(records, record) + } + } + + return records } diff --git a/spark/sql/dataframe.go b/spark/sql/dataframe.go index 4c7a563..b827af9 100644 --- a/spark/sql/dataframe.go +++ b/spark/sql/dataframe.go @@ -200,6 +200,11 @@ type DataFrame interface { // Sort returns a new DataFrame sorted by the specified columns. Sort(ctx context.Context, columns ...column.Convertible) (DataFrame, error) Stat() DataFrameStatFunctions + // StreamRows returns a lazy iterator over rows from Spark. + // No rows are fetched from Spark over gRPC until the previous one has been consumed. + // It provides no internal buffering: each Row is produced only when the caller + // requests it, ensuring client back-pressure is respected. + StreamRows(ctx context.Context) (iter.Seq2[types.Row, error], error) // Subtract subtracts the other DataFrame from the current DataFrame. And only returns // distinct rows. Subtract(ctx context.Context, other DataFrame) DataFrame @@ -936,6 +941,17 @@ func (df *dataFrameImpl) ToArrow(ctx context.Context) (*arrow.Table, error) { return &table, nil } +func (df *dataFrameImpl) StreamRows(ctx context.Context) (iter.Seq2[types.Row, error], error) { + responseClient, err := df.session.client.ExecutePlan(ctx, df.createPlan()) + if err != nil { + return nil, sparkerrors.WithType(fmt.Errorf("failed to execute plan: %w", err), sparkerrors.ExecutionError) + } + + seq2 := responseClient.ToRecordSequence(ctx) + + return types.NewRowSequence(ctx, seq2), nil +} + func (df *dataFrameImpl) UnionAll(ctx context.Context, other DataFrame) DataFrame { otherDf := other.(*dataFrameImpl) isAll := true diff --git a/spark/sql/types/arrow.go b/spark/sql/types/arrow.go index e33068f..6c49bd0 100644 --- a/spark/sql/types/arrow.go +++ b/spark/sql/types/arrow.go @@ -68,6 +68,12 @@ func ReadArrowTableToRows(table arrow.Table) ([]Row, error) { return result, nil } +func ReadArrowRecordToRows(record arrow.Record) ([]Row, error) { + table := array.NewTableFromRecords(record.Schema(), []arrow.Record{record}) + defer table.Release() + return ReadArrowTableToRows(table) +} + func readArrayData(t arrow.Type, data arrow.ArrayData) ([]any, error) { buf := make([]any, 0) // Switch over the type t and append the values to buf. diff --git a/spark/sql/types/arrow_test.go b/spark/sql/types/arrow_test.go index e75ea33..92d491b 100644 --- a/spark/sql/types/arrow_test.go +++ b/spark/sql/types/arrow_test.go @@ -438,3 +438,60 @@ func TestConvertProtoDataTypeToDataType_UnsupportedType(t *testing.T) { } assert.Equal(t, "Unsupported", types.ConvertProtoDataTypeToDataType(unsupportedDataType).TypeName()) } + +func TestReadArrowBatchToRecord(t *testing.T) { + // Create a test arrow record + arrowFields := []arrow.Field{ + {Name: "col1", Type: arrow.BinaryTypes.String}, + {Name: "col2", Type: arrow.PrimitiveTypes.Int32}, + } + arrowSchema := arrow.NewSchema(arrowFields, nil) + + alloc := memory.NewGoAllocator() + recordBuilder := array.NewRecordBuilder(alloc, arrowSchema) + defer recordBuilder.Release() + + recordBuilder.Field(0).(*array.StringBuilder).Append("test1") + recordBuilder.Field(0).(*array.StringBuilder).Append("test2") + recordBuilder.Field(1).(*array.Int32Builder).Append(100) + recordBuilder.Field(1).(*array.Int32Builder).Append(200) + + originalRecord := recordBuilder.NewRecord() + defer originalRecord.Release() + + // Serialize to arrow batch format + var buf bytes.Buffer + arrowWriter := ipc.NewWriter(&buf, ipc.WithSchema(arrowSchema)) + defer arrowWriter.Close() + + err := arrowWriter.Write(originalRecord) + require.NoError(t, err) + + // Test ReadArrowBatchToRecord + record, err := types.ReadArrowBatchToRecord(buf.Bytes(), nil) + require.NoError(t, err) + defer record.Release() + + // Verify the record was read correctly + assert.Equal(t, int64(2), record.NumRows()) + assert.Equal(t, int64(2), record.NumCols()) + assert.Equal(t, "col1", record.Schema().Field(0).Name) + assert.Equal(t, "col2", record.Schema().Field(1).Name) +} + +func TestReadArrowBatchToRecord_InvalidData(t *testing.T) { + // Test with invalid arrow data + invalidData := []byte{0x00, 0x01, 0x02} + + _, err := types.ReadArrowBatchToRecord(invalidData, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "failed to create arrow reader") +} + +func TestReadArrowBatchToRecord_EmptyData(t *testing.T) { + // Test with empty data + emptyData := []byte{} + + _, err := types.ReadArrowBatchToRecord(emptyData, nil) + assert.Error(t, err) +} diff --git a/spark/sql/types/rowiterator.go b/spark/sql/types/rowiterator.go new file mode 100644 index 0000000..7f9d451 --- /dev/null +++ b/spark/sql/types/rowiterator.go @@ -0,0 +1,64 @@ +package types + +import ( + "context" + "errors" + "io" + "iter" + + "github.com/apache/arrow-go/v18/arrow" +) + +// rowIterFromRecord converts an Arrow record into a row iterator, +// releasing the record when iteration completes or the consumer stops. +func rowIterFromRecord(rec arrow.Record) iter.Seq2[Row, error] { + return func(yield func(Row, error) bool) { + defer rec.Release() + rows, err := ReadArrowRecordToRows(rec) + if err != nil { + _ = yield(nil, err) + return + } + for _, row := range rows { + if !yield(row, nil) { + return + } + } + } +} + +// NewRowSequence flattens record batches to a sequence of rows stream. +func NewRowSequence(ctx context.Context, recordSeq iter.Seq2[arrow.Record, error]) iter.Seq2[Row, error] { + return func(yield func(Row, error) bool) { + for rec, recErr := range recordSeq { + select { + case <-ctx.Done(): + _ = yield(nil, ctx.Err()) + return + default: + } + + // Treat io.EOF as clean stream termination. Some Spark + // implementations (notably Databricks clusters as of 05/2025) + // yield EOF as an error value instead of ending the sequence. + if errors.Is(recErr, io.EOF) { + return + } + if recErr != nil { + _ = yield(nil, recErr) + return + } + if rec == nil { + _ = yield(nil, errors.New("expected non-nil arrow.Record, got nil")) + return + } + + for row, err := range rowIterFromRecord(rec) { + cont := yield(row, err) + if err != nil || !cont { + return + } + } + } + } +} diff --git a/spark/sql/types/rowiterator_test.go b/spark/sql/types/rowiterator_test.go new file mode 100644 index 0000000..b9a2f46 --- /dev/null +++ b/spark/sql/types/rowiterator_test.go @@ -0,0 +1,399 @@ +package types_test + +import ( + "context" + "errors" + "io" + "iter" + "testing" + + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/apache/spark-connect-go/spark/sql/types" +) + +// Helper function to create test records +func createTestRecord(values []string) arrow.Record { + schema := arrow.NewSchema( + []arrow.Field{{Name: "col1", Type: arrow.BinaryTypes.String}}, + nil, + ) + + alloc := memory.NewGoAllocator() + builder := array.NewRecordBuilder(alloc, schema) + + for _, v := range values { + builder.Field(0).(*array.StringBuilder).Append(v) + } + + record := builder.NewRecord() + builder.Release() + + return record +} + +// Helper function to create a Seq2 iterator from test data +func createTestSeq2(records []arrow.Record, err error) iter.Seq2[arrow.Record, error] { + return func(yield func(arrow.Record, error) bool) { + // Yield each record + for _, record := range records { + // Retain before yielding since consumer will release + record.Retain() + if !yield(record, nil) { + return + } + } + + if err != nil { + yield(nil, err) + } + } +} + +func TestRowIterator_BasicIteration(t *testing.T) { + // Create test records + records := []arrow.Record{ + createTestRecord([]string{"row1", "row2"}), + createTestRecord([]string{"row3", "row4"}), + } + + // Clean up records after test + defer func() { + for _, r := range records { + r.Release() + } + }() + + seq2 := createTestSeq2(records, nil) + + rowIter := types.NewRowSequence(context.Background(), seq2) + + // Collect all rows + var rows []types.Row + for row, err := range rowIter { + require.NoError(t, err) + rows = append(rows, row) + } + + // Verify we got all 4 rows + assert.Len(t, rows, 4) + assert.Equal(t, "row1", rows[0].At(0)) + assert.Equal(t, "row2", rows[1].At(0)) + assert.Equal(t, "row3", rows[2].At(0)) + assert.Equal(t, "row4", rows[3].At(0)) +} + +func TestRowIterator_EmptyResult(t *testing.T) { + // Create empty Seq2 + seq2 := func(yield func(arrow.Record, error) bool) { + // Don't yield anything - sequence is immediately over + } + + next := types.NewRowSequence(context.Background(), seq2) + + // Should iterate zero times + count := 0 + for _, err := range next { + require.NoError(t, err) + count++ + } + assert.Equal(t, 0, count) +} + +func TestRowIterator_ErrorPropagation(t *testing.T) { + testErr := errors.New("test error") + + // Create Seq2 that yields one record then an error + seq2 := func(yield func(arrow.Record, error) bool) { + record := createTestRecord([]string{"row1"}) + record.Retain() // Consumer will release + if !yield(record, nil) { + record.Release() // Clean up if yield returns false + return + } + yield(nil, testErr) + } + + next := types.NewRowSequence(context.Background(), seq2) + + var rows []types.Row + var gotError error + + for row, err := range next { + if err != nil { + gotError = err + break + } + rows = append(rows, row) + } + + // Should have read first row successfully + assert.Len(t, rows, 1) + assert.Equal(t, "row1", rows[0].At(0)) + + // Should have received the error + assert.Equal(t, testErr, gotError) +} + +func TestRowIterator_ContextCancellation(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) + + // Create a Seq2 that yields records indefinitely + seq2 := func(yield func(arrow.Record, error) bool) { + for { + select { + case <-ctx.Done(): + yield(nil, ctx.Err()) + return + default: + record := createTestRecord([]string{"row"}) + record.Retain() // Consumer will release + if !yield(record, nil) { + record.Release() + return + } + } + } + } + + next := types.NewRowSequence(ctx, seq2) + + var rows []types.Row + count := 0 + + for row, err := range next { + if err != nil { + assert.ErrorIs(t, err, context.Canceled) + break + } + rows = append(rows, row) + count++ + + // Cancel after first row + if count == 1 { + cancel() + } + + if count > 10 { + break + } + } + + assert.GreaterOrEqual(t, len(rows), 1) + assert.Equal(t, "row", rows[0].At(0)) +} + +func TestRowIterator_EarlyBreak(t *testing.T) { + // Create multiple records + records := []arrow.Record{ + createTestRecord([]string{"row1"}), + createTestRecord([]string{"row2"}), + createTestRecord([]string{"row3"}), + } + + // Clean up records after test + defer func() { + for _, r := range records { + r.Release() + } + }() + + seq2 := createTestSeq2(records, nil) + + next := types.NewRowSequence(context.Background(), seq2) + + // Read only one row then break + var rows []types.Row + for row, err := range next { + require.NoError(t, err) + rows = append(rows, row) + if len(rows) >= 1 { + break // Early termination + } + } + + // Should have only one row + assert.Len(t, rows, 1) + assert.Equal(t, "row1", rows[0].At(0)) +} + +func TestRowIterator_EmptyBatchHandling(t *testing.T) { + // Test handling of empty records (0 rows but valid record) + emptyRecord := createTestRecord([]string{}) // No rows + validRecord := createTestRecord([]string{"row1"}) + + records := []arrow.Record{emptyRecord, validRecord} + defer func() { + for _, r := range records { + r.Release() + } + }() + + seq2 := createTestSeq2(records, nil) + next := types.NewRowSequence(context.Background(), seq2) + + // Should skip empty batch and return row from second batch + var rows []types.Row + for row, err := range next { + require.NoError(t, err) + rows = append(rows, row) + } + + assert.Len(t, rows, 1) + assert.Equal(t, "row1", rows[0].At(0)) +} + +func TestRowIterator_DatabricksEOFBehavior(t *testing.T) { + // Test Databricks-specific behavior where io.EOF is sent as an error + // value rather than just ending the sequence. NewRowSequence treats + // io.EOF as clean termination. + seq2 := func(yield func(arrow.Record, error) bool) { + record1 := createTestRecord([]string{"row1", "row2"}) + record1.Retain() + if !yield(record1, nil) { + record1.Release() + return + } + + record2 := createTestRecord([]string{"row3"}) + record2.Retain() + if !yield(record2, nil) { + record2.Release() + return + } + + // Databricks sends io.EOF as error — should terminate cleanly + yield(nil, io.EOF) + } + + next := types.NewRowSequence(context.Background(), seq2) + + // Read all rows successfully + var rows []types.Row + for row, err := range next { + require.NoError(t, err) + rows = append(rows, row) + } + + // Should have all 3 rows + assert.Len(t, rows, 3) + assert.Equal(t, "row1", rows[0].At(0)) + assert.Equal(t, "row2", rows[1].At(0)) + assert.Equal(t, "row3", rows[2].At(0)) +} + +func TestRowIterator_NilRecordReturnsError(t *testing.T) { + // Test that receiving a nil record returns an error + seq2 := func(yield func(arrow.Record, error) bool) { + record := createTestRecord([]string{"row1"}) + record.Retain() + if !yield(record, nil) { + record.Release() + return + } + + // Yield nil record (shouldn't happen in production) + yield(nil, nil) + } + + next := types.NewRowSequence(context.Background(), seq2) + + var rows []types.Row + var gotError error + + for row, err := range next { + if err != nil { + gotError = err + break + } + rows = append(rows, row) + } + + // Should have read first row successfully + assert.Len(t, rows, 1) + assert.Equal(t, "row1", rows[0].At(0)) + + // Should have received error about nil record + assert.Error(t, gotError) + assert.Contains(t, gotError.Error(), "expected non-nil arrow.Record, got nil") +} + +func TestRowSeq2_DirectUsage(t *testing.T) { + // Test using NewRowSequence directly as a Seq2 + records := []arrow.Record{ + createTestRecord([]string{"row1", "row2"}), + createTestRecord([]string{"row3"}), + } + + defer func() { + for _, r := range records { + r.Release() + } + }() + + recordSeq := createTestSeq2(records, nil) + rowSeq := types.NewRowSequence(context.Background(), recordSeq) + + var rows []types.Row + for row, err := range rowSeq { + require.NoError(t, err) + rows = append(rows, row) + } + + // Should have all 3 rows flattened + assert.Len(t, rows, 3) + assert.Equal(t, "row1", rows[0].At(0)) + assert.Equal(t, "row2", rows[1].At(0)) + assert.Equal(t, "row3", rows[2].At(0)) +} + +func TestRowIterator_MultipleIterations(t *testing.T) { + // Test that ranging the same iterator twice works safely when the + // upstream is single-use (like a real gRPC stream). + records := []arrow.Record{ + createTestRecord([]string{"row1", "row2"}), + } + + defer func() { + for _, r := range records { + r.Release() + } + }() + + // Build a single-use upstream to simulate a gRPC stream. + exhausted := false + seq2 := func(yield func(arrow.Record, error) bool) { + if exhausted { + return + } + exhausted = true + for _, record := range records { + record.Retain() + if !yield(record, nil) { + return + } + } + } + + next := types.NewRowSequence(context.Background(), seq2) + + var rows1 []types.Row + for row, err := range next { + require.NoError(t, err) + rows1 = append(rows1, row) + } + assert.Len(t, rows1, 2) + + // Second iteration — upstream exhausted, should yield nothing + var rows2 []types.Row + for row, err := range next { + require.NoError(t, err) + rows2 = append(rows2, row) + } + assert.Len(t, rows2, 0) +}