diff --git a/internal/tests/integration/dataframe_test.go b/internal/tests/integration/dataframe_test.go index d383ca1..69ee10d 100644 --- a/internal/tests/integration/dataframe_test.go +++ b/internal/tests/integration/dataframe_test.go @@ -19,6 +19,7 @@ import ( "context" "fmt" "os" + "path" "testing" "github.com/apache/spark-connect-go/spark/sql/utils" @@ -805,6 +806,33 @@ func TestDataFrame_WithOption(t *testing.T) { assert.Equal(t, int64(10), c) } +func TestDataFrame_WriteWithOption(t *testing.T) { + ctx, spark := connect() + df, err := spark.CreateDataFrame(ctx, [][]any{{1, "a"}, {2, "b"}}, types.StructOf( + types.NewStructField("f1-i32", types.INTEGER), + types.NewStructField("f2-string", types.STRING)), + ) + assert.NoError(t, err) + c, err := df.Count(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(2), c) + outDir, err := os.MkdirTemp("", "example.out") + assert.NoError(t, err) + outfilePath := path.Join(outDir, "example.csv") + defer os.RemoveAll(outDir) + err = df.Writer().Format("csv"). + Option("header", "true"). + Save(ctx, outfilePath) + assert.NoError(t, err) + verifyDf, err := spark.Read().Format("csv"). + Option("header", "true"). + Load(outfilePath) + assert.NoError(t, err) + c, err = verifyDf.Count(ctx) + assert.NoError(t, err) + assert.Equal(t, int64(2), c) +} + func TestDataFrame_Sample(t *testing.T) { ctx, spark := connect() df, err := spark.Sql(ctx, "select * from range(100)") diff --git a/spark/sql/dataframewriter.go b/spark/sql/dataframewriter.go index 8c096f8..d663c94 100644 --- a/spark/sql/dataframewriter.go +++ b/spark/sql/dataframewriter.go @@ -33,12 +33,14 @@ type DataFrameWriter interface { Format(source string) DataFrameWriter // Save writes data frame to the given path. Save(ctx context.Context, path string) error + Option(key, value string) DataFrameWriter } func newDataFrameWriter(sparkExecutor *sparkSessionImpl, relation *proto.Relation) DataFrameWriter { return &dataFrameWriterImpl{ sparkExecutor: sparkExecutor, relation: relation, + options: nil, } } @@ -48,6 +50,7 @@ type dataFrameWriterImpl struct { relation *proto.Relation saveMode string formatSource string + options map[string]string } func (w *dataFrameWriterImpl) Mode(saveMode string) DataFrameWriter { @@ -80,6 +83,7 @@ func (w *dataFrameWriterImpl) Save(ctx context.Context, path string) error { SaveType: &proto.WriteOperation_Path{ Path: path, }, + Options: w.options, }, }, }, @@ -94,6 +98,14 @@ func (w *dataFrameWriterImpl) Save(ctx context.Context, path string) error { return err } +func (w *dataFrameWriterImpl) Option(key, value string) DataFrameWriter { + if w.options == nil { + w.options = make(map[string]string) + } + w.options[key] = value + return w +} + func getSaveMode(mode string) (proto.WriteOperation_SaveMode, error) { if mode == "" { return proto.WriteOperation_SAVE_MODE_UNSPECIFIED, nil diff --git a/spark/sql/dataframewriter_test.go b/spark/sql/dataframewriter_test.go index bc85f65..8abde16 100644 --- a/spark/sql/dataframewriter_test.go +++ b/spark/sql/dataframewriter_test.go @@ -66,6 +66,7 @@ func TestSaveExecutesWriteOperationUntilEOF(t *testing.T) { writer := newDataFrameWriter(session, relation) writer.Format("format") writer.Mode("append") + writer.Option("foo", "bar") err := writer.Save(ctx, path) assert.NoError(t, err) }