-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathonline_plot.py
More file actions
59 lines (53 loc) · 2.09 KB
/
online_plot.py
File metadata and controls
59 lines (53 loc) · 2.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import pandas as pd
from lecilab_behavior_analysis.plots import (correct_left_and_right_plot,
side_correct_performance_plot)
from matplotlib import gridspec
from matplotlib import pyplot as plt
from village.custom_classes.online_plot_base import OnlinePlotBase
class OnlinePlot(OnlinePlotBase):
# TODO: make this nice and add something informative for habituation, like the side chosen
def __init__(self) -> None:
super().__init__()
self.fig = plt.figure(figsize=(20, 5))
rows_gs = gridspec.GridSpec(2, 1, height_ratios=[1, 2])
# Create separate inner grids for each row with different width ratios
top_gs = gridspec.GridSpecFromSubplotSpec(
1, 1, subplot_spec=rows_gs[0]
)
bot_gs = gridspec.GridSpecFromSubplotSpec(
1, 3, subplot_spec=rows_gs[1], width_ratios=[1, 1, 1]
)
self.ax1 = self.fig.add_subplot(top_gs[0, 0])
self.ax2 = self.fig.add_subplot(bot_gs[0, 0])
self.ax3 = self.fig.add_subplot(bot_gs[0, 1])
def update_plot(self, df: pd.DataFrame) -> None:
try:
self.make_timing_plot(df, self.ax3)
except Exception:
self.make_error_plot(self.ax3)
try:
self.ax1.clear()
self.ax1 = side_correct_performance_plot(df, self.ax1, 50)
except Exception as e:
print(e)
self.make_error_plot(self.ax1)
try:
self.ax2.clear()
self.ax2 = correct_left_and_right_plot(df, self.ax2)
except Exception as e:
print(e)
self.make_error_plot(self.ax2)
self.fig.tight_layout()
def make_timing_plot(self, df: pd.DataFrame, ax: plt.Axes) -> None:
ax.clear()
df.plot(kind="scatter", x="TRIAL_START", y="trial", ax=ax)
def make_error_plot(self, ax) -> None:
ax.clear()
ax.text(
0.5,
0.5,
"Could not create plot",
horizontalalignment="center",
verticalalignment="center",
transform=ax.transAxes,
)