diff --git a/mario/data_extractor.py b/mario/data_extractor.py index 31d5d46..268d905 100644 --- a/mario/data_extractor.py +++ b/mario/data_extractor.py @@ -410,7 +410,6 @@ def stream_sql_to_hyper(self, file_path: str, **kwargs): frame_to_hyper(df, database=file_path, table=table_name, table_mode='a') def stream_sql_query_to_csv(self, file_path, query, connection, row_counter=0, **kwargs) -> int: - from mario.query_builder import get_formatted_query options = CsvOptions(**kwargs) if options.compress_using_gzip: compression_options = dict(method='gzip') @@ -421,7 +420,7 @@ def stream_sql_query_to_csv(self, file_path, query, connection, row_counter=0, * mode = 'w' header = True - for df in pd.read_sql(get_formatted_query(query[0], query[1]), connection, chunksize=options.chunk_size): + for df in pd.read_sql(sql=query[0], params=query[1], con=connection, chunksize=options.chunk_size): if options.validate or options.minimise: self._data = df if options.validate: @@ -645,7 +644,6 @@ def stream_sql_to_hyper(self, file_path: str, **kwargs): logger.info("Executing query") from tableauhyperapi import TableName from pantab import frame_to_hyper - from mario.query_builder import get_formatted_query options = HyperOptions(**kwargs) connection = self.get_connection() @@ -655,7 +653,7 @@ def stream_sql_to_hyper(self, file_path: str, **kwargs): for partition_value in self.__get_partition_values__(): query = self.__build_query_using_partition__(partition_value=partition_value) - for df in pd.read_sql(get_formatted_query(query[0], query[1]), connection, chunksize=options.chunk_size): + for df in pd.read_sql(sql=query[0], params=query[1], con=connection, chunksize=options.chunk_size): if options.validate or options.minimise: self._data = df if options.validate: diff --git a/test/test_data_extractor.py b/test/test_data_extractor.py index 0e146cf..cee500b 100644 --- a/test/test_data_extractor.py +++ b/test/test_data_extractor.py @@ -724,3 +724,75 @@ def test_partitioning_extractor_with_row_numbers(): columns = pd.read_csv(csv_file).columns # Check row_number omitted in CSV output assert 'row_number' not in columns + + +def test_partitioning_extractor_with_row_numbers_apostrophes(): + # Skip this test if we don't have a connection string + if not os.environ.get('CONNECTION_STRING'): + pytest.skip("Skipping SQL test as no database configured") + + from mario.hyper_utils import get_column_list + + dataset = dataset_from_json(os.path.join('test', 'dataset.json')) + constraint = Constraint() + constraint.item = 'Product Name' + constraint.allowed_values = [ + "Honeywell Enviracaire Portable HEPA Air Cleaner for 17' x 22' Room", + "Global Manager's Adjustable Task Chair, Storm" + ] + + dataset.constraints.append(constraint) + constraint2 = Constraint() + constraint2.item = 'City' + constraint2.allowed_values = ['Jacksonville', 'Laredo', 'Springfield'] + dataset.constraints.append(constraint2) + metadata = metadata_from_json(os.path.join('test', 'metadata.json')) + configuration = Configuration( + connection_string=os.environ.get('CONNECTION_STRING'), + schema="dev", + view="superstore", + query_builder=SubsetQueryBuilder + ) + extractor = PartitioningExtractor( + configuration=configuration, + dataset_specification=dataset, + metadata=metadata, + partition_column='City' + ) + + path = os.path.join('output', 'test_partitioning_extractor_with_row_numbers_apostrophes') + os.makedirs(path, exist_ok=True) + + file = os.path.join(path, 'test.hyper') + csv_file = os.path.join(path, 'test.csv') + + # drop existing + for path in [file, csv_file]: + if os.path.exists(path): + os.remove(path) + + extractor.stream_sql_to_hyper( + file_path=file, + include_row_numbers=True + ) + + # Check row_number is in hyper output + assert 'row_number' in get_column_list(hyper_file_path=file, schema='Extract', table='Extract') + + # Load it up and export a CSV + hyper_config = Configuration( + file_path=file + ) + hyper = HyperFile( + configuration=hyper_config, + dataset_specification=dataset, + metadata=metadata + ) + + hyper.save_data_as_csv( + file_path=csv_file, + compress_using_gzip=False + ) + columns = pd.read_csv(csv_file).columns + # Check row_number omitted in CSV output + assert 'row_number' not in columns \ No newline at end of file