Skip to content

Commit 97d5aaa

Browse files
committed
Updating tests
1 parent b47148d commit 97d5aaa

5 files changed

Lines changed: 73 additions & 2 deletions

File tree

Chapter02/testing/helper.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@
1010

1111
# Define simulate ride data function
1212
def simulate_ride_data():
13+
"""
14+
Simulates ride data.
15+
16+
Simulates 370 ride distances with normal distribution around 10,
17+
10 ride distances with normal distribution around 30 (long distances),
18+
10 ride distances with normal distribution around 10 (same distance),
19+
and 10 ride distances with normal distribution around 10 (same distance).
20+
21+
Simulates 370 ride speeds with normal distribution around 30,
22+
10 ride speeds with normal distribution around 30 (same speed),
23+
10 ride speeds with normal distribution around 50 (high speed),
24+
and 10 ride speeds with normal distribution around 15 (low speed).
25+
26+
Assembles them into a Data Frame with ride_id as the index.
27+
28+
Returns:
29+
df_sim (pandas.DataFrame): A DataFrame containing simulated ride data.
30+
"""
1331
# Simulate some ride data ...
1432
ride_dists = np.concatenate(
1533
(
@@ -44,9 +62,19 @@ def simulate_ride_data():
4462

4563

4664
def get_taxi_data():
65+
66+
"""
67+
Reads in taxi ride data from a csv file or simulates it if not present.
68+
69+
Args:
70+
None
71+
72+
Returns:
73+
df (pandas.DataFrame): A DataFrame containing taxi ride data.
74+
"""
4775
# If data present, read it in
48-
#file_path = f'''../../chapter1/batch-anomaly/data/taxi-rides.csv''' #relative
49-
file_path = f'''chapter1/batch-anomaly/data/taxi-rides.csv''' #from top dir
76+
file_path = '../../Chapter01/clustering/taxi-rides.csv' #relative
77+
#file_path = f'''chapter1/batch-anomaly/data/taxi-rides.csv''' #from top dir
5078
if os.path.exists(file_path):
5179
df = pd.read_csv(file_path)
5280
else:

Chapter02/testing/model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,21 @@
55

66

77
def cluster_and_label(X):
8+
"""
9+
Clusters the given data with DBSCAN algorithm and returns the results.
10+
11+
Parameters
12+
----------
13+
X : numpy.ndarray
14+
Array of data points to be clustered.
15+
16+
Returns
17+
-------
18+
run_metadata : dict
19+
A dictionary containing the results of the clustering algorithm.
20+
It includes the estimated number of clusters, the estimated number of noise points,
21+
the silhouette coefficient, and the labels of the data points.
22+
"""
823
X = StandardScaler().fit_transform(X)
924
db = DBSCAN(eps=0.3, min_samples=10).fit(X)
1025

Chapter02/testing/test_basic.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
def test_example():
2+
"""
3+
Test that the example works as expected.
4+
"""
25
pass

Chapter02/testing/test_model_performance.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,11 @@
1212

1313
@pytest.fixture
1414
def test_dataset() -> Union[np.array, np.array]:
15+
"""
16+
Returns a tuple containing the test dataset and the corresponding labels.
17+
The dataset is the wine dataset, with the label being True for class 2 and False otherwise.
18+
The dataset is split into a training and test set using `train_test_split` with a random state of 42.
19+
"""
1520
# Load the dataset
1621
X, y = load_wine(return_X_y=True)
1722
# create an array of True for 2 and False otherwise
@@ -22,18 +27,31 @@ def test_dataset() -> Union[np.array, np.array]:
2227

2328
@pytest.fixture
2429
def model() -> sklearn.ensemble._forest.RandomForestClassifier:
30+
"""
31+
Returns a trained RandomForestClassifier model downloaded from the Hugging Face Hub.
32+
The model was trained on the wine dataset and is used for testing the performance of the model.
33+
"""
2534
REPO_ID = "electricweegie/mlewp-sklearn-wine"
2635
FILENAME = "rfc.joblib"
2736
model = joblib.load(hf_hub_download(REPO_ID, FILENAME))
2837
return model
2938

3039

3140
def test_model_inference_types(model, test_dataset):
41+
"""
42+
Tests that the model's predict method returns a numpy array and that the test dataset is composed of numpy arrays.
43+
"""
44+
3245
assert isinstance(model.predict(test_dataset[0]), np.ndarray)
3346
assert isinstance(test_dataset[0], np.ndarray)
3447
assert isinstance(test_dataset[1], np.ndarray)
3548

3649
def test_model_performance(model, test_dataset):
50+
"""
51+
Tests the performance of the model on the test dataset.
52+
The performance is measured using the F1-score and precision metrics.
53+
The model is expected to achieve an F1-score greater than 0.95 and a precision greater than 0.9 for class False, and an F1-score greater than 0.8 and a precision greater than 0.8 for class True.
54+
"""
3755
metrics = classification_report(y_true=test_dataset[1], y_pred=model.predict(test_dataset[0]), output_dict=True)
3856
assert metrics['False']['f1-score'] > 0.95
3957
assert metrics['False']['precision'] > 0.9

Chapter02/testing/test_taxi_cluster_basic.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44

55
@pytest.mark.skip(reason="From edition 1, does not work due to not uploading taxi data in repo")
66
def test_cluster_and_label():
7+
"""
8+
Tests the cluster_and_label function.
9+
10+
This test should pass if the function returns a dictionary.
11+
12+
The test is skipped because the taxi data is not uploaded to the repository.
13+
"""
714
df = get_taxi_data()
815
results = cluster_and_label(df)
916
assert isinstance(results, dict)

0 commit comments

Comments
 (0)