-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathscript.py
More file actions
209 lines (186 loc) · 8.77 KB
/
script.py
File metadata and controls
209 lines (186 loc) · 8.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
"""
This module contains a dashboard application for visualizing COVID-19 data.
It includes the dashboard itself and methods for visualizing the data using Dash.
"""
import pandas as pd
import dash
from dash import dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objs as go
import data_prep
class Visualisation:
"""
This is a class containing different functions for creating plots out of data.
"""
def __init__(self):
self.processor = data_prep.DataProcessor()
def generate_plots(self, df_filtered, selected):
"""
Generates line and bar plots for selected data and countries.
"""
fig_cases = self.generate_line_plot(df_filtered, selected, "New_cases")
fig_deaths = self.generate_line_plot(df_filtered, selected, "New_deaths")
rt_plot = self.plot_bar(df_filtered, selected, "Rt")
death_cases_plot = self.plot_bar(df_filtered, selected, "deaths_per_cases")
return fig_cases, fig_deaths, rt_plot, death_cases_plot
def generate_line_plot(self, df, selected_countries, data_col):
"""
Creates a line graph illustrating the development of a specified data column.
"""
fig = go.Figure()
name_of_graph = data_col.replace('_', ' ').replace('New ', '').capitalize()
for country in selected_countries:
selected_country_data = df[df['Country_region'] == country]
fig.add_trace(go.Scatter(x=selected_country_data['Date_reported'],
y=selected_country_data[data_col], mode='lines', name=country))
fig.update_layout(title=f'{name_of_graph} by Country', yaxis_title=f'# of {name_of_graph}',
font=dict(family='Arial', size=16, color='#7f7f7f'),
plot_bgcolor='white', paper_bgcolor='white'
)
return fig
def plot_bar(self, df, selected_countries, data_col):
"""
Generates a bar plot for the given data column for the selected countries.
"""
fig = go.Figure()
df, title = self.compute_data(df, data_col)
for country_region in selected_countries:
selected_country_data = df[df['Country_region'] == country_region]
fig.add_trace(go.Bar(x=selected_country_data['Country_region'],
y=selected_country_data[data_col],
name=country_region))
fig.update_layout(
title=title,
font=dict(family='Arial', size=16, color='#7f7f7f'),
plot_bgcolor='white', paper_bgcolor='white'
)
return fig
def compute_data(self, df, data_col):
"""
Returns the title and the data for a bar plot based on the data column.
"""
title = ""
grouped = df.groupby('Country_region')
if data_col == 'Rt':
title = 'Average Rt number (Transmission rate)'
df = grouped.agg({data_col: 'mean'}).reset_index()
elif data_col == 'deaths_per_cases':
title = 'Deaths per cases'
df = grouped.apply(self.processor.calculate_deaths_per_cases).reset_index(drop=True)
df = df.groupby('Country_region').tail(1)
else:
print(f"The column '{data_col}' either doesn't exist or we don't compute it yet.")
return df, title
class CovidDashboard:
"""
This class sets up and runs a Dash application to visualize COVID-19 data.
"""
def __init__(self):
try:
self.df_absolute = pd.read_csv('./processed_data/df_absolute.csv')
self.df_normalized = pd.read_csv('./processed_data/df_normalized.csv')
except FileNotFoundError as e:
print(f"Error: {e}. Ensure that the required CSV files are in the 'processed_data' directory.")
self.df_absolute = pd.DataFrame()
self.df_normalized = pd.DataFrame()
except pd.errors.EmptyDataError as e:
print(f"Error: {e}. The CSV files are empty. Please provide valid data.")
self.df_absolute = pd.DataFrame()
self.df_normalized = pd.DataFrame()
except Exception as e:
print(f"An unexpected error occurred: {e}")
self.df_absolute = pd.DataFrame()
self.df_normalized = pd.DataFrame()
self.visualize = Visualisation()
self.app = dash.Dash(__name__)
self.app.layout = self.create_layout()
def create_layout(self):
"""
Creates the layout of the Dash application.
"""
return html.Div([
html.H1('COVID-19 Dashboard', style={'textAlign': 'center', 'color': '#003366'}),
html.Div(
[html.Div(
[html.Label('Select Countries/Regions', style={'fontWeight': 'bold'}),
dcc.Dropdown(
id='country-dropdown',
options=[{'label': country, 'value': country} for
country in self.df_absolute['Country_region'].unique()],
value=['Switzerland', 'EURO', 'Türkiye'],
multi=True,
style={'width': '100%'}
),
],
style={'flex': '1', 'margin-right': '10px'}
),
html.Div(
[html.Label('Select Timeframe', style={'fontWeight': 'bold'}),
dcc.DatePickerRange(
id='date-picker',
start_date=self.df_absolute['Date_reported'].min(),
end_date=self.df_absolute['Date_reported'].max(),
display_format='MM/YYYY',
style={'width': '100%'}
),
],
style={'flex': '1', 'margin-right': '10px'}
),
html.Div(
[html.Label('Normalize Data', style={'fontWeight': 'bold'}),
dcc.Checklist(
id='normalize-checklist',
options=[{'label': 'Normalize per one million inhabitants', 'value': 'normalize'}],
value=['normalize'],
style={'margin-top': '10px'}
),
],
style={'flex': '1'}
)
],
id='control-container',
style={'display': 'flex', 'flex-direction': 'row', 'padding': '10px'}
),
html.Div([
html.Div([
dcc.Graph(id='cases-graph', style={'width': '100%', 'display': 'inline-block'}),
dcc.Graph(id='deaths-graph', style={'width': '100%', 'display': 'inline-block'}),
], style={'width': '48%', 'display': 'inline-block', 'vertical-align': 'top'}),
html.Div([
dcc.Graph(id='rt-graph', style={'width': '100%', 'display': 'inline-block'}),
dcc.Graph(id='deaths-per-cases-graph', style={'width': '100%', 'display': 'inline-block'}),
], style={'width': '48%', 'display': 'inline-block', 'vertical-align': 'top'}),
], style={'display': 'flex', 'justify-content': 'space-between'}),
], style={'font-family': 'Arial, sans-serif', 'backgroundColor': '#f9f9f9', 'padding': '20px'})
def register_callbacks(self):
"""
Registers the callbacks for the Dash application.
"""
@self.app.callback(
[Output('cases-graph', 'figure'),
Output('deaths-graph', 'figure'),
Output('rt-graph', 'figure'),
Output('deaths-per-cases-graph', 'figure')],
[Input('country-dropdown', 'value'),
Input('date-picker', 'start_date'),
Input('date-picker', 'end_date'),
Input('normalize-checklist', 'value')]
)
def update_graphs(selected, start_date, end_date, normalize_value):
if 'normalize' in normalize_value:
df_filtered = self.df_normalized
else:
df_filtered = self.df_absolute
df_filtered = df_filtered[(df_filtered['Date_reported'] >= start_date) &
(df_filtered['Date_reported'] <= end_date)]
fig_cases, fig_deaths, rt_plot, death_cases_plot = self.visualize.generate_plots(df_filtered, selected)
return fig_cases, fig_deaths, rt_plot, death_cases_plot
def run(self):
"""
Run the app
"""
self.register_callbacks()
self.app.run_server(debug=True)
if __name__ == '__main__':
dashboard = CovidDashboard()
dashboard.run()