-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
311 lines (254 loc) · 13 KB
/
main.py
File metadata and controls
311 lines (254 loc) · 13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
#!/usr/bin/env python
"""
Data Analyzer - Main Application
This script provides a command-line interface for analyzing and visualizing
data from CSV files using the components from the data_analyzer package.
"""
import os
import sys
import argparse
import pandas as pd
from datetime import datetime
# Import project modules
from src.data_loader import DataLoader
from src.analyzer import DataAnalyzer
from src.visualizer import DataVisualizer
def parse_args():
"""Parse command-line arguments."""
parser = argparse.ArgumentParser(description='Data Analysis Tool')
# Required arguments
parser.add_argument('file', help='Path to the CSV file to analyze')
# Analysis options
analysis_group = parser.add_argument_group('Analysis Options')
analysis_group.add_argument('--analysis', '-a', choices=[
'summary', 'time-series', 'distribution', 'top-categories',
'customer-segments', 'customer-metrics', 'correlation'
], default='summary', help='Type of analysis to perform (default: summary)')
analysis_group.add_argument('--groupby', '-g', help='Column to group by for analysis')
analysis_group.add_argument('--n-top', '-n', type=int, default=5,
help='Number of top items to show (default: 5)')
analysis_group.add_argument('--frequency', '-f', default='M',
choices=['D', 'W', 'M', 'Q', 'Y'],
help='Frequency for time-series analysis (default: M)')
analysis_group.add_argument('--start-date', help='Start date for filtering (YYYY-MM-DD)')
analysis_group.add_argument('--end-date', help='End date for filtering (YYYY-MM-DD)')
analysis_group.add_argument('--category', help='Filter by category')
analysis_group.add_argument('--customer', help='Filter by customer ID')
# Visualization options
viz_group = parser.add_argument_group('Visualization Options')
viz_group.add_argument('--plot', '-p', choices=[
'bar', 'line', 'pie', 'heatmap', 'histogram', 'box', 'scatter'
], default='bar', help='Type of plot to create (default: bar)')
viz_group.add_argument('--x-column', help='Column for X-axis in scatter plot')
viz_group.add_argument('--y-column', help='Column for Y-axis in scatter plot')
viz_group.add_argument('--title', help='Title for the plot')
viz_group.add_argument('--xlabel', help='Label for X-axis')
viz_group.add_argument('--ylabel', help='Label for Y-axis')
viz_group.add_argument('--color', default='steelblue', help='Color for the plot')
viz_group.add_argument('--horizontal', action='store_true', help='Create horizontal bar chart')
viz_group.add_argument('--figsize', help='Figure size in inches (width,height)', default='10,6')
# Output options
output_group = parser.add_argument_group('Output Options')
output_group.add_argument('--output', '-o', help='Directory to save output files')
output_group.add_argument('--format', choices=['png', 'jpg', 'svg', 'pdf'],
default='png', help='Format for saving plots (default: png)')
output_group.add_argument('--dpi', type=int, default=300, help='DPI for saving plots (default: 300)')
output_group.add_argument('--no-display', action='store_true',
help='Do not display plots, only save them')
return parser.parse_args()
def main():
"""Main function."""
# Parse command-line arguments
args = parse_args()
# Create instances of the components
loader = DataLoader()
try:
# Load the data
print(f"Loading data from {args.file}...")
data = loader.load_data(args.file)
print(f"Loaded {len(data)} rows and {len(data.columns)} columns.")
# Apply filters if specified
if args.start_date or args.end_date:
data = loader.filter_by_date_range(args.start_date, args.end_date)
print(f"Filtered by date range: {len(data)} rows remaining.")
if args.category:
data = loader.filter_by_category(args.category)
print(f"Filtered by category '{args.category}': {len(data)} rows remaining.")
if args.customer:
data = loader.filter_by_customer(args.customer)
print(f"Filtered by customer '{args.customer}': {len(data)} rows remaining.")
# Perform the analysis
analyzer = DataAnalyzer(data)
# Get the analysis results based on the specified type
if args.analysis == 'summary':
results = analyzer.get_summary_statistics(groupby=args.groupby)
print("\nSummary Statistics:")
print(results)
elif args.analysis == 'time-series':
results = analyzer.analyze_time_series(freq=args.frequency, groupby=args.groupby)
print(f"\nTime Series Analysis (Frequency: {args.frequency}):")
print(results)
elif args.analysis == 'distribution':
results = analyzer.get_spending_distribution(by=args.groupby or 'category')
print(f"\nSpending Distribution by {args.groupby or 'category'}:")
print(results)
elif args.analysis == 'top-categories':
results = analyzer.get_top_spending_categories(n=args.n_top)
print(f"\nTop {args.n_top} Spending Categories:")
print(results)
elif args.analysis == 'customer-segments':
results = analyzer.segment_customers()
print("\nCustomer Segments:")
print(results)
elif args.analysis == 'customer-metrics':
results = analyzer.calculate_customer_metrics()
print("\nCustomer Metrics:")
print(results)
elif args.analysis == 'correlation':
results = analyzer.get_category_correlation()
print("\nCategory Correlation:")
print(results)
# Create visualization if analysis produced results
if results is not None:
# Parse figsize
try:
width, height = map(float, args.figsize.split(','))
figsize = (width, height)
except:
figsize = (10, 6)
print("Warning: Invalid figsize format. Using default (10,6).")
# Create a visualizer
visualizer = DataVisualizer(figsize=figsize)
# Determine title if not specified
title = args.title or f"{args.analysis.replace('-', ' ').title()} - {os.path.basename(args.file)}"
# Set up save path if output is specified
save_path = None
if args.output:
# Create output directory if it doesn't exist
output_dir = visualizer.create_output_dir(args.output)
# Generate a filename based on analysis type and current timestamp
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
filename = f"{args.analysis}_{timestamp}.{args.format}"
save_path = os.path.join(output_dir, filename)
# Create the visualization based on the specified plot type
if args.plot == 'bar':
# Determine x and y labels if not specified
xlabel = args.xlabel or (args.groupby if args.groupby else 'Category')
ylabel = args.ylabel or 'Amount'
fig, ax = visualizer.bar_chart(
data=results,
title=title,
xlabel=xlabel,
ylabel=ylabel,
color=args.color,
save_path=save_path,
horizontal=args.horizontal
)
print(f"\nCreated bar chart{' and saved to ' + save_path if save_path else ''}.")
elif args.plot == 'line':
# Determine x and y labels if not specified
xlabel = args.xlabel or 'Date'
ylabel = args.ylabel or 'Amount'
fig, ax = visualizer.line_chart(
data=results,
title=title,
xlabel=xlabel,
ylabel=ylabel,
color=args.color,
save_path=save_path
)
print(f"\nCreated line chart{' and saved to ' + save_path if save_path else ''}.")
elif args.plot == 'pie':
fig, ax = visualizer.pie_chart(
data=results,
title=title,
save_path=save_path
)
print(f"\nCreated pie chart{' and saved to ' + save_path if save_path else ''}.")
elif args.plot == 'heatmap':
fig, ax = visualizer.heatmap(
data=results,
title=title,
save_path=save_path
)
print(f"\nCreated heatmap{' and saved to ' + save_path if save_path else ''}.")
elif args.plot == 'histogram':
# Determine x and y labels if not specified
xlabel = args.xlabel or 'Value'
ylabel = args.ylabel or 'Frequency'
# For histogram, we need a Series, not a DataFrame or Series with Index
if isinstance(results, pd.DataFrame):
if args.x_column and args.x_column in results.columns:
data_to_plot = results[args.x_column]
else:
# Just use the first column
data_to_plot = results.iloc[:, 0]
else:
data_to_plot = results
fig, ax = visualizer.histogram(
data=data_to_plot,
title=title,
xlabel=xlabel,
ylabel=ylabel,
color=args.color,
save_path=save_path
)
print(f"\nCreated histogram{' and saved to ' + save_path if save_path else ''}.")
elif args.plot == 'box':
# Determine x and y labels if not specified
xlabel = args.xlabel or 'Category'
ylabel = args.ylabel or 'Value'
fig, ax = visualizer.box_plot(
data=results,
title=title,
xlabel=xlabel,
ylabel=ylabel,
color=args.color,
save_path=save_path
)
print(f"\nCreated box plot{' and saved to ' + save_path if save_path else ''}.")
elif args.plot == 'scatter':
# For scatter plot, we need two Series for x and y
if not args.x_column or not args.y_column:
if isinstance(results, pd.DataFrame) and len(results.columns) >= 2:
x_column = results.columns[0]
y_column = results.columns[1]
else:
print("Error: Scatter plot requires x_column and y_column arguments or a DataFrame with at least 2 columns.")
return
else:
x_column = args.x_column
y_column = args.y_column
# Make sure the columns exist
if isinstance(results, pd.DataFrame):
if x_column not in results.columns or y_column not in results.columns:
print(f"Error: Columns '{x_column}' and/or '{y_column}' not found in results.")
return
x_data = results[x_column]
y_data = results[y_column]
else:
print("Error: Scatter plot requires a DataFrame.")
return
# Determine x and y labels if not specified
xlabel = args.xlabel or x_column
ylabel = args.ylabel or y_column
fig, ax = visualizer.scatter_plot(
x=x_data,
y=y_data,
title=title,
xlabel=xlabel,
ylabel=ylabel,
color=args.color,
save_path=save_path
)
print(f"\nCreated scatter plot{' and saved to ' + save_path if save_path else ''}.")
# Display the plot if not suppressed
if not args.no_display:
import matplotlib.pyplot as plt
plt.show()
except Exception as e:
print(f"Error: {str(e)}")
return 1
return 0
if __name__ == '__main__':
sys.exit(main())