Skip to content

Commit 216c9aa

Browse files
committed
Model Trainer part completed
1 parent 2740053 commit 216c9aa

14 files changed

Lines changed: 311 additions & 10 deletions

File tree

final_model/model.pkl

12.6 MB
Binary file not shown.

final_model/preprocessor.pkl

0 Bytes
Binary file not shown.
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import sys
2+
from pathlib import Path
3+
4+
from sklearn.ensemble import (
5+
AdaBoostClassifier,
6+
GradientBoostingClassifier,
7+
RandomForestClassifier,
8+
)
9+
from sklearn.linear_model import LogisticRegression
10+
from sklearn.tree import DecisionTreeClassifier
11+
12+
from network_security.entity.artifact_entity import (
13+
DataTransformationArtifact,
14+
ModelTrainerArtifact,
15+
)
16+
from network_security.entity.config_entity import ModelTrainerConfig
17+
from network_security.exception.exception import NetworkSecurityException
18+
from network_security.logging.logger import logging
19+
from network_security.utils.main_utils.utils import (
20+
load_numpy_array_data,
21+
load_object,
22+
save_object,
23+
)
24+
from network_security.utils.ml_utils.evaluation.evaluation import evaluate_models
25+
from network_security.utils.ml_utils.metric.classification_metric import (
26+
get_classification_score,
27+
)
28+
from network_security.utils.ml_utils.model.estimator import NetworkModel
29+
30+
31+
class ModelTrainer:
32+
def __init__(
33+
self,
34+
model_trainer_config: ModelTrainerConfig,
35+
data_transformation_artifact: DataTransformationArtifact,
36+
) -> None:
37+
try:
38+
self.model_trainer_config = model_trainer_config
39+
self.data_transformation_artifact = data_transformation_artifact
40+
except Exception as e:
41+
raise NetworkSecurityException(e, sys)
42+
43+
def train_model(
44+
self,
45+
X_train: object,
46+
y_train: object,
47+
X_test: object,
48+
y_test: object,
49+
) -> ModelTrainerArtifact:
50+
models = {
51+
"Random Forest": RandomForestClassifier(verbose=1),
52+
"Decision Tree": DecisionTreeClassifier(),
53+
"Gradient Boosting": GradientBoostingClassifier(verbose=1),
54+
"Logistic Regression": LogisticRegression(verbose=1),
55+
"AdaBoost": AdaBoostClassifier(),
56+
}
57+
params = {
58+
"Decision Tree": {
59+
"criterion": ["gini", "entropy", "log_loss"],
60+
"splitter": ["best", "random"],
61+
"max_features": ["sqrt", "log2"],
62+
},
63+
"Random Forest": {
64+
"criterion": ["gini", "entropy", "log_loss"],
65+
"max_features": ["sqrt", "log2", None],
66+
"n_estimators": [8, 16, 32, 128, 256],
67+
},
68+
"Gradient Boosting": {
69+
"loss": ["log_loss", "exponential"],
70+
"learning_rate": [0.1, 0.01, 0.05, 0.001],
71+
"subsample": [0.6, 0.7, 0.75, 0.85, 0.9],
72+
"criterion": ["squared_error", "friedman_mse"],
73+
"max_features": ["auto", "sqrt", "log2"],
74+
"n_estimators": [8, 16, 32, 64, 128, 256],
75+
},
76+
"Logistic Regression": {},
77+
"AdaBoost": {
78+
"learning_rate": [0.1, 0.01, 0.001],
79+
"n_estimators": [8, 16, 32, 64, 128, 256],
80+
},
81+
}
82+
model_report: dict = evaluate_models(
83+
X_train=X_train,
84+
y_train=y_train,
85+
X_test=X_test,
86+
y_test=y_test,
87+
models=models,
88+
param=params,
89+
)
90+
91+
## To get best model score from dict
92+
best_model_score = max(sorted(model_report.values()))
93+
94+
## To get best model name from dict
95+
best_model_name = list(model_report.keys())[
96+
list(model_report.values()).index(best_model_score)
97+
]
98+
best_model = models[best_model_name]
99+
y_train_pred = best_model.predict(X_train)
100+
101+
classification_train_metric = get_classification_score(
102+
y_true=y_train,
103+
y_pred=y_train_pred,
104+
)
105+
106+
y_test_pred = best_model.predict(X_test)
107+
classification_test_metric = get_classification_score(
108+
y_true=y_test,
109+
y_pred=y_test_pred,
110+
)
111+
112+
preprocessor = load_object(
113+
file_path=self.data_transformation_artifact.transformed_object_file_path,
114+
)
115+
model_dir_path = Path(self.model_trainer_config.trained_model_file_path).parent
116+
model_dir_path.mkdir(parents=True, exist_ok=True)
117+
118+
network_model = NetworkModel(preprocessor=preprocessor, model=best_model)
119+
save_object(self.model_trainer_config.trained_model_file_path, obj=NetworkModel)
120+
121+
## Model pusher
122+
save_object("final_model/model.pkl", best_model)
123+
124+
## Model Trainer Artifact
125+
model_trainer_artifact = ModelTrainerArtifact(
126+
trained_model_file_path=self.model_trainer_config.trained_model_file_path,
127+
train_metric_artifact=classification_train_metric,
128+
test_metric_artifact=classification_test_metric,
129+
)
130+
logging.info(f"Model trainer artifact: {model_trainer_artifact}")
131+
return model_trainer_artifact
132+
133+
def initiate_model_trainer(self) -> ModelTrainerArtifact:
134+
try:
135+
train_file_path = (
136+
self.data_transformation_artifact.transformed_train_file_path
137+
)
138+
test_file_path = (
139+
self.data_transformation_artifact.transformed_test_file_path
140+
)
141+
142+
# Loading training array and testing array
143+
train_arr = load_numpy_array_data(train_file_path)
144+
test_arr = load_numpy_array_data(test_file_path)
145+
146+
x_train, y_train, x_test, y_test = (
147+
train_arr[:, :-1],
148+
train_arr[:, -1],
149+
test_arr[:, :-1],
150+
test_arr[:, -1],
151+
)
152+
153+
model_trainer_artifact = self.train_model(x_train, y_train, x_test, y_test)
154+
return model_trainer_artifact
155+
156+
except Exception as e:
157+
raise NetworkSecurityException(e, sys)

