diff --git a/CHANGELOG.md b/CHANGELOG.md index eea6446a48..3bbf6ea87f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ #### Bug Fixes +- Fixed a bug that `TRY_CAST` reader option is ignored when calling `DataFrameReader.schema().csv()`. - Fixed a Snowflake platform compatibility issue (SNOW-3259059) where `concat(lit('"'), ...)` could lose the leading quote through some `EXCEPT` / chained set-operation plans by lowering that literal to `CHR(34)` in generated SQL. - Fixed a bug where chained `DataFrame.filter()` calls with raw SQL text containing `OR` produced incorrect results. diff --git a/src/snowflake/snowpark/dataframe_reader.py b/src/snowflake/snowpark/dataframe_reader.py index daf0f5d3f6..99d6ef5e6e 100644 --- a/src/snowflake/snowpark/dataframe_reader.py +++ b/src/snowflake/snowpark/dataframe_reader.py @@ -797,6 +797,7 @@ def csv(self, path: str, _emit_ast: bool = True) -> DataFrame: self._file_type = "CSV" schema_to_cast, transformations = None, None + use_user_schema = False if not self._user_schema: if not self._infer_schema: @@ -833,7 +834,19 @@ def csv(self, path: str, _emit_ast: bool = True) -> DataFrame: transformations = [] else: self._cur_options["INFER_SCHEMA"] = False - schema = self._user_schema._to_attributes() + try_cast = self._cur_options.get("TRY_CAST", False) + if try_cast: + ( + schema, + schema_to_cast, + transformations, + ) = self._get_schema_from_csv_user_input( + self._user_schema, + try_cast, + ) + use_user_schema = True + else: + schema = self._user_schema._to_attributes() metadata_project, metadata_schema = self._get_metadata_project_and_schema() @@ -859,6 +872,7 @@ def csv(self, path: str, _emit_ast: bool = True) -> DataFrame: transformations=transformations, metadata_project=metadata_project, metadata_schema=metadata_schema, + use_user_schema=use_user_schema, ), analyzer=self._session._analyzer, ), @@ -879,6 +893,7 @@ def csv(self, path: str, _emit_ast: bool = True) -> DataFrame: transformations=transformations, metadata_project=metadata_project, metadata_schema=metadata_schema, + use_user_schema=use_user_schema, ), _ast_stmt=stmt, _emit_ast=_emit_ast, @@ -1387,6 +1402,37 @@ def _infer_schema_for_file_format( return new_schema, schema_to_cast, read_file_transformations, None + def _get_schema_from_csv_user_input( + self, user_schema: StructType, try_cast: bool + ) -> Tuple[List, List, List]: + """ + This function accept a user input structtype and return schemas needed for reading CSV file. + CSV files are processed differently than semi-structured file so need a different helper function. + """ + schema_to_cast = [] + transformations = [] + new_schema = [] + for index, field in enumerate(user_schema.fields, start=1): + new_schema.append( + Attribute( + field.column_identifier.quoted_name, + field.datatype, + field.nullable, + ) + ) + sf_type = convert_sp_to_sf_type(field.datatype) + source_column = f"${index}" + identifier = ( + f"TRY_CAST({source_column} AS {sf_type})" + if try_cast + else f"{source_column}::{sf_type}" + ) + schema_to_cast.append((identifier, field.name)) + transformations.append(sql_expr(identifier)) + + read_file_transformations = [t._expression.sql for t in transformations] + return new_schema, schema_to_cast, read_file_transformations + def _get_schema_from_user_input( self, user_schema: StructType ) -> Tuple[List, List, List]: diff --git a/tests/integ/scala/test_dataframe_reader_suite.py b/tests/integ/scala/test_dataframe_reader_suite.py index be91633c10..2e6b855c53 100644 --- a/tests/integ/scala/test_dataframe_reader_suite.py +++ b/tests/integ/scala/test_dataframe_reader_suite.py @@ -375,6 +375,25 @@ def test_read_csv(session, mode): assert "is out of range" in str(ex_info.value) +@pytest.mark.parametrize("mode", ["select", "copy"]) +def test_read_csv_with_user_schema_try_cast(session, mode): + reader = get_reader(session, mode) + test_file_on_stage = f"@{tmp_stage_name1}/{test_file_csv}" + try_cast_schema = StructType( + [ + StructField("a", IntegerType()), + StructField("b", IntegerType()), + StructField("c", DoubleType()), + ] + ) + df_try_cast = ( + reader.schema(try_cast_schema).option("TRY_CAST", True).csv(test_file_on_stage) + ) + try_cast_res = df_try_cast.collect() + try_cast_res.sort(key=lambda x: x[0]) + assert try_cast_res == [Row(1, None, 1.2), Row(2, None, 2.2)] + + @pytest.mark.xfail( "config.getoption('local_testing_mode', default=False)", reason="SNOW-1435112: csv infer schema option is not supported",