Skip to content

Commit 480e204

Browse files
authored
hw3_final
1 parent d8e5123 commit 480e204

4 files changed

Lines changed: 74 additions & 9 deletions

File tree

README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,7 @@
3030
# a yellow window will appear for the manual download.
3131

3232
# use pre-commit run --all-files on ubuntu terminal to format.
33+
34+
# use with subtable as (select \* from some table), subtable 2 as ( )
35+
36+
# select \* from subtable

bda602_hw3.py

Lines changed: 60 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,37 @@
33
from pyspark.ml.classification import LogisticRegression
44
from pyspark.ml.feature import StandardScaler, VectorAssembler
55
from pyspark.sql import SparkSession
6+
from pyspark import keyword_only
7+
from pyspark.ml.param.shared import HasInputCols, HasOutputCol
8+
from pyspark.ml.util import DefaultParamsReadable, DefaultParamsWritable
9+
from pyspark.sql.functions import col, concat, lit, split, when
10+
from pyspark.ml import Pipeline, Transformer
11+
12+
13+
class SplitColumnTransform(
14+
Transformer,
15+
HasInputCols,
16+
HasOutputCol,
17+
DefaultParamsReadable,
18+
DefaultParamsWritable,
19+
):
20+
@keyword_only
21+
def __init__(self, inputCols=None, outputCol=None):
22+
super(SplitColumnTransform, self).__init__()
23+
kwargs = self._input_kwargs
24+
self.setParams(**kwargs)
25+
return
26+
27+
@keyword_only
28+
def setParams(self, inputCols=None, outputCol=None):
29+
kwargs = self._input_kwargs
30+
return self._set(**kwargs)
31+
32+
def _transform(self, dataset):
33+
input_cols = self.getInputCols()
34+
output_col = self.getOutputCol()
35+
36+
return dataset.show()
637

738

839
def main():
@@ -18,19 +49,24 @@ def main():
1849
.enableHiveSupport()
1950
.getOrCreate()
2051
)
21-
52+
# "SELECT nb.batter,nb.Hit,nb.atBat,nb.game_id,nb.local_date,"
53+
# "SUM(nb.Hit) AS total_h,SUM(nb.atBat) as total_ab,(SUM(nb.Hit) / SUM(nb.atBat)) AS rolling_avg"
54+
# "FROM new_baseball nb"
55+
# "JOIN new_baseball nb2"
56+
# "on nb.batter"
57+
# "where nb.local_date between nb.local_date -100 and nb2.local_date"
58+
# "GROUP by nb.batter,nb.local_date"
2259
sql = (
23-
"SELECT nb.batter,nb.Hit,nb.atBat,nb.game_id,nb.local_date,"
24-
"SUM(nb.Hit) AS total_h,SUM(nb.atBat) as total_ab,(SUM(nb.Hit) / SUM(nb.atBat)) AS rolling_avg"
25-
"FROM new_baseball nb"
26-
"JOIN new_baseball nb2"
27-
"on nb.batter"
28-
"where nb.local_date between nb.local_date -100 and nb2.local_date"
29-
"GROUP by nb.batter,nb.local_date"
60+
"""SELECT bc.batter,bc.Hit, bc.atBat,g.game_id, g.local_dateFROM batter_counts bc,
61+
SUM(nb.Hit) AS total_h,SUM(nb.atBat) as total_ab,(SUM(nb.Hit) / SUM(nb.atBat)) AS rolling_avg
62+
JOIN game g
63+
ON g.game_id = bc.game_id
64+
order by bc.batter, bc.game_id"""
65+
3066
)
3167
database = "baseball"
3268
user = "tderig"
33-
password = ""
69+
password = "password"
3470
server = "127.0.0.1"
3571
port = 3306
3672
jdbc_url = f"jdbc:mysql://{server}:{port}/{database}?permitMysqlScheme"
@@ -49,6 +85,21 @@ def main():
4985
df.show(5)
5086
df.printSchema()
5187

88+
df.createOrReplaceTempView("rolling_avg")
89+
df2 = spark.sql("""select batter, game_id, SUM(Hit) AS total_h,SUM(nb.atBat)
90+
as total_ab,(SUM(nb.Hit) / SUM(nb.atBat)) AS rolling_avg
91+
where nb.local_date >= 2012-03-20 00:00:00.000 and nb2.local_date < 2012-06-28 22:15:00.000
92+
GROUP by nb.batter,nb.local_date"""
93+
)
94+
95+
96+
97+
new_transform = SplitColumnTransform()
98+
pipeline = Pipeline(stages=[new_transform])
99+
model = pipeline.fit(df2)
100+
model.transform(df2)
101+
52102

53103
if __name__ == "__main__":
54104
main()
105+
#

requirements.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,7 @@ mariadb
22
numpy
33
pandas
44
plotly
5+
pyspark
6+
pyspark-stubs
57
scikit-learn
68
sqlalchemy

requirements.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,14 @@ pandas==1.5.0
2020
# via -r requirements.in
2121
plotly==5.10.0
2222
# via -r requirements.in
23+
py4j==0.10.9
24+
# via pyspark
25+
pyspark==3.0.3
26+
# via
27+
# -r requirements.in
28+
# pyspark-stubs
29+
pyspark-stubs==3.0.0.post3
30+
# via -r requirements.in
2331
python-dateutil==2.8.2
2432
# via pandas
2533
pytz==2022.2.1

0 commit comments

Comments
 (0)