network_security/constant/training_pipeline/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,14 @@
5555
DATA_TRANSFORMATION_TRAIN_FILE_PATH: str = "train.npy"
5656

5757
DATA_TRANSFORMATION_TEST_FILE_PATH: str = "test.npy"
58+
59+
60+
"""
61+
Model Trainer ralated constant start with MODE TRAINER VAR NAME
62+
"""
63+
64+
MODEL_TRAINER_DIR_NAME: str = "model_trainer"
65+
MODEL_TRAINER_TRAINED_MODEL_DIR: str = "trained_model"
66+
MODEL_TRAINER_TRAINED_MODEL_NAME: str = "model.pkl"
67+
MODEL_TRAINER_EXPECTED_SCORE: float = 0.6
68+
MODEL_TRAINER_OVER_FIITING_UNDER_FITTING_THRESHOLD: float = 0.05

network_security/entity/artifact_entity.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,17 @@ class DataTransformationArtifact:
2222
transformed_object_file_path: str
2323
transformed_train_file_path: str
2424
transformed_test_file_path: str
25+
26+
27+
@dataclass
28+
class ClassificationMetricArtifact:
29+
f1_score: float
30+
precision_score: float
31+
recall_score: float
32+
33+
34+
@dataclass
35+
class ModelTrainerArtifact:
36+
trained_model_file_path: str
37+
train_metric_artifact: ClassificationMetricArtifact
38+
test_metric_artifact: ClassificationMetricArtifact

network_security/entity/config_entity.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,20 @@ def __init__(self, training_pipeline_config: TrainingPipelineConfig) -> None:
9999
/ training_pipeline.DATA_TRANSFORMATION_TRANSFORMED_OBJECT_DIR
100100
/ training_pipeline.PREPROCESSING_OBJECT_FILE_NAME
101101
)
102+
103+
104+
class ModelTrainerConfig:
105+
def __init__(self, training_pipeline_config: TrainingPipelineConfig) -> None:
106+
self.model_trainer_dir: Path = (
107+
Path(training_pipeline_config.artifact_dir)
108+
/ training_pipeline.MODEL_TRAINER_DIR_NAME
109+
)
110+
self.trained_model_file_path: Path = (
111+
self.model_trainer_dir
112+
/ training_pipeline.MODEL_TRAINER_TRAINED_MODEL_DIR
113+
/ training_pipeline.MODEL_FILE_NAME
114+
)
115+
self.expected_accuracy: float = training_pipeline.MODEL_TRAINER_EXPECTED_SCORE
116+
self.overfitting_underfitting_threshold = (
117+
training_pipeline.MODEL_TRAINER_OVER_FIITING_UNDER_FITTING_THRESHOLD
118+
)

network_security/utils/main_utils/utils.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
import os
2-
3-
# import dill
41
import pickle
52
import sys
63
from pathlib import Path
@@ -31,16 +28,16 @@ def write_yaml_file(file_path: str, content: object, replace: bool = False) -> N
3128
raise NetworkSecurityException(e, sys)
3229

3330

34-
def save_numpy_array_data(file_path: str, array: np.array):
31+
def save_numpy_array_data(file_path: str, array: np.array) -> None:
3532
"""
3633
Save numpy array data to file
3734
file_path: str location of file to save
38-
array: np.array data to save
35+
array: np.array data to save.
3936
"""
4037
try:
41-
dir_path = os.path.dirname(file_path)
42-
os.makedirs(dir_path, exist_ok=True)
43-
with open(file_path, "wb") as file_obj:
38+
dir_path = Path(file_path).parent
39+
dir_path.mkdir(parents=True, exist_ok=True)
40+
with Path(file_path).open("wb") as file_obj:
4441
np.save(file_obj, array)
4542
except Exception as e:
4643
raise NetworkSecurityException(e, sys) from e
@@ -49,9 +46,34 @@ def save_numpy_array_data(file_path: str, array: np.array):
4946
def save_object(file_path: str, obj: object) -> None:
5047
try:
5148
logging.info("Entered the save_object method of MainUtils class")
52-
os.makedirs(os.path.dirname(file_path), exist_ok=True)
53-
with open(file_path, "wb") as file_obj:
49+
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
50+
with Path(file_path).open("wb") as file_obj:
5451
pickle.dump(obj, file_obj)
5552
logging.info("Exited the save_object method of MainUtils class")
5653
except Exception as e:
5754
raise NetworkSecurityException(e, sys) from e
55+
56+
57+
def load_object(file_path: str) -> object:
58+
try:
59+
if not Path(file_path).exists():
60+
raise Exception(f"The file: {file_path} is not exists")
61+
with Path(file_path).open("rb") as file_obj:
62+
print(file_obj)
63+
return pickle.load(file_obj)
64+
except Exception as e:
65+
raise NetworkSecurityException(e, sys) from e
66+
67+
68+
def load_numpy_array_data(file_path: str) -> np.array:
69+
"""
70+
Load numpy array data from file
71+
file_path: str location of file to load
72+
return: np.array data loaded.
73+
"""
74+
try:
75+
with Path(file_path).open("rb") as file_obj:
76+
return np.load(file_obj)
77+
except Exception as e:
78+
raise NetworkSecurityException(e, sys) from e
79+

network_security/utils/ml_utils/__init__.py

Whitespace-only changes.

network_security/utils/ml_utils/evaluation/__init__.py

Whitespace-only changes.
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import sys
2+
3+
from sklearn.metrics import r2_score
4+
from sklearn.model_selection import GridSearchCV
5+
6+
from network_security.exception.exception import NetworkSecurityException
7+
8+
9+
def evaluate_models(X_train: object, y_train: object, X_test: object, y_test: object, models: dict, param: dict) -> dict:
10+
try:
11+
report = {}
12+
13+
for i in range(len(list(models))):
14+
model = list(models.values())[i]
15+
para = param[list(models.keys())[i]]
16+
17+
gs = GridSearchCV(model, para, cv=3)
18+
gs.fit(X_train, y_train)
19+
20+
model.set_params(**gs.best_params_)
21+
model.fit(X_train, y_train)
22+
23+
# model.fit(X_train, y_train) # Train model
24+
25+
y_train_pred = model.predict(X_train)
26+
27+
y_test_pred = model.predict(X_test)
28+
29+
train_model_score = r2_score(y_train, y_train_pred)
30+
31+
test_model_score = r2_score(y_test, y_test_pred)
32+
33+
report[list(models.keys())[i]] = test_model_score
34+
35+
return report
36+
37+
except Exception as e:
38+
raise NetworkSecurityException(e, sys)

0 commit comments

Comments
 (0)