-
Notifications
You must be signed in to change notification settings - Fork 12
Expand file tree
/
Copy pathtest_notebooks.py
More file actions
139 lines (100 loc) · 4.44 KB
/
test_notebooks.py
File metadata and controls
139 lines (100 loc) · 4.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import pathlib
import signal
import sys
from datetime import datetime
from typing import Any, Callable, NamedTuple
import nbformat
import pytest
from nbclient.exceptions import CellExecutionError
from nbconvert.preprocessors.execute import ExecutePreprocessor
TEARDOWN_CELL_TAG = "teardown"
class IndexedCell(NamedTuple):
cell: Any
index: int # type: ignore
class TeardownExecutePreprocessor(ExecutePreprocessor):
def __init__(self, **kw: Any):
super().__init__(**kw) # type: ignore
def init_notebook(self, tear_down_cells: list[IndexedCell]) -> None:
self.tear_down_cells = tear_down_cells
self._skip_rest = False
# run the cell of a notebook
def preprocess_cell(self, cell: Any, resources: Any, index: int) -> None:
if index == 0:
def handle_signal(sig, frame): # type: ignore
print("Received SIGNAL, running tear down cells")
self.teardown(resources)
sys.exit(1)
signal.signal(signal.SIGINT, handle_signal)
signal.signal(signal.SIGTERM, handle_signal)
try:
if not self._skip_rest:
super().preprocess_cell(cell, resources, index) # type: ignore
except CellExecutionError as e:
if self.tear_down_cells:
print(f"Running tear down cells due to error in notebook execution: {e}")
self.teardown(resources)
raise e
def teardown(self, resources: Any) -> None:
for td_cell, td_idx in self.tear_down_cells:
try:
super().preprocess_cell(td_cell, resources, td_idx) # type: ignore
except CellExecutionError as td_e:
print(f"Error running tear down cell {td_idx}: {td_e}")
class TearDownCollector(ExecutePreprocessor):
def __init__(self, **kw: Any):
super().__init__(**kw) # type: ignore
def init_notebook(self) -> None:
self._tear_down_cells: list[IndexedCell] = []
def preprocess_cell(self, cell: Any, resources: Any, index: int) -> None:
if TEARDOWN_CELL_TAG in cell["metadata"].get("tags", []):
self._tear_down_cells.append(IndexedCell(cell, index))
def tear_down_cells(self) -> list[IndexedCell]:
return self._tear_down_cells
def run_notebooks(filter_func: Callable[[str], bool]) -> None:
current_dir = pathlib.Path(__file__).parent.resolve()
examples_path = current_dir.parent.parent / "examples"
notebook_files = [
f for f in examples_path.iterdir() if f.is_file() and f.suffix == ".ipynb" and filter_func(f.name)
]
ep = TeardownExecutePreprocessor(kernel_name="python3")
td_collector = TearDownCollector(kernel_name="python3")
exceptions: list[RuntimeError] = []
for notebook_filename in notebook_files:
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"{now}: Executing notebook {notebook_filename}", flush=True)
with open(notebook_filename) as f:
nb = nbformat.read(f, as_version=4) # type: ignore
# Collect tear down cells
td_collector.init_notebook()
td_collector.preprocess(nb)
ep.init_notebook(tear_down_cells=td_collector.tear_down_cells())
# run the notebook
try:
ep.preprocess(nb)
print(f"Finished executing notebook {notebook_filename}")
except CellExecutionError as e:
exceptions.append(RuntimeError(f"Error executing notebook {notebook_filename}", e))
continue
if exceptions:
for nb_ex in exceptions:
print(nb_ex)
raise RuntimeError(f"{len(exceptions)} Errors occurred while executing notebooks")
else:
print("Finished executing notebooks")
@pytest.mark.requires_neo4j_and_gds
def test_neo4j(gds: Any) -> None:
neo4j_notebooks = ["neo4j-nvl-example.ipynb", "gds-nvl-example.ipynb"]
def filter_func(notebook: str) -> bool:
return notebook in neo4j_notebooks
run_notebooks(filter_func)
@pytest.mark.requires_snowflake
def test_snowflake() -> None:
snowflake_notebooks = ["snowpark-nvl-example.ipynb"]
def filter_func(notebook: str) -> bool:
return notebook in snowflake_notebooks
run_notebooks(filter_func)
def test_simple() -> None:
simple_notebooks = ["simple-nvl-example.ipynb"]
def filter_func(notebook: str) -> bool:
return notebook in simple_notebooks
run_notebooks(filter_func)