-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcreate_model.py
More file actions
75 lines (59 loc) · 2.64 KB
/
create_model.py
File metadata and controls
75 lines (59 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import json
import pathlib
import pickle
from typing import List
from typing import Tuple
import pandas
from sklearn import model_selection
from sklearn import neighbors
from sklearn import pipeline
from sklearn import preprocessing
SALES_PATH = "data/kc_house_data.csv" # path to CSV with home sale data
DEMOGRAPHICS_PATH = "data/kc_house_data.csv" # path to CSV with demographics
# List of columns (subset) that will be taken from home sale data
SALES_COLUMN_SELECTION = [
'price', 'bedrooms', 'bathrooms', 'sqft_living', 'sqft_lot', 'floors',
'sqft_above', 'sqft_basement', 'zipcode'
]
OUTPUT_DIR = "model" # Directory where output artifacts will be saved
def load_data(
sales_path: str, demographics_path: str, sales_column_selection: List[str]
) -> Tuple[pandas.DataFrame, pandas.Series]:
"""Load the target and feature data by merging sales and demographics.
Args:
sales_path: path to CSV file with home sale data
demographics_path: path to CSV file with home sale data
sales_column_selection: list of columns from sales data to be used as
features
Returns:
Tuple containg with two elements: a DataFrame and a Series of the same
length. The DataFrame contains features for machine learning, the
series contains the target variable (home sale price).
"""
data = pandas.read_csv(sales_path,
usecols=sales_column_selection,
dtype={'zipcode': str})
demographics = pandas.read_csv("data/zipcode_demographics.csv",
dtype={'zipcode': str})
merged_data = data.merge(demographics, how="left",
on="zipcode").drop(columns="zipcode")
# Remove the target variable from the dataframe, features will remain
y = merged_data.pop('price')
x = merged_data
return x, y
def main():
"""Load data, train model, and export artifacts."""
x, y = load_data(SALES_PATH, DEMOGRAPHICS_PATH, SALES_COLUMN_SELECTION)
x_train, _x_test, y_train, _y_test = model_selection.train_test_split(
x, y, random_state=42)
model = pipeline.make_pipeline(preprocessing.RobustScaler(),
neighbors.KNeighborsRegressor()).fit(
x_train, y_train)
output_dir = pathlib.Path(OUTPUT_DIR)
output_dir.mkdir(exist_ok=True)
# Output model artifacts: pickled model and JSON list of features
pickle.dump(model, open(output_dir / "model.pkl", 'wb'))
json.dump(list(x_train.columns),
open(output_dir / "model_features.json", 'w'))
if __name__ == "__main__":
main()