diff --git a/epios/post_process.py b/epios/post_process.py index b256973..772278b 100644 --- a/epios/post_process.py +++ b/epios/post_process.py @@ -428,29 +428,35 @@ def _compare(self, time_sample, gen_plot=False, scale_method: str = 'proportiona ''' # Based on the input, use different scale method to estimate the true infection number if scale_method == 'proportional': - result_scaled = np.array(self.result) * len(self.demo_data) + result_scaled = np.round(np.array(self.result) * len(self.demo_data)) # Get the true result from self.time_data - true_result = [] - for t in time_sample: + true_result_plot = [] + for t in range(max(time_sample) + 1): num = self.time_data.iloc[t, 1:].value_counts().get(3, 0) num += self.time_data.iloc[t, 1:].value_counts().get(4, 0) num += self.time_data.iloc[t, 1:].value_counts().get(5, 0) num += self.time_data.iloc[t, 1:].value_counts().get(6, 0) num += self.time_data.iloc[t, 1:].value_counts().get(7, 0) num += self.time_data.iloc[t, 1:].value_counts().get(8, 0) - true_result.append(num) + true_result_plot.append(num) + + true_result = [] + for t in time_sample: + true_result.append(true_result_plot[t]) # Find the difference between estimated infection level and the real one diff = np.array(true_result) - result_scaled if gen_plot: plt.figure() plt.plot(time_sample, result_scaled, label='Predicted result', linestyle='--') - plt.plot(time_sample, true_result, label='True result') + plt.plot(range(max(time_sample) + 1), true_result_plot, label='True result') plt.plot(time_sample, np.abs(diff), label='Absolute difference') plt.legend() plt.xlabel('Time') plt.ylabel('Population') + plt.xlim(0, max(time_sample)) + plt.ylim(0, len(self.demo_data)) plt.title('Number of infection in the population') if saving_path_compare: plt.savefig(saving_path_compare) @@ -661,7 +667,7 @@ def _wrapper_Region_AgeRegion(self, sampling_method, sample_size, time_sample, n # Plot the figure if gen_plot: plt.figure() - infected_population = np.array(infected_rate) * len(self.demo_data) + infected_population = np.round(np.array(infected_rate) * len(self.demo_data)) plt.plot(time_sample, infected_population) plt.xlabel('Time') plt.ylabel('Population') @@ -748,7 +754,7 @@ def _wrapper_Age_Base(self, sampling_method, sample_size, time_sample, # Plot the figure if gen_plot: plt.figure() - infected_population = np.array(infected_rate) * len(self.demo_data) + infected_population = np.round(np.array(infected_rate) * len(self.demo_data)) plt.plot(time_sample, infected_population) plt.xlabel('Time') plt.ylabel('Population')