Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions mario/data_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,6 @@ def stream_sql_to_hyper(self, file_path: str, **kwargs):
frame_to_hyper(df, database=file_path, table=table_name, table_mode='a')

def stream_sql_query_to_csv(self, file_path, query, connection, row_counter=0, **kwargs) -> int:
from mario.query_builder import get_formatted_query
options = CsvOptions(**kwargs)
if options.compress_using_gzip:
compression_options = dict(method='gzip')
Expand All @@ -421,7 +420,7 @@ def stream_sql_query_to_csv(self, file_path, query, connection, row_counter=0, *
mode = 'w'
header = True

for df in pd.read_sql(get_formatted_query(query[0], query[1]), connection, chunksize=options.chunk_size):
for df in pd.read_sql(sql=query[0], params=query[1], con=connection, chunksize=options.chunk_size):
if options.validate or options.minimise:
self._data = df
if options.validate:
Expand Down Expand Up @@ -645,7 +644,6 @@ def stream_sql_to_hyper(self, file_path: str, **kwargs):
logger.info("Executing query")
from tableauhyperapi import TableName
from pantab import frame_to_hyper
from mario.query_builder import get_formatted_query

options = HyperOptions(**kwargs)
connection = self.get_connection()
Expand All @@ -655,7 +653,7 @@ def stream_sql_to_hyper(self, file_path: str, **kwargs):

for partition_value in self.__get_partition_values__():
query = self.__build_query_using_partition__(partition_value=partition_value)
for df in pd.read_sql(get_formatted_query(query[0], query[1]), connection, chunksize=options.chunk_size):
for df in pd.read_sql(sql=query[0], params=query[1], con=connection, chunksize=options.chunk_size):
if options.validate or options.minimise:
self._data = df
if options.validate:
Expand Down
72 changes: 72 additions & 0 deletions test/test_data_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,3 +724,75 @@ def test_partitioning_extractor_with_row_numbers():
columns = pd.read_csv(csv_file).columns
# Check row_number omitted in CSV output
assert 'row_number' not in columns


def test_partitioning_extractor_with_row_numbers_apostrophes():
# Skip this test if we don't have a connection string
if not os.environ.get('CONNECTION_STRING'):
pytest.skip("Skipping SQL test as no database configured")

from mario.hyper_utils import get_column_list

dataset = dataset_from_json(os.path.join('test', 'dataset.json'))
constraint = Constraint()
constraint.item = 'Product Name'
constraint.allowed_values = [
"Honeywell Enviracaire Portable HEPA Air Cleaner for 17' x 22' Room",
"Global Manager's Adjustable Task Chair, Storm"
]

dataset.constraints.append(constraint)
constraint2 = Constraint()
constraint2.item = 'City'
constraint2.allowed_values = ['Jacksonville', 'Laredo', 'Springfield']
dataset.constraints.append(constraint2)
metadata = metadata_from_json(os.path.join('test', 'metadata.json'))
configuration = Configuration(
connection_string=os.environ.get('CONNECTION_STRING'),
schema="dev",
view="superstore",
query_builder=SubsetQueryBuilder
)
extractor = PartitioningExtractor(
configuration=configuration,
dataset_specification=dataset,
metadata=metadata,
partition_column='City'
)

path = os.path.join('output', 'test_partitioning_extractor_with_row_numbers_apostrophes')
os.makedirs(path, exist_ok=True)

file = os.path.join(path, 'test.hyper')
csv_file = os.path.join(path, 'test.csv')

# drop existing
for path in [file, csv_file]:
if os.path.exists(path):
os.remove(path)

extractor.stream_sql_to_hyper(
file_path=file,
include_row_numbers=True
)

# Check row_number is in hyper output
assert 'row_number' in get_column_list(hyper_file_path=file, schema='Extract', table='Extract')

# Load it up and export a CSV
hyper_config = Configuration(
file_path=file
)
hyper = HyperFile(
configuration=hyper_config,
dataset_specification=dataset,
metadata=metadata
)

hyper.save_data_as_csv(
file_path=csv_file,
compress_using_gzip=False
)
columns = pd.read_csv(csv_file).columns
# Check row_number omitted in CSV output
assert 'row_number' not in columns
Loading