-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassification.py
More file actions
28 lines (24 loc) · 1.02 KB
/
classification.py
File metadata and controls
28 lines (24 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
# classification.py
# Classification algorithms for the DewCrypt model.
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import SGDClassifier
def mini_batch_sgd_classifier(input_data, labels):
"""
Mini-Batch SGD Classifier algorithm as described in the paper.
"""
clf = SGDClassifier(max_iter=1000, tol=1e-3)
clf.fit(input_data, labels)
return clf.predict(input_data)
def random_forest_classifier(input_data, labels):
"""
Random Forest Classifier algorithm as described in the paper.
"""
clf = RandomForestClassifier(n_estimators=100)
clf.fit(input_data, labels)
return clf.predict(input_data)
if __name__ == "__main__":
# Test the classification functions
sample_data = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
sample_labels = np.array([0, 1, 0])
print("SGD Classifier Predictions:", mini_batch_sgd_classifier(sample_data, sample_labels))
print("Random Forest Classifier Predictions:", random_forest_classifier(sample_data, sample_labels))