forked from dafrenchyman/PythonProjectTemplate
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbda602_hw3.py
More file actions
94 lines (78 loc) · 2.7 KB
/
bda602_hw3.py
File metadata and controls
94 lines (78 loc) · 2.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from pyspark.sql import SparkSession
from pyspark import keyword_only
from pyspark.ml.param.shared import HasInputCols, HasOutputCol
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
from pyspark.ml import Pipeline, Transformer
class SplitColumnTransform(
Transformer,
HasInputCols,
HasOutputCol,
DefaultParamsReadable,
DefaultParamsWritable,
):
@keyword_only
def __init__(self, inputCols=None, outputCol=None):
super(SplitColumnTransform, self).__init__()
kwargs = self._input_kwargs
self.setParams(**kwargs)
return
@keyword_only
def setParams(self, inputCols=None, outputCol=None):
kwargs = self._input_kwargs
return self._set(**kwargs)
def _transform(self, dataset):
input_cols = self.getInputCols()
output_col = self.getOutputCol()
return dataset.show()
def main():
appName = "App"
master = "local[*]"
spark = (
SparkSession.builder.appName(appName)
.master(master)
.config(
"spark.jars",
"/mnt/c/Users/thoma/scr/PythonProjectTemplate/mariadb-java-client-3.0.8.jar",
)
.enableHiveSupport()
.getOrCreate()
)
sql = (
"""SELECT bc.batter,bc.Hit, bc.atBat,g.game_id, g.local_dateFROM batter_counts bc,
SUM(nb.Hit) AS total_h,SUM(nb.atBat) as total_ab,(SUM(nb.Hit) / SUM(nb.atBat)) AS rolling_avg
JOIN game g
ON g.game_id = bc.game_id
order by bc.batter, bc.game_id"""
)
database = "baseball"
user = ""
password = ""
server = "127.0.0.1"
port = 3306
jdbc_url = f"jdbc:mysql://{server}:{port}/{database}?permitMysqlScheme"
jdbc_driver = "org.mariadb.jdbc.Driver"
df = (
spark.read.format("jdbc")
.option("url", jdbc_url)
.option("query", sql)
.option("user", user)
.option("password", password)
.option("trustServerCertificate", True)
.option("driver", jdbc_driver)
.load()
)
df.show(5)
df.printSchema()
df.createOrReplaceTempView("rolling_avg")
df2 = spark.sql("""select batter, game_id, SUM(Hit) AS total_h,SUM(nb.atBat)
as total_ab,(SUM(nb.Hit) / SUM(nb.atBat)) AS rolling_avg
where nb.local_date >= 2012-03-20 00:00:00.000 and nb2.local_date < 2012-06-28 22:15:00.000
GROUP by nb.batter,nb.local_date"""
)
new_transform = SplitColumnTransform()
pipeline = Pipeline(stages=[new_transform])
model = pipeline.fit(df2)
model.transform(df2)
if __name__ == "__main__":
main()
#