-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathrandomForestIncome.py
More file actions
27 lines (17 loc) · 872 Bytes
/
randomForestIncome.py
File metadata and controls
27 lines (17 loc) · 872 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def warn(*args, **kwargs):
pass
import warnings
warnings.warn = warn
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn import tree
from sklearn.ensemble import RandomForestClassifier
income_data = pd.read_csv("income.csv", header = 0, delimiter = ", ")
labels = income_data[["income"]]
income_data["sex-int"] = income_data["sex"].apply(lambda row: 0 if row == "Male" else 1)
income_data["country-int"] = income_data["native-country"].apply(lambda row: 0 if row == "United States" else 1)
data = income_data[["age", "capital-gain", "capital-loss", "hours-per-week", "sex-int", "country-int"]]
train_data, test_data, train_labels, test_labels = train_test_split(data, labels, random_state = 1)
forest = RandomForestClassifier(random_state = 1)
forest.fit(train_data, train_labels)
print(forest.score(test_data, test_labels))