33from pyspark .ml .classification import LogisticRegression
44from pyspark .ml .feature import StandardScaler , VectorAssembler
55from pyspark .sql import SparkSession
6+ from pyspark import keyword_only
7+ from pyspark .ml .param .shared import HasInputCols , HasOutputCol
8+ from pyspark .ml .util import DefaultParamsReadable , DefaultParamsWritable
9+ from pyspark .sql .functions import col , concat , lit , split , when
10+ from pyspark .ml import Pipeline , Transformer
11+
12+
13+ class SplitColumnTransform (
14+ Transformer ,
15+ HasInputCols ,
16+ HasOutputCol ,
17+ DefaultParamsReadable ,
18+ DefaultParamsWritable ,
19+ ):
20+ @keyword_only
21+ def __init__ (self , inputCols = None , outputCol = None ):
22+ super (SplitColumnTransform , self ).__init__ ()
23+ kwargs = self ._input_kwargs
24+ self .setParams (** kwargs )
25+ return
26+
27+ @keyword_only
28+ def setParams (self , inputCols = None , outputCol = None ):
29+ kwargs = self ._input_kwargs
30+ return self ._set (** kwargs )
31+
32+ def _transform (self , dataset ):
33+ input_cols = self .getInputCols ()
34+ output_col = self .getOutputCol ()
35+
36+ return dataset .show ()
637
738
839def main ():
@@ -18,19 +49,24 @@ def main():
1849 .enableHiveSupport ()
1950 .getOrCreate ()
2051 )
21-
52+ # "SELECT nb.batter,nb.Hit,nb.atBat,nb.game_id,nb.local_date,"
53+ # "SUM(nb.Hit) AS total_h,SUM(nb.atBat) as total_ab,(SUM(nb.Hit) / SUM(nb.atBat)) AS rolling_avg"
54+ # "FROM new_baseball nb"
55+ # "JOIN new_baseball nb2"
56+ # "on nb.batter"
57+ # "where nb.local_date between nb.local_date -100 and nb2.local_date"
58+ # "GROUP by nb.batter,nb.local_date"
2259 sql = (
23- "SELECT nb.batter,nb.Hit,nb.atBat,nb.game_id,nb.local_date,"
24- "SUM(nb.Hit) AS total_h,SUM(nb.atBat) as total_ab,(SUM(nb.Hit) / SUM(nb.atBat)) AS rolling_avg"
25- "FROM new_baseball nb"
26- "JOIN new_baseball nb2"
27- "on nb.batter"
28- "where nb.local_date between nb.local_date -100 and nb2.local_date"
29- "GROUP by nb.batter,nb.local_date"
60+ """SELECT bc.batter,bc.Hit, bc.atBat,g.game_id, g.local_dateFROM batter_counts bc,
61+ SUM(nb.Hit) AS total_h,SUM(nb.atBat) as total_ab,(SUM(nb.Hit) / SUM(nb.atBat)) AS rolling_avg
62+ JOIN game g
63+ ON g.game_id = bc.game_id
64+ order by bc.batter, bc.game_id"""
65+
3066 )
3167 database = "baseball"
3268 user = "tderig"
33- password = ""
69+ password = "password "
3470 server = "127.0.0.1"
3571 port = 3306
3672 jdbc_url = f"jdbc:mysql://{ server } :{ port } /{ database } ?permitMysqlScheme"
@@ -49,6 +85,21 @@ def main():
4985 df .show (5 )
5086 df .printSchema ()
5187
88+ df .createOrReplaceTempView ("rolling_avg" )
89+ df2 = spark .sql ("""select batter, game_id, SUM(Hit) AS total_h,SUM(nb.atBat)
90+ as total_ab,(SUM(nb.Hit) / SUM(nb.atBat)) AS rolling_avg
91+ where nb.local_date >= 2012-03-20 00:00:00.000 and nb2.local_date < 2012-06-28 22:15:00.000
92+ GROUP by nb.batter,nb.local_date"""
93+ )
94+
95+
96+
97+ new_transform = SplitColumnTransform ()
98+ pipeline = Pipeline (stages = [new_transform ])
99+ model = pipeline .fit (df2 )
100+ model .transform (df2 )
101+
52102
53103if __name__ == "__main__" :
54104 main ()
105+ #
0 commit comments