diff --git a/.gitignore b/.gitignore index 15a05753a0..57fa41912d 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,12 @@ project/plugins/lib_managed/ project/plugins/src_managed/ /.idea/ /.idea_modules/ +.project +.classpath +.cache-main +.cache-tests +.tmpBin +bin *.iml sonatype.sbt tutorial/data/cofollows.tsv diff --git a/.travis.yml b/.travis.yml index 2a8c092edd..006975af74 100644 --- a/.travis.yml +++ b/.travis.yml @@ -27,63 +27,63 @@ addons: matrix: include: #BASE TESTS - - scala: 2.11.8 - env: BUILD="base" TEST_TARGET="scalding-args scalding-date maple" + - scala: 2.11.11 + env: BUILD="base" TEST_TARGET="scalding-args scalding-date maple scalding-quotation" script: "scripts/run_test.sh" - - scala: 2.12.1 - env: BUILD="base" TEST_TARGET="scalding-args scalding-date maple" + - scala: 2.12.3 + env: BUILD="base" TEST_TARGET="scalding-args scalding-date maple scalding-quotation" script: "scripts/run_test.sh" - - scala: 2.11.8 + - scala: 2.11.11 env: BUILD="base" TEST_TARGET="scalding-avro scalding-hraven scalding-commons scalding-parquet scalding-parquet-cascading scalding-parquet-scrooge scalding-parquet-scrooge-cascading" script: "scripts/run_test.sh" - - scala: 2.12.1 + - scala: 2.12.3 env: BUILD="base" TEST_TARGET="scalding-avro scalding-hraven scalding-commons scalding-parquet scalding-parquet-cascading scalding-parquet-scrooge scalding-parquet-scrooge-cascading" script: "scripts/run_test.sh" - - scala: 2.11.8 + - scala: 2.11.11 env: BUILD="base" TEST_TARGET="scalding-core scalding-jdbc scalding-json scalding-db" script: "scripts/run_test.sh" - - scala: 2.12.1 + - scala: 2.12.3 env: BUILD="base" TEST_TARGET="scalding-core scalding-jdbc scalding-json scalding-db" script: "scripts/run_test.sh" - - scala: 2.11.8 + - scala: 2.11.11 env: BUILD="base" TEST_TARGET="scalding-hadoop-test" script: "scripts/run_test.sh" - - scala: 2.12.1 + - scala: 2.12.3 env: BUILD="base" TEST_TARGET="scalding-hadoop-test" script: "scripts/run_test.sh" - - scala: 2.11.8 + - scala: 2.11.11 env: BUILD="base" TEST_TARGET="scalding-estimators-test" script: "scripts/run_test.sh" - - scala: 2.12.1 + - scala: 2.12.3 env: BUILD="base" TEST_TARGET="scalding-estimators-test" script: "scripts/run_test.sh" - - scala: 2.11.8 + - scala: 2.11.11 env: BUILD="base" TEST_TARGET="scalding-serialization" script: "scripts/run_test.sh" - - scala: 2.12.1 + - scala: 2.12.3 env: BUILD="base" TEST_TARGET="scalding-serialization" script: "scripts/run_test.sh" - - scala: 2.11.8 + - scala: 2.11.11 env: BUILD="base" TEST_TARGET="scalding-thrift-macros" script: "scripts/run_test.sh" - - scala: 2.12.1 + - scala: 2.12.3 env: BUILD="base" TEST_TARGET="scalding-thrift-macros" script: "scripts/run_test.sh" - - scala: 2.11.8 + - scala: 2.11.11 env: BUILD="test tutorials and matrix tutorials and repl" TEST_TARGET="scalding-repl" script: - "scripts/run_test.sh" @@ -92,7 +92,7 @@ matrix: - "scripts/build_assembly_no_test.sh scalding-assembly" - "scripts/test_matrix_tutorials.sh" - - scala: 2.12.1 + - scala: 2.12.3 env: BUILD="test tutorials and matrix tutorials and repl" TEST_TARGET="scalding-repl" script: - "scripts/run_test.sh" @@ -101,7 +101,7 @@ matrix: - "scripts/build_assembly_no_test.sh scalding-assembly" - "scripts/test_matrix_tutorials.sh" - - scala: 2.11.8 + - scala: 2.11.11 env: BUILD="test repl and typed tutorials and microsite" script: - ./sbt ++$TRAVIS_SCALA_VERSION clean docs/makeMicrosite @@ -112,7 +112,7 @@ matrix: - "scripts/build_assembly_no_test.sh execution-tutorial" - "scripts/test_execution_tutorial.sh" - - scala: 2.12.1 + - scala: 2.12.3 env: BUILD="test repl and typed tutorials" script: - "scripts/build_assembly_no_test.sh scalding-repl" diff --git a/build.sbt b/build.sbt index 3b27f140e5..9747676faf 100644 --- a/build.sbt +++ b/build.sbt @@ -21,6 +21,7 @@ val avroVersion = "1.7.4" val bijectionVersion = "0.9.5" val cascadingAvroVersion = "2.1.2" val chillVersion = "0.8.4" +val dagonVersion = "0.2.2" val elephantbirdVersion = "4.15" val hadoopLzoVersion = "0.4.19" val hadoopVersion = "2.6.0" @@ -47,9 +48,9 @@ val printDependencyClasspath = taskKey[Unit]("Prints location of the dependencie val sharedSettings = assemblySettings ++ scalariformSettings ++ Seq( organization := "com.twitter", - scalaVersion := "2.11.8", + scalaVersion := "2.11.11", - crossScalaVersions := Seq(scalaVersion.value, "2.12.1"), + crossScalaVersions := Seq(scalaVersion.value, "2.12.3"), ScalariformKeys.preferences := formattingPreferences, @@ -57,7 +58,9 @@ val sharedSettings = assemblySettings ++ scalariformSettings ++ Seq( javacOptions in doc := Seq("-source", "1.6"), - wartremoverErrors in (Compile, compile) += Wart.OptionPartial, + wartremoverErrors in (Compile, compile) ++= Seq( + Wart.OptionPartial, Wart.ExplicitImplicitTypes, Wart.LeakingSealed, + Wart.Return, Wart.EitherProjectionPartial), libraryDependencies ++= Seq( "org.mockito" % "mockito-all" % "1.8.5" % "test", @@ -214,6 +217,7 @@ lazy val scalding = Project( .aggregate( scaldingArgs, scaldingDate, + scaldingQuotation, scaldingCore, scaldingCommons, scaldingAvro, @@ -242,6 +246,7 @@ lazy val scaldingAssembly = Project( .aggregate( scaldingArgs, scaldingDate, + scaldingQuotation, scaldingCore, scaldingCommons, scaldingAvro, @@ -298,10 +303,8 @@ lazy val scaldingArgs = module("args") lazy val scaldingDate = module("date") -lazy val scaldingGraph = module("graph") - lazy val cascadingVersion = - System.getenv.asScala.getOrElse("SCALDING_CASCADING_VERSION", "3.2.1") + System.getenv.asScala.getOrElse("SCALDING_CASCADING_VERSION", "3.3.0-wip-18") lazy val cascadingJDBCVersion = System.getenv.asScala.getOrElse("SCALDING_CASCADING_JDBC_VERSION", "3.0.0-wip-127") @@ -316,11 +319,19 @@ lazy val scaldingBenchmarks = module("benchmarks") parallelExecution in Test := false ).dependsOn(scaldingCore) +lazy val scaldingQuotation = module("quotation").settings( + libraryDependencies ++= Seq( + "org.scala-lang" % "scala-reflect" % scalaVersion.value % "provided", + "org.scala-lang" % "scala-compiler" % scalaVersion.value % "provided" + ) +) + lazy val scaldingCore = module("core").settings( libraryDependencies ++= Seq( "cascading" % "cascading-core" % cascadingVersion, "cascading" % "cascading-hadoop" % cascadingVersion, "cascading" % "cascading-local" % cascadingVersion, + "com.stripe" %% "dagon-core" % dagonVersion, "com.twitter" % "chill-hadoop" % chillVersion, "com.twitter" % "chill-java" % chillVersion, "com.twitter" %% "chill-bijection" % chillVersion, @@ -337,7 +348,7 @@ lazy val scaldingCore = module("core").settings( "org.slf4j" % "slf4j-api" % slf4jVersion, "org.slf4j" % "slf4j-log4j12" % slf4jVersion % "provided"), addCompilerPlugin("org.scalamacros" % "paradise" % paradiseVersion cross CrossVersion.full) -).dependsOn(scaldingArgs, scaldingDate, scaldingSerialization, maple) +).dependsOn(scaldingArgs, scaldingDate, scaldingSerialization, maple, scaldingQuotation) lazy val scaldingCommons = module("commons").settings( libraryDependencies ++= Seq( diff --git a/project/plugins.sbt b/project/plugins.sbt index be4e74c664..5131951d08 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -18,4 +18,4 @@ addSbtPlugin("com.typesafe.sbt" % "sbt-git" % "0.6.2") addSbtPlugin("com.typesafe.sbt" % "sbt-scalariform" % "1.3.0") addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.5.0") addSbtPlugin("org.xerial.sbt" % "sbt-sonatype" % "1.0") -addSbtPlugin("org.wartremover" % "sbt-wartremover" % "2.0.2") +addSbtPlugin("org.wartremover" % "sbt-wartremover" % "2.1.1") diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/DailySources.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/DailySources.scala index 21fa7b7e35..2f2dc068b5 100644 --- a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/DailySources.scala +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/DailySources.scala @@ -42,7 +42,7 @@ abstract class DailySuffixLzoProtobuf[T <: Message: Manifest](prefix: String, da abstract class DailySuffixMostRecentLzoProtobuf[T <: Message: Manifest](prefix: String, dateRange: DateRange) extends DailySuffixMostRecentSource(prefix, dateRange) with LzoProtobuf[T] { - override def column = manifest[T].erasure + override def column = manifest[T].runtimeClass } abstract class DailySuffixLzoThrift[T <: TBase[_, _]: Manifest](prefix: String, dateRange: DateRange) diff --git a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/VersionedKeyValSource.scala b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/VersionedKeyValSource.scala index 28a9b474a4..dd9e81966f 100644 --- a/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/VersionedKeyValSource.scala +++ b/scalding-commons/src/main/scala/com/twitter/scalding/commons/source/VersionedKeyValSource.scala @@ -127,22 +127,20 @@ class VersionedKeyValSource[K, V](val path: String, val sourceVersion: Option[Lo } def sinkExists(mode: Mode): Boolean = - sinkVersion match { - case Some(version) => - mode match { - case Test(buffers) => - buffers(this) map { !_.isEmpty } getOrElse false + sinkVersion.exists { version => + mode match { + case Test(buffers) => + buffers(this) map { !_.isEmpty } getOrElse false - case HadoopTest(conf, buffers) => - buffers(this) map { !_.isEmpty } getOrElse false + case HadoopTest(conf, buffers) => + buffers(this) map { !_.isEmpty } getOrElse false - case m: HadoopMode => - val conf = new JobConf(m.jobConf) - val store = sink.getStore(conf) - store.hasVersion(version) - case _ => sys.error(s"Unknown mode $mode") - } - case None => false + case m: HadoopMode => + val conf = new JobConf(m.jobConf) + val store = sink.getStore(conf) + store.hasVersion(version) + case _ => sys.error(s"Unknown mode $mode") + } } override def createTap(readOrWrite: AccessMode)(implicit mode: Mode): Tap[_, _, _] = { diff --git a/scalding-core/src/main/scala/com/twitter/package.scala b/scalding-core/src/main/scala/com/twitter/package.scala index eae1c53cb3..18ddfcacac 100644 --- a/scalding-core/src/main/scala/com/twitter/package.scala +++ b/scalding-core/src/main/scala/com/twitter/package.scala @@ -37,7 +37,7 @@ package object scalding { val scaldingVersion: String = "0.17.2" object RichPathFilter { - implicit def toRichPathFilter(f: PathFilter) = new RichPathFilter(f) + implicit def toRichPathFilter(f: PathFilter): RichPathFilter = new RichPathFilter(f) } class RichPathFilter(f: PathFilter) { diff --git a/scalding-core/src/main/scala/com/twitter/scalding/ArgHelp.scala b/scalding-core/src/main/scala/com/twitter/scalding/ArgHelp.scala index 5e259bd755..94c21a0ae4 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/ArgHelp.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/ArgHelp.scala @@ -5,10 +5,10 @@ sealed trait DescribedArg { def description: String } -case class RequiredArg(key: String, description: String) extends DescribedArg -case class OptionalArg(key: String, description: String) extends DescribedArg -case class ListArg(key: String, description: String) extends DescribedArg -case class BooleanArg(key: String, description: String) extends DescribedArg +final case class RequiredArg(key: String, description: String) extends DescribedArg +final case class OptionalArg(key: String, description: String) extends DescribedArg +final case class ListArg(key: String, description: String) extends DescribedArg +final case class BooleanArg(key: String, description: String) extends DescribedArg class HelpException extends RuntimeException("User asked for help") class DescriptionValidationException(msg: String) extends RuntimeException(msg) @@ -119,4 +119,4 @@ trait ArgHelper { } } -object ArgHelp extends ArgHelper \ No newline at end of file +object ArgHelp extends ArgHelper diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Config.scala b/scalding-core/src/main/scala/com/twitter/scalding/Config.scala index 0274d24b6b..571d279007 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Config.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Config.scala @@ -34,11 +34,12 @@ import java.net.URI import scala.collection.JavaConverters._ import scala.util.{ Failure, Success, Try } +import com.twitter.scalding.serialization.RequireOrderedSerializationMode /** * This is a wrapper class on top of Map[String, String] */ -trait Config extends Serializable { +abstract class Config extends Serializable { import Config._ // get the constants def toMap: Map[String, String] @@ -101,8 +102,10 @@ trait Config extends Serializable { * is used to create the Class.forName */ def getCascadingAppJar: Option[Try[Class[_]]] = - get(AppProps.APP_JAR_CLASS).map { str => - // The Class[_] messes up using Try(Class.forName(str)) on scala 2.9.3 + getClassForKey(AppProps.APP_JAR_CLASS) + + def getClassForKey(k: String): Option[Try[Class[_]]] = + get(k).map { str => try { Success( // Make sure we are using the class-loader for the current thread @@ -110,6 +113,7 @@ trait Config extends Serializable { } catch { case err: Throwable => Failure(err) } } + /* * Used in joins to determine how much of the "right hand side" of * the join to keep in memory @@ -135,17 +139,30 @@ trait Config extends Serializable { def setMapSideAggregationThreshold(count: Int): Config = this + (AggregateByProps.AGGREGATE_BY_CAPACITY -> count.toString) + @deprecated("Use setRequireOrderedSerializationMode", "12/14/17") + def setRequireOrderedSerialization(b: Boolean): Config = + this + (ScaldingRequireOrderedSerialization -> (b.toString)) + + @deprecated("Use getRequireOrderedSerializationMode", "12/14/17") + def getRequireOrderedSerialization: Boolean = + getRequireOrderedSerializationMode == Some(RequireOrderedSerializationMode.Fail) + /** * Set this configuration option to require all grouping/cogrouping * to use OrderedSerialization */ - def setRequireOrderedSerialization(b: Boolean): Config = - this + (ScaldingRequireOrderedSerialization -> (b.toString)) + def setRequireOrderedSerializationMode(r: Option[RequireOrderedSerializationMode]): Config = + r.map { + v => this + (ScaldingRequireOrderedSerialization -> (v.toString)) + }.getOrElse(this) - def getRequireOrderedSerialization: Boolean = + def getRequireOrderedSerializationMode: Option[RequireOrderedSerializationMode] = get(ScaldingRequireOrderedSerialization) - .map(_.toBoolean) - .getOrElse(false) + .map(_.toLowerCase()).collect { + case "true" => RequireOrderedSerializationMode.Fail // backwards compatibility + case "fail" => RequireOrderedSerializationMode.Fail + case "log" => RequireOrderedSerializationMode.Log + } def getCascadingSerializationTokens: Map[Int, String] = get(Config.CascadingSerializationTokens) @@ -239,6 +256,19 @@ trait Config extends Serializable { def setDefaultComparator(clazz: Class[_ <: java.util.Comparator[_]]): Config = this + (FlowProps.DEFAULT_ELEMENT_COMPARATOR -> clazz.getName) + def getOptimizationPhases: Option[Try[typed.OptimizationPhases]] = + getClassForKey(Config.OptimizationPhases).map { tryClass => + tryClass.flatMap { clazz => + Try(clazz.newInstance().asInstanceOf[typed.OptimizationPhases]) + } + } + + def setOptimizationPhases(clazz: Class[_ <: typed.OptimizationPhases]): Config = + setOptimizationPhasesFromName(clazz.getName) + + def setOptimizationPhasesFromName(className: String): Config = + this + (Config.OptimizationPhases -> className) + def getScaldingVersion: Option[String] = get(Config.ScaldingVersion) def setScaldingVersion: Config = (this.+(Config.ScaldingVersion -> scaldingVersion)).+( @@ -391,6 +421,18 @@ trait Config extends Serializable { def getHashJoinAutoForceRight: Boolean = get(HashJoinAutoForceRight) + .map(_.toBoolean) + .getOrElse(true) // cascading3 seems to currently require this + + def setConvertHashJoinToShuffleJoin(b: Boolean): Config = + this + (Config.HashToShuffleJoin -> (b.toString)) + + /** + * Cascading 3 has in the past had issues with hashJoins. + * If your plan fails, you may try with this option set. + */ + def getConvertHashJoinToShuffleJoin: Boolean = + get(Config.HashToShuffleJoin) .map(_.toBoolean) .getOrElse(false) @@ -402,6 +444,19 @@ trait Config extends Serializable { def setVerboseFileSourceLogging(b: Boolean): Config = this + (VerboseFileSourceLoggingKey -> b.toString) + def getSkipNullCounters: Boolean = + get(SkipNullCounters) + .map(_.toBoolean) + .getOrElse(false) + + /** + * If this is true, on hadoop, when we get a null Counter + * for a given name, we just ignore the counter instead + * of NPE + */ + def setSkipNullCounters(boolean: Boolean): Config = + this + (SkipNullCounters -> boolean.toString) + override def hashCode = toMap.hashCode override def equals(that: Any) = that match { case thatConf: Config => toMap == thatConf.toMap @@ -424,12 +479,14 @@ object Config { val ScaldingJobArgs: String = "scalding.job.args" val ScaldingJobArgsSerialized: String = "scalding.job.argsserialized" val ScaldingVersion: String = "scalding.version" + val SkipNullCounters: String = "scalding.counters.skipnull" val HRavenHistoryUserName: String = "hraven.history.user.name" val ScaldingRequireOrderedSerialization: String = "scalding.require.orderedserialization" val FlowListeners: String = "scalding.observability.flowlisteners" val FlowStepListeners: String = "scalding.observability.flowsteplisteners" val FlowStepStrategies: String = "scalding.strategies.flowstepstrategies" - val VerboseFileSourceLoggingKey = "scalding.filesource.verbose.logging" + val VerboseFileSourceLoggingKey: String = "scalding.filesource.verbose.logging" + val OptimizationPhases: String = "scalding.optimization.phases" /** * Parameter that actually controls the number of reduce tasks. @@ -467,12 +524,17 @@ object Config { /** * Parameter that can be used to determine behavior on the rhs of a hashJoin. - * If true, we try to guess when to auto force to disk before a hashJoin - * else (the default) we don't try to infer this and the behavior can be dictated by the user manually + * If true (the default), we try to guess when to auto force to disk before a hashJoin + * else we don't try to infer this and the behavior can be dictated by the user manually * calling forceToDisk on the rhs or not as they wish. + * + * Note, cascading3 seems to currently require this behavior, so disable at your own + * risk */ val HashJoinAutoForceRight: String = "scalding.hashjoin.autoforceright" + val HashToShuffleJoin: String = "scalding.hashjoin.convertshuffle" + val empty: Config = Config(Map.empty) /* @@ -553,7 +615,7 @@ object Config { * Either union these two, or return the keys that overlap */ def disjointUnion[K >: String, V >: String](m: Map[K, V], conf: Config): Either[Set[String], Map[K, V]] = { - val asMap = conf.toMap.toMap[K, V] // linter:ignore we are upcasting K, V + val asMap = conf.toMap.toMap[K, V] // linter:disable:TypeToType // we are upcasting K, V val duplicateKeys = (m.keySet & asMap.keySet) if (duplicateKeys.isEmpty) Right(m ++ asMap) else Left(conf.toMap.keySet.filter(duplicateKeys(_))) // make sure to return Set[String], and not cast @@ -562,7 +624,7 @@ object Config { * This overwrites any keys in m that exist in config. */ def overwrite[K >: String, V >: String](m: Map[K, V], conf: Config): Map[K, V] = - m ++ (conf.toMap.toMap[K, V]) // linter:ignore we are upcasting K, V + m ++ (conf.toMap.toMap[K, V]) // linter:disable:TypeToType // we are upcasting K, V /* * Note that Hadoop Configuration is mutable, but Config is not. So a COPY is diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Execution.scala b/scalding-core/src/main/scala/com/twitter/scalding/Execution.scala index e83badbcb9..d98e745d9a 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Execution.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Execution.scala @@ -16,10 +16,12 @@ limitations under the License. package com.twitter.scalding import cascading.flow.{ FlowDef, Flow } +import com.stripe.dagon.{ Dag, Id, Rule, HMap } import com.twitter.algebird.monad.Trampoline import com.twitter.algebird.{ Monoid, Monad, Semigroup } import com.twitter.scalding.cascading_interop.FlowListenerPromise import com.twitter.scalding.filecache.{CachedFile, DistributedCacheFile} +import com.twitter.scalding.typed.functions.{ ConsList, ReverseList } import com.twitter.scalding.typed.cascading_backend.AsyncFlowDefRunner import java.util.UUID import scala.collection.mutable @@ -334,7 +336,7 @@ object Execution { getOrElseInsertWithFeedback(cfg, ex, res)._2 } - private case class FutureConst[T](get: ConcurrentExecutionContext => Future[T]) extends Execution[T] { + private final case class FutureConst[T](get: ConcurrentExecutionContext => Future[T]) extends Execution[T] { protected def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = Trampoline(cache.getOrElseInsert(conf, this, for { @@ -345,7 +347,7 @@ object Execution { // Note that unit is not optimized away, since Futures are often used with side-effects, so, // we ensure that get is always called in contrast to Mapped, which assumes that fn is pure. } - private case class FlatMapped[S, T](prev: Execution[S], fn: S => Execution[T]) extends Execution[T] { + private final case class FlatMapped[S, T](prev: Execution[S], fn: S => Execution[T]) extends Execution[T] { protected def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = Trampoline.call(prev.runStats(conf, mode, cache)).map { fut1 => cache.getOrElseInsert(conf, this, @@ -357,7 +359,7 @@ object Execution { } } - private case class Mapped[S, T](prev: Execution[S], fn: S => T) extends Execution[T] { + private final case class Mapped[S, T](prev: Execution[S], fn: S => T) extends Execution[T] { protected def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = Trampoline.call(prev.runStats(conf, mode, cache)).map { fut => cache.getOrElseInsert(conf, this, @@ -365,7 +367,7 @@ object Execution { } } - private case class GetCounters[T](prev: Execution[T]) extends Execution[(T, ExecutionCounters)] { + private final case class GetCounters[T](prev: Execution[T]) extends Execution[(T, ExecutionCounters)] { protected def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = Trampoline.call(prev.runStats(conf, mode, cache)).map { fut => cache.getOrElseInsert(conf, this, @@ -376,7 +378,7 @@ object Execution { }) } } - private case class ResetCounters[T](prev: Execution[T]) extends Execution[T] { + private final case class ResetCounters[T](prev: Execution[T]) extends Execution[T] { protected def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = Trampoline.call(prev.runStats(conf, mode, cache)).map { fut => cache.getOrElseInsert(conf, this, @@ -384,7 +386,7 @@ object Execution { } } - private case class TransformedConfig[T](prev: Execution[T], fn: Config => Config) extends Execution[T] { + private final case class TransformedConfig[T](prev: Execution[T], fn: Config => Config) extends Execution[T] { protected def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = { val mutatedConfig = fn(conf) Trampoline.call(prev.runStats(mutatedConfig, mode, cache)) @@ -403,14 +405,14 @@ object Execution { * We operate here by getting a copy of the super EvalCache, without its cache's. * This is so we can share the singleton thread for scheduling jobs against Cascading. */ - private case class WithNewCache[T](prev: Execution[T]) extends Execution[T] { + private final case class WithNewCache[T](prev: Execution[T]) extends Execution[T] { protected def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = { val ec = cache.cleanCache Trampoline.call(prev.runStats(conf, mode, ec)) } } - private case class OnComplete[T](prev: Execution[T], fn: Try[T] => Unit) extends Execution[T] { + private final case class OnComplete[T](prev: Execution[T], fn: Try[T] => Unit) extends Execution[T] { protected def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = Trampoline.call(prev.runStats(conf, mode, cache)).map { res => cache.getOrElseInsert(conf, this, { @@ -432,7 +434,7 @@ object Execution { } } - private case class RecoverWith[T](prev: Execution[T], fn: PartialFunction[Throwable, Execution[T]]) extends Execution[T] { + private final case class RecoverWith[T](prev: Execution[T], fn: PartialFunction[Throwable, Execution[T]]) extends Execution[T] { protected def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = Trampoline.call(prev.runStats(conf, mode, cache)).map { fut => cache.getOrElseInsert(conf, this, @@ -507,7 +509,7 @@ object Execution { } } - private case class Zipped[S, T](one: Execution[S], two: Execution[T]) extends Execution[(S, T)] { + private final case class Zipped[S, T](one: Execution[S], two: Execution[T]) extends Execution[(S, T)] { protected def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = for { f1 <- Trampoline.call(one.runStats(conf, mode, cache)) @@ -518,7 +520,7 @@ object Execution { .map { case ((s, ss), (t, st)) => ((s, t), ss ++ st) }) } } - private case class UniqueIdExecution[T](fn: UniqueID => Execution[T]) extends Execution[T] { + private final case class UniqueIdExecution[T](fn: UniqueID => Execution[T]) extends Execution[T] { protected def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = { Trampoline(cache.getOrElseInsert(conf, this, { val (uid, nextConf) = conf.ensureUniqueId @@ -529,7 +531,7 @@ object Execution { /* * This allows you to run any cascading flowDef as an Execution. */ - private case class FlowDefExecution(result: (Config, Mode) => FlowDef) extends Execution[Unit] { + private final case class FlowDefExecution(result: (Config, Mode) => FlowDef) extends Execution[Unit] { protected def runStats(conf: Config, mode: Mode, cache: EvalCache)(implicit cec: ConcurrentExecutionContext) = { lazy val future = { cache.writer match { @@ -553,9 +555,35 @@ object Execution { sealed trait ToWrite object ToWrite { - case class Force[T](pipe: TypedPipe[T]) extends ToWrite - case class ToIterable[T](pipe: TypedPipe[T]) extends ToWrite - case class SimpleWrite[T](pipe: TypedPipe[T], sink: TypedSink[T]) extends ToWrite + final case class Force[T](pipe: TypedPipe[T]) extends ToWrite + final case class ToIterable[T](pipe: TypedPipe[T]) extends ToWrite + final case class SimpleWrite[T](pipe: TypedPipe[T], sink: TypedSink[T]) extends ToWrite + + /** + * Optimize these writes into new writes and provide a mapping from + * the original TypedPipe to the new TypedPipe + */ + def optimizeWriteBatch(writes: List[ToWrite], rules: Seq[Rule[TypedPipe]]): HMap[TypedPipe, TypedPipe] = { + val dag = Dag.empty(typed.OptimizationRules.toLiteral) + val (d1, ws) = writes.foldLeft((dag, List.empty[Id[_]])) { + case ((dag, ws), Force(p)) => + val (d1, id) = dag.addRoot(p) + (d1, id :: ws) + case ((dag, ws), ToIterable(p)) => + val (d1, id) = dag.addRoot(p) + (d1, id :: ws) + case ((dag, ws), SimpleWrite(p, sink)) => + val (d1, id) = dag.addRoot(p) + (d1, id :: ws) + } + // now we optimize the graph + val d2 = d1.applySeq(rules) + // convert back to TypedPipe: + ws.foldLeft(HMap.empty[TypedPipe, TypedPipe]) { + case (cache, id) => + cache + (d1.evaluate(id) -> d2.evaluate(id)) + } + } } /** @@ -608,7 +636,7 @@ object Execution { * are based on on this one. By keeping the Pipe and the Sink, can inspect the Execution * DAG and optimize it later (a goal, but not done yet). */ - private case class WriteExecution[T]( + private final case class WriteExecution[T]( head: ToWrite, tail: List[ToWrite], result: ((Config, Mode, Writer, ConcurrentExecutionContext)) => Future[T]) extends Execution[T] { @@ -860,6 +888,9 @@ object Execution { ex: Execution[E]): Execution[(A, B, C, D, E)] = ax.zip(bx).zip(cx).zip(dx).zip(ex).map { case ((((a, b), c), d), e) => (a, b, c, d, e) } + // Avoid recreating the empty Execution each time + private val nil = from(Nil) + /* * If you have many Executions, it is better to combine them with * zip than flatMap (which is sequential). sequence just calls @@ -869,24 +900,14 @@ object Execution { * these executions are executed in parallel: run is called on all at the * same time, not one after the other. */ - private case class SequencingFn[T]() extends Function1[(T, List[T]), List[T]] { - def apply(results: (T, List[T])) = results match { - case (y, ys) => y :: ys - } - } - private case class ReversingFn[T]() extends Function1[List[T], List[T]] { - def apply(results: List[T]) = results.reverse - } - // Avoid recreating the empty Execution each time - private val nil = from(Nil) def sequence[T](exs: Seq[Execution[T]]): Execution[Seq[T]] = { @annotation.tailrec def go(xs: List[Execution[T]], acc: Execution[List[T]]): Execution[List[T]] = xs match { case Nil => acc - case h :: tail => go(tail, h.zip(acc).map(SequencingFn())) + case h :: tail => go(tail, h.zip(acc).map(ConsList())) } // This pushes all of them onto a list, and then reverse to keep order - go(exs.toList, nil).map(ReversingFn()) + go(exs.toList, nil).map(ReverseList()) } /** diff --git a/scalding-core/src/main/scala/com/twitter/scalding/ExecutionContext.scala b/scalding-core/src/main/scala/com/twitter/scalding/ExecutionContext.scala index 2022613d65..68a66f4d8a 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/ExecutionContext.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/ExecutionContext.scala @@ -73,10 +73,11 @@ trait ExecutionContext { // identify the flowDef val configWithId = config.addUniqueId(UniqueID.getIDFor(flowDef)) val flow = mode.newFlowConnector(configWithId).connect(flowDef) - if (config.getRequireOrderedSerialization) { + + config.getRequireOrderedSerializationMode.map { mode => // This will throw, but be caught by the outer try if // we have groupby/cogroupby not using OrderedSerializations - CascadingBinaryComparator.checkForOrderedSerialization(flow).get + CascadingBinaryComparator.checkForOrderedSerialization(flow, mode).get } flow match { diff --git a/scalding-core/src/main/scala/com/twitter/scalding/FieldConversions.scala b/scalding-core/src/main/scala/com/twitter/scalding/FieldConversions.scala index 9b91675160..bcc02e4ce1 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/FieldConversions.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/FieldConversions.scala @@ -251,10 +251,10 @@ sealed trait Field[T] extends java.io.Serializable { } @DefaultSerializer(classOf[serialization.IntFieldSerializer]) -case class IntField[T](override val id: java.lang.Integer)(implicit override val ord: Ordering[T], override val mf: Option[Manifest[T]]) extends Field[T] +final case class IntField[T](override val id: java.lang.Integer)(implicit override val ord: Ordering[T], override val mf: Option[Manifest[T]]) extends Field[T] @DefaultSerializer(classOf[serialization.StringFieldSerializer]) -case class StringField[T](override val id: String)(implicit override val ord: Ordering[T], override val mf: Option[Manifest[T]]) extends Field[T] +final case class StringField[T](override val id: String)(implicit override val ord: Ordering[T], override val mf: Option[Manifest[T]]) extends Field[T] object Field { def apply[T](index: Int)(implicit ord: Ordering[T], mf: Manifest[T]) = IntField[T](index)(ord, Some(mf)) diff --git a/scalding-core/src/main/scala/com/twitter/scalding/FileSource.scala b/scalding-core/src/main/scala/com/twitter/scalding/FileSource.scala index a4cde45bcc..1aff689718 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/FileSource.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/FileSource.scala @@ -231,7 +231,13 @@ abstract class FileSource extends SchemedSource with LocalSourceOverride with Hf * TODO: consider writing a more in-depth version of this method in [[TimePathedSource]] that looks for * TODO: missing days / hours etc. */ - protected def pathIsGood(globPattern: String, conf: Configuration) = FileSource.globHasNonHiddenPaths(globPattern, conf) + protected def pathIsGood(globPattern: String, conf: Configuration) = { + if (conf.getBoolean("scalding.require_success_file", false)) { + FileSource.allGlobFilesWithSuccess(globPattern, conf, true) + } else { + FileSource.globHasNonHiddenPaths(globPattern, conf) + } + } def hdfsPaths: Iterable[String] // By default, we write to the LAST path returned by hdfsPaths @@ -525,9 +531,12 @@ object TextLine { new TextLine(p, sm, textEncoding) } -class TextLine(p: String, override val sinkMode: SinkMode, override val textEncoding: String) extends FixedPathSource(p) with TextLineScheme { +class TextLine(p: String, override val sinkMode: SinkMode, override val textEncoding: String) extends FixedPathSource(p) with TextLineScheme with TypedSink[String] { // For some Java interop + def this(p: String) = this(p, TextLine.defaultSinkMode, TextLine.defaultTextEncoding) + + override def setter[U <: String] = TupleSetter.asSubSetter[String, U](TupleSetter.of[String]) } /** diff --git a/scalding-core/src/main/scala/com/twitter/scalding/GroupBuilder.scala b/scalding-core/src/main/scala/com/twitter/scalding/GroupBuilder.scala index 818e5f07b0..9677cad7c7 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/GroupBuilder.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/GroupBuilder.scala @@ -54,14 +54,14 @@ class GroupBuilder(val groupFields: Fields) extends FoldOperations[GroupBuilder] private def getNextMiddlefield: String = { val out = "__middlefield__" + maxMF.toString maxMF += 1 - return out + out } private def tryAggregateBy(ab: AggregateBy, ev: Pipe => Every): Boolean = { // Concat if there if not none reds = reds.map(rl => ab :: rl) evs = ev :: evs - return !reds.isEmpty + reds.nonEmpty } /** diff --git a/scalding-core/src/main/scala/com/twitter/scalding/HfsConfPropertySetter.scala b/scalding-core/src/main/scala/com/twitter/scalding/HfsConfPropertySetter.scala index f3fa41f7e6..6c547cadfb 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/HfsConfPropertySetter.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/HfsConfPropertySetter.scala @@ -55,7 +55,7 @@ private[scalding] class ConfPropertiesHfsTap( * Changes here however will not show up in the hadoop UI */ trait HfsConfPropertySetter extends HfsTapProvider { - @deprecated("Tap config is deprecated, use sourceConfig or sinkConfig directly. In cascading configs applied to sinks can leak to sources in the step writing to the sink.") + @deprecated("Tap config is deprecated, use sourceConfig or sinkConfig directly. In cascading configs applied to sinks can leak to sources in the step writing to the sink.", "0.17.0") def tapConfig: Config = Config.empty def sourceConfig: Config = Config.empty diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Job.scala b/scalding-core/src/main/scala/com/twitter/scalding/Job.scala index aa39842a75..d427093269 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Job.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Job.scala @@ -194,7 +194,7 @@ class Job(val args: Args) extends FieldConversions with java.io.Serializable { .setScaldingFlowClass(getClass) .setArgs(args) .maybeSetSubmittedTimestamp()._2 - .toMap.toMap[AnyRef, AnyRef] // linter:ignore the second one is to lift from String -> AnyRef + .toMap.toMap[AnyRef, AnyRef] // linter:disable:TypeToType // the second one is to lift from String -> AnyRef } private def reflectedClasses: Set[Class[_]] = @@ -271,8 +271,7 @@ class Job(val args: Args) extends FieldConversions with java.io.Serializable { } // Print custom counters unless --scalding.nocounters is used or there are no custom stats if (!args.boolean("scalding.nocounters")) { - implicit val statProvider = statsData - val jobStats = Stats.getAllCustomCounters + val jobStats = Stats.getAllCustomCounters()(statsData) if (!jobStats.isEmpty) { println("Dumping custom counters:") jobStats.foreach { diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Mode.scala b/scalding-core/src/main/scala/com/twitter/scalding/Mode.scala index 0b89259db0..51736f1c12 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Mode.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Mode.scala @@ -214,6 +214,13 @@ case class Hdfs(strict: Boolean, @transient conf: Configuration) extends HadoopM } } +object Hdfs { + /** + * Make an Hdfs instance in strict mode with new Configuration + */ + def default: Hdfs = Hdfs(true, new Configuration) +} + case class HadoopTest(@transient conf: Configuration, @transient buffers: Source => Option[Buffer[Tuple]]) extends HadoopMode with TestMode { diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Operations.scala b/scalding-core/src/main/scala/com/twitter/scalding/Operations.scala index 52829654a5..5ba6f893e0 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Operations.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Operations.scala @@ -316,7 +316,7 @@ package com.twitter.scalding { } } - class SummingMapsideCache[K, V](flowProcess: FlowProcess[_], summingCache: SummingWithHitsCache[K, V]) + final class SummingMapsideCache[K, V](flowProcess: FlowProcess[_], summingCache: SummingWithHitsCache[K, V]) extends MapsideCache[K, V] { private[this] val misses = CounterImpl(flowProcess, StatKey(MapsideReduce.COUNTER_GROUP, "misses")) private[this] val hits = CounterImpl(flowProcess, StatKey(MapsideReduce.COUNTER_GROUP, "hits")) @@ -349,7 +349,7 @@ package com.twitter.scalding { } } - class AdaptiveMapsideCache[K, V](flowProcess: FlowProcess[_], adaptiveCache: AdaptiveCache[K, V]) + final class AdaptiveMapsideCache[K, V](flowProcess: FlowProcess[_], adaptiveCache: AdaptiveCache[K, V]) extends MapsideCache[K, V] { private[this] val misses = CounterImpl(flowProcess, StatKey(MapsideReduce.COUNTER_GROUP, "misses")) private[this] val hits = CounterImpl(flowProcess, StatKey(MapsideReduce.COUNTER_GROUP, "hits")) @@ -698,4 +698,25 @@ package com.twitter.scalding { } } } + + /** + * This gets a pair out of a tuple, incruments the counters with the left, and passes the value + * on + */ + class IncrementCounters[A](pass: Fields, conv: TupleConverter[(A, Iterable[((String, String), Long)])]) + extends BaseOperation[Any](pass) + with Function[Any] { + + override def operate(flowProcess: FlowProcess[_], functionCall: FunctionCall[Any]): Unit = { + val (a, inc) = conv(functionCall.getArguments) + val iter = inc.iterator + while (iter.hasNext) { + val ((k1, k2), amt) = iter.next + flowProcess.increment(k1, k2, amt) + } + val tup = Tuple.size(1) + tup.set(0, a) + functionCall.getOutputCollector.add(tup) + } + } } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/SkewReplication.scala b/scalding-core/src/main/scala/com/twitter/scalding/SkewReplication.scala index 530a3f0313..3f745ff190 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/SkewReplication.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/SkewReplication.scala @@ -35,7 +35,7 @@ sealed abstract class SkewReplication { /** * See https://github.com/twitter/scalding/pull/229#issuecomment-10773810 */ -case class SkewReplicationA(replicationFactor: Int = 1) extends SkewReplication { +final case class SkewReplicationA(replicationFactor: Int = 1) extends SkewReplication { override def getReplications(leftCount: Int, rightCount: Int, reducers: Int) = { val numReducers = if (reducers <= 0) DEFAULT_NUM_REDUCERS else reducers @@ -52,7 +52,7 @@ case class SkewReplicationA(replicationFactor: Int = 1) extends SkewReplication /** * See https://github.com/twitter/scalding/pull/229#issuecomment-10792296 */ -case class SkewReplicationB(maxKeysInMemory: Int = 1E6.toInt, maxReducerOutput: Int = 1E7.toInt) +final case class SkewReplicationB(maxKeysInMemory: Int = 1E6.toInt, maxReducerOutput: Int = 1E7.toInt) extends SkewReplication { override def getReplications(leftCount: Int, rightCount: Int, reducers: Int) = { @@ -64,4 +64,4 @@ case class SkewReplicationB(maxKeysInMemory: Int = 1E6.toInt, maxReducerOutput: (left, if (right == 0) 1 else right) } -} \ No newline at end of file +} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Stats.scala b/scalding-core/src/main/scala/com/twitter/scalding/Stats.scala index 6893e9533c..b212a8182b 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Stats.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Stats.scala @@ -4,6 +4,7 @@ import cascading.flow.{ Flow, FlowListener, FlowDef, FlowProcess } import cascading.flow.hadoop.HadoopFlowProcess import cascading.stats.CascadingStats import java.util.concurrent.ConcurrentHashMap +import org.apache.hadoop.mapreduce.Counter import org.slf4j.{ Logger, LoggerFactory } import scala.collection.JavaConverters._ import scala.collection.mutable @@ -59,13 +60,27 @@ sealed private[scalding] trait CounterImpl { def increment(amount: Long): Unit } -private[scalding] case class GenericFlowPCounterImpl(fp: FlowProcess[_], statKey: StatKey) extends CounterImpl { +private[scalding] final case class GenericFlowPCounterImpl(fp: FlowProcess[_], statKey: StatKey) extends CounterImpl { override def increment(amount: Long): Unit = fp.increment(statKey.group, statKey.counter, amount) } -private[scalding] case class HadoopFlowPCounterImpl(fp: HadoopFlowProcess, statKey: StatKey) extends CounterImpl { - private[this] val cntr = fp.getReporter().getCounter(statKey.group, statKey.counter) - override def increment(amount: Long): Unit = cntr.increment(amount) +private[scalding] final case class HadoopFlowPCounterImpl(fp: HadoopFlowProcess, statKey: StatKey) extends CounterImpl { + // we use a nullable type here for efficiency + private[this] val counter: Counter = (for { + r <- Option(fp.getReporter) + c <- Option(r.getCounter(statKey.group, statKey.counter)) + } yield c).orNull + + def skipNull: Boolean = + fp.getProperty(Config.SkipNullCounters) match { + case null => false // by default don't skip + case isset => isset.toString.toBoolean + } + + require((counter != null) || skipNull, s"counter for $statKey is null and ${Config.SkipNullCounters} is not set to true") + + override def increment(amount: Long): Unit = + if (counter != null) counter.increment(amount) else () } object Stat { diff --git a/scalding-core/src/main/scala/com/twitter/scalding/Tool.scala b/scalding-core/src/main/scala/com/twitter/scalding/Tool.scala index e54fad92ba..947e8f1613 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/Tool.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/Tool.scala @@ -124,7 +124,8 @@ class Tool extends Configured with HTool { j.clear() //When we get here, the job is finished if (successful) { - j.next match { + // we need to use match not foreach to get tail recursion + j.next match { // linter:disable:UseOptionForeachNotPatMatch case Some(nextj) => start(nextj, cnt + 1) case None => () } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/bdd/PipeOperationsConversions.scala b/scalding-core/src/main/scala/com/twitter/scalding/bdd/PipeOperationsConversions.scala index ce7d064572..bb37ee95b0 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/bdd/PipeOperationsConversions.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/bdd/PipeOperationsConversions.scala @@ -22,13 +22,13 @@ trait PipeOperationsConversions { class TwoPipesOperation(op: (RichPipe, Pipe) => RichPipe) extends PipeOperation { def apply(pipes: List[RichPipe]): Pipe = { - assertPipeSize(pipes, 2); op(pipes(0), pipes(1)) // linter:ignore + assertPipeSize(pipes, 2); op(pipes(0), pipes(1)) // linter:disable } } class ThreePipesOperation(op: (RichPipe, RichPipe, RichPipe) => Pipe) extends PipeOperation { def apply(pipes: List[RichPipe]): Pipe = { - assertPipeSize(pipes, 3); op(pipes(0), pipes(1), pipes(2)) // linter:ignore + assertPipeSize(pipes, 3); op(pipes(0), pipes(1), pipes(2)) // linter:disable } } @@ -37,7 +37,7 @@ trait PipeOperationsConversions { } class ListPipesOperation(op: List[Pipe] => Pipe) extends PipeOperation { - def apply(pipes: List[RichPipe]): Pipe = op(pipes.map(_.pipe).toList) + def apply(pipes: List[RichPipe]): Pipe = op(pipes.map(_.pipe)) } implicit val fromSingleRichPipeFunctionToOperation: (RichPipe => RichPipe) => OnePipeOperation = (op: RichPipe => RichPipe) => new OnePipeOperation(op(_).pipe) diff --git a/scalding-core/src/main/scala/com/twitter/scalding/bdd/TypedPipeOperationsConversions.scala b/scalding-core/src/main/scala/com/twitter/scalding/bdd/TypedPipeOperationsConversions.scala index 409019f1fb..50926261c1 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/bdd/TypedPipeOperationsConversions.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/bdd/TypedPipeOperationsConversions.scala @@ -26,7 +26,7 @@ trait TypedPipeOperationsConversions { override def apply(pipes: List[TypedPipe[_]]): TypedPipe[TypeOut] = { assertPipeSize(pipes, 2) op( - pipes(0).asInstanceOf[TypedPipe[TypeIn1]], // linter:ignore + pipes(0).asInstanceOf[TypedPipe[TypeIn1]], // linter:disable pipes(1).asInstanceOf[TypedPipe[TypeIn2]]) } } @@ -35,7 +35,7 @@ trait TypedPipeOperationsConversions { override def apply(pipes: List[TypedPipe[_]]): TypedPipe[TypeOut] = { assertPipeSize(pipes, 3) op( - pipes(0).asInstanceOf[TypedPipe[TypeIn1]], // linter:ignore + pipes(0).asInstanceOf[TypedPipe[TypeIn1]], // linter:disable pipes(1).asInstanceOf[TypedPipe[TypeIn2]], pipes(2).asInstanceOf[TypedPipe[TypeIn3]]) } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/estimation/memory/MemoryEstimatorStepStrategy.scala b/scalding-core/src/main/scala/com/twitter/scalding/estimation/memory/MemoryEstimatorStepStrategy.scala index 6ce1af2c61..d6ab7deec3 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/estimation/memory/MemoryEstimatorStepStrategy.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/estimation/memory/MemoryEstimatorStepStrategy.scala @@ -14,7 +14,8 @@ object MemoryEstimatorStepStrategy extends FlowStepStrategy[JobConf] { private val LOG = LoggerFactory.getLogger(this.getClass) - implicit val estimatorMonoid = new FallbackEstimatorMonoid[MemoryEstimate] + implicit val estimatorMonoid: Monoid[Estimator[MemoryEstimate]] = + new FallbackEstimatorMonoid[MemoryEstimate] /** * Make memory estimate, possibly overriding explicitly-set memory settings, diff --git a/scalding-core/src/main/scala/com/twitter/scalding/macros/impl/CaseClassBasedSetterImpl.scala b/scalding-core/src/main/scala/com/twitter/scalding/macros/impl/CaseClassBasedSetterImpl.scala index 60f9065057..8bbdc29884 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/macros/impl/CaseClassBasedSetterImpl.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/macros/impl/CaseClassBasedSetterImpl.scala @@ -39,7 +39,7 @@ object CaseClassBasedSetterImpl { */ def setTree(value: Tree, offset: Int): Tree } - case class PrimitiveSetter(tpe: Type) extends SetterBuilder { + final case class PrimitiveSetter(tpe: Type) extends SetterBuilder { def columns = 1 def setTree(value: Tree, offset: Int) = fsetter.from(c)(tpe, offset, container, value) match { case Success(tree) => tree @@ -51,7 +51,7 @@ object CaseClassBasedSetterImpl { def columns = 1 def setTree(value: Tree, offset: Int) = fsetter.default(c)(offset, container, value) } - case class OptionSetter(inner: SetterBuilder) extends SetterBuilder { + final case class OptionSetter(inner: SetterBuilder) extends SetterBuilder { def columns = inner.columns def setTree(value: Tree, offset: Int) = { val someVal = newTermName(c.fresh("someVal")) @@ -64,7 +64,7 @@ object CaseClassBasedSetterImpl { }""" } } - case class CaseClassSetter(members: Vector[(Tree => Tree, SetterBuilder)]) extends SetterBuilder { + final case class CaseClassSetter(members: Vector[(Tree => Tree, SetterBuilder)]) extends SetterBuilder { val columns = members.map(_._2.columns).sum def setTree(value: Tree, offset: Int) = { val setters = members.scanLeft((offset, Option.empty[Tree])) { diff --git a/scalding-core/src/main/scala/com/twitter/scalding/macros/impl/FieldsProviderImpl.scala b/scalding-core/src/main/scala/com/twitter/scalding/macros/impl/FieldsProviderImpl.scala index ca037641d7..a5c1af7597 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/macros/impl/FieldsProviderImpl.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/macros/impl/FieldsProviderImpl.scala @@ -93,11 +93,12 @@ object FieldsProviderImpl { case tpe if tpe =:= typeOf[Float] => true case tpe if tpe =:= typeOf[Double] => true case tpe if tpe =:= typeOf[String] => true - case tpe => - optionInner(c)(tpe) match { - case Some(t) => isNumbered(t) - case None => false - } + case tpe => optionInner(c)(tpe) match { // linter:disable:UseOptionExistsNotPatMatch + case Some(t) => + // we need this match style to do tailrec + isNumbered(t) + case None => false + } } object FieldBuilder { @@ -119,16 +120,16 @@ object FieldsProviderImpl { def columnTypes: Vector[Tree] def names: Vector[String] } - case class Primitive(name: String, tpe: Type) extends FieldBuilder { + final case class Primitive(name: String, tpe: Type) extends FieldBuilder { def columnTypes = Vector(q"""_root_.scala.Predef.classOf[$tpe]""") def names = Vector(name) } - case class OptionBuilder(of: FieldBuilder) extends FieldBuilder { + final case class OptionBuilder(of: FieldBuilder) extends FieldBuilder { // Options just use Object as the type, due to the way cascading works on number types def columnTypes = of.columnTypes.map(_ => q"""_root_.scala.Predef.classOf[_root_.java.lang.Object]""") def names = of.names } - case class CaseClassBuilder(prefix: String, members: Vector[FieldBuilder]) extends FieldBuilder { + final case class CaseClassBuilder(prefix: String, members: Vector[FieldBuilder]) extends FieldBuilder { def columnTypes = members.flatMap(_.columnTypes) def names = for { member <- members @@ -163,7 +164,7 @@ object FieldsProviderImpl { .declarations .collect { case m: MethodSymbol if m.isCaseAccessor => m } .map { accessorMethod => - val fieldName = accessorMethod.name.toTermName.toString + val fieldName = accessorMethod.name.toString val fieldType = accessorMethod.returnType.asSeenFrom(outerTpe, outerTpe.typeSymbol.asClass) (fieldType, fieldName) }.toVector diff --git a/scalding-core/src/main/scala/com/twitter/scalding/macros/impl/TupleConverterImpl.scala b/scalding-core/src/main/scala/com/twitter/scalding/macros/impl/TupleConverterImpl.scala index f8ea670d2b..648ada6529 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/macros/impl/TupleConverterImpl.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/macros/impl/TupleConverterImpl.scala @@ -52,11 +52,11 @@ object TupleConverterImpl { def columns: Int def applyTree(offset: Int): Tree } - case class PrimitiveBuilder(primitiveGetter: Int => Tree) extends ConverterBuilder { + final case class PrimitiveBuilder(primitiveGetter: Int => Tree) extends ConverterBuilder { def columns = 1 def applyTree(offset: Int) = primitiveGetter(offset) } - case class OptionBuilder(evidentCol: Int, of: ConverterBuilder) extends ConverterBuilder { + final case class OptionBuilder(evidentCol: Int, of: ConverterBuilder) extends ConverterBuilder { def columns = of.columns def applyTree(offset: Int) = { val testIdx = offset + evidentCol @@ -64,7 +64,7 @@ object TupleConverterImpl { else Some(${of.applyTree(offset)})""" } } - case class CaseClassBuilder(tpe: Type, members: Vector[ConverterBuilder]) extends ConverterBuilder { + final case class CaseClassBuilder(tpe: Type, members: Vector[ConverterBuilder]) extends ConverterBuilder { val columns = members.map(_.columns).sum def applyTree(offset: Int) = { val trees = members.scanLeft((offset, Option.empty[Tree])) { diff --git a/scalding-core/src/main/scala/com/twitter/scalding/mathematics/Matrix2.scala b/scalding-core/src/main/scala/com/twitter/scalding/mathematics/Matrix2.scala index afb3aa1d0c..8195ba22ac 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/mathematics/Matrix2.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/mathematics/Matrix2.scala @@ -242,7 +242,7 @@ class DefaultMatrixJoiner(sizeRatioThreshold: Long) extends MatrixJoiner2 { /** * Infinite column vector - only for intermediate computations */ -case class OneC[R, V](implicit override val rowOrd: Ordering[R]) extends Matrix2[R, Unit, V] { +final case class OneC[R, V](implicit override val rowOrd: Ordering[R]) extends Matrix2[R, Unit, V] { override val sizeHint: SizeHint = FiniteHint(Long.MaxValue, 1) override def colOrd = Ordering[Unit] def transpose = OneR() @@ -253,7 +253,7 @@ case class OneC[R, V](implicit override val rowOrd: Ordering[R]) extends Matrix2 /** * Infinite row vector - only for intermediate computations */ -case class OneR[C, V](implicit override val colOrd: Ordering[C]) extends Matrix2[Unit, C, V] { +final case class OneR[C, V](implicit override val colOrd: Ordering[C]) extends Matrix2[Unit, C, V] { override val sizeHint: SizeHint = FiniteHint(1, Long.MaxValue) override def rowOrd = Ordering[Unit] def transpose = OneC() @@ -269,7 +269,7 @@ case class OneR[C, V](implicit override val colOrd: Ordering[C]) extends Matrix2 * @param ring * @param expressions a HashMap of common subtrees; None if possibly not optimal (did not go through optimize), Some(...) with a HashMap that was created in optimize */ -case class Product[R, C, C2, V](left: Matrix2[R, C, V], +final case class Product[R, C, C2, V](left: Matrix2[R, C, V], right: Matrix2[C, C2, V], ring: Ring[V], expressions: Option[Map[Matrix2[R, C2, V], TypedPipe[(R, C2, V)]]] = None)(implicit val joiner: MatrixJoiner2) extends Matrix2[R, C2, V] { @@ -342,13 +342,10 @@ case class Product[R, C, C2, V](left: Matrix2[R, C, V], override lazy val toTypedPipe: TypedPipe[(R, C2, V)] = { expressions match { - case Some(m) => m.get(this) match { - case Some(pipe) => pipe - case None => { - val result = computePipe() - m.put(this, result) - result - } + case Some(m) => m.get(this).getOrElse { + val result = computePipe() + m.put(this, result) + result } case None => optimizedSelf.toTypedPipe } @@ -393,7 +390,7 @@ case class Product[R, C, C2, V](left: Matrix2[R, C, V], } } -case class Sum[R, C, V](left: Matrix2[R, C, V], right: Matrix2[R, C, V], mon: Monoid[V]) extends Matrix2[R, C, V] { +final case class Sum[R, C, V](left: Matrix2[R, C, V], right: Matrix2[R, C, V], mon: Monoid[V]) extends Matrix2[R, C, V] { def collectAddends(sum: Sum[R, C, V]): List[TypedPipe[(R, C, V)]] = { def getLiteral(mat: Matrix2[R, C, V]): TypedPipe[(R, C, V)] = { mat match { @@ -452,7 +449,7 @@ case class Sum[R, C, V](left: Matrix2[R, C, V], right: Matrix2[R, C, V], mon: Mo }.reduce(_ ++ _).sum) } -case class HadamardProduct[R, C, V](left: Matrix2[R, C, V], +final case class HadamardProduct[R, C, V](left: Matrix2[R, C, V], right: Matrix2[R, C, V], ring: Ring[V]) extends Matrix2[R, C, V] { @@ -485,7 +482,7 @@ case class HadamardProduct[R, C, V](left: Matrix2[R, C, V], implicit def withOrderedSerialization: Ordering[(R, C)] = OrderedSerialization2.maybeOrderedSerialization2(rowOrd, colOrd) } -case class MatrixLiteral[R, C, V](override val toTypedPipe: TypedPipe[(R, C, V)], +final case class MatrixLiteral[R, C, V](override val toTypedPipe: TypedPipe[(R, C, V)], override val sizeHint: SizeHint)(implicit override val rowOrd: Ordering[R], override val colOrd: Ordering[C]) extends Matrix2[R, C, V] { @@ -566,7 +563,7 @@ trait Scalar2[V] extends Serializable { // TODO: FunctionMatrix[R,C,V](fn: (R,C) => V) and a Literal scalar is just: FuctionMatrix[Unit, Unit, V]({ (_, _) => v }) } -case class ValuePipeScalar[V](override val value: ValuePipe[V]) extends Scalar2[V] +final case class ValuePipeScalar[V](override val value: ValuePipe[V]) extends Scalar2[V] object Scalar2 { // implicits cannot share names diff --git a/scalding-core/src/main/scala/com/twitter/scalding/mathematics/SizeHint.scala b/scalding-core/src/main/scala/com/twitter/scalding/mathematics/SizeHint.scala index a164be7650..8fc4e8fcae 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/mathematics/SizeHint.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/mathematics/SizeHint.scala @@ -60,7 +60,7 @@ case object NoClue extends SizeHint { def transpose = NoClue } -case class FiniteHint(rows: BigInt = -1L, cols: BigInt = -1L) extends SizeHint { +final case class FiniteHint(rows: BigInt = -1L, cols: BigInt = -1L) extends SizeHint { def *(other: SizeHint) = { other match { case NoClue => NoClue @@ -93,7 +93,7 @@ case class FiniteHint(rows: BigInt = -1L, cols: BigInt = -1L) extends SizeHint { } // sparsity is the fraction of the rows and columns that are expected to be present -case class SparseHint(sparsity: Double, rows: BigInt, cols: BigInt) extends SizeHint { +final case class SparseHint(sparsity: Double, rows: BigInt, cols: BigInt) extends SizeHint { def *(other: SizeHint): SizeHint = { other match { case NoClue => NoClue diff --git a/scalding-core/src/main/scala/com/twitter/scalding/mathematics/TypedSimilarity.scala b/scalding-core/src/main/scala/com/twitter/scalding/mathematics/TypedSimilarity.scala index eb3ae6c14e..593ae42aaf 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/mathematics/TypedSimilarity.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/mathematics/TypedSimilarity.scala @@ -34,10 +34,10 @@ case class Edge[+N, +E](from: N, to: N, data: E) { } abstract sealed trait Degree { val degree: Int } -case class InDegree(override val degree: Int) extends Degree -case class OutDegree(override val degree: Int) extends Degree -case class Weight(weight: Double) -case class L2Norm(norm: Double) +final case class InDegree(override val degree: Int) extends Degree +final case class OutDegree(override val degree: Int) extends Degree +final case class Weight(weight: Double) +final case class L2Norm(norm: Double) object GraphOperations extends Serializable { /** diff --git a/scalding-core/src/main/scala/com/twitter/scalding/reducer_estimation/ReducerEstimatorStepStrategy.scala b/scalding-core/src/main/scala/com/twitter/scalding/reducer_estimation/ReducerEstimatorStepStrategy.scala index 5135ac5fd6..4ac7139dc2 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/reducer_estimation/ReducerEstimatorStepStrategy.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/reducer_estimation/ReducerEstimatorStepStrategy.scala @@ -13,7 +13,8 @@ object ReducerEstimatorStepStrategy extends FlowStepStrategy[JobConf] { private val LOG = LoggerFactory.getLogger(this.getClass) - implicit val estimatorMonoid = new FallbackEstimatorMonoid[Int] + implicit val estimatorMonoid: Monoid[Estimator[Int]] = + new FallbackEstimatorMonoid[Int] /** * Make reducer estimate, possibly overriding explicitly-set numReducers, @@ -104,4 +105,4 @@ object ReducerEstimatorStepStrategy extends FlowStepStrategy[JobConf] { } } } -} \ No newline at end of file +} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/serialization/CascadingBinaryComparator.scala b/scalding-core/src/main/scala/com/twitter/scalding/serialization/CascadingBinaryComparator.scala index 2a543110f7..d4ea2bb3ee 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/serialization/CascadingBinaryComparator.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/serialization/CascadingBinaryComparator.scala @@ -23,6 +23,7 @@ import com.twitter.scalding.ExecutionContext.getDesc import java.io.InputStream import java.util.Comparator import scala.util.{ Failure, Success, Try } +import org.slf4j.LoggerFactory /** * This is the type that should be fed to cascading to enable binary comparators @@ -40,11 +41,13 @@ class CascadingBinaryComparator[T](ob: OrderedSerialization[T]) extends Comparat object CascadingBinaryComparator { + private val LOG = LoggerFactory.getLogger(this.getClass) + /** * This method will walk the flowDef and make sure all the * groupBy/cogroups are using a CascadingBinaryComparator */ - private[scalding] def checkForOrderedSerialization[T](flow: Flow[T]): Try[Unit] = { + private[scalding] def checkForOrderedSerialization[T](flow: Flow[T], mode: RequireOrderedSerializationMode): Try[Unit] = { import collection.JavaConverters._ import cascading.pipe._ import com.twitter.scalding.RichPipe @@ -53,8 +56,17 @@ object CascadingBinaryComparator { def reduce(it: TraversableOnce[Try[Unit]]): Try[Unit] = it.find(_.isFailure).getOrElse(Success(())) - def failure(s: String): Try[Unit] = - Failure(new RuntimeException("Cannot verify OrderedSerialization: " + s)) + def failure(s: String): Try[Unit] = { + val message = + s"Cannot verify OrderedSerialization: $s. Add `import com.twitter.scalding.serialization.RequiredBinaryComparators._`" + mode match { + case RequireOrderedSerializationMode.Fail => + Failure(new RuntimeException(message)) + case RequireOrderedSerializationMode.Log => + LOG.warn(message) + Try(()) + } + } def check(s: Splice): Try[Unit] = { val m = s.getKeySelectors.asScala diff --git a/scalding-core/src/main/scala/com/twitter/scalding/serialization/RequiredBinaryComparators.scala b/scalding-core/src/main/scala/com/twitter/scalding/serialization/RequiredBinaryComparators.scala index f757c84ec6..b188c982f8 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/serialization/RequiredBinaryComparators.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/serialization/RequiredBinaryComparators.scala @@ -26,9 +26,9 @@ object RequiredBinaryComparators { */ trait RequiredBinaryComparatorsExecutionApp extends ExecutionApp { implicit def ordSer[T]: OrderedSerialization[T] = macro com.twitter.scalding.serialization.macros.impl.OrderedSerializationProviderImpl[T] - + def requireOrderedSerializationMode: RequireOrderedSerializationMode = RequireOrderedSerializationMode.Fail override def config(inputArgs: Array[String]): (Config, Mode) = { val (conf, m) = super.config(inputArgs) - (conf.setRequireOrderedSerialization(true), m) + (conf.setRequireOrderedSerializationMode(Some(requireOrderedSerializationMode)), m) } } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/serialization/RequiredBinaryComparatorsConfig.scala b/scalding-core/src/main/scala/com/twitter/scalding/serialization/RequiredBinaryComparatorsConfig.scala index 46c30a203d..d14872d6cc 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/serialization/RequiredBinaryComparatorsConfig.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/serialization/RequiredBinaryComparatorsConfig.scala @@ -2,6 +2,13 @@ package com.twitter.scalding.serialization import com.twitter.scalding.{ Config, Job } +sealed trait RequireOrderedSerializationMode +object RequireOrderedSerializationMode { + case object Fail extends RequireOrderedSerializationMode + case object Log extends RequireOrderedSerializationMode +} + trait RequiredBinaryComparatorsConfig extends Job { - override def config = super.config + (Config.ScaldingRequireOrderedSerialization -> "true") + def requireOrderedSerializationMode: RequireOrderedSerializationMode = RequireOrderedSerializationMode.Fail + override def config = super.config + (Config.ScaldingRequireOrderedSerialization -> requireOrderedSerializationMode.toString) } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/FlatMappedFn.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/FlatMappedFn.scala deleted file mode 100644 index 100e71d7e6..0000000000 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/FlatMappedFn.scala +++ /dev/null @@ -1,106 +0,0 @@ -/* -Copyright 2013 Twitter, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ -package com.twitter.scalding.typed - -import java.io.Serializable - -import com.twitter.scalding.TupleConverter -import cascading.tuple.TupleEntry - -/** - * This is one of 4 core, non composed operations: - * identity - * filter - * map - * flatMap - */ -sealed trait FlatMapping[-A, +B] extends java.io.Serializable -object FlatMapping { - def filter[A](fn: A => Boolean): FlatMapping[A, A] = - Filter[A, A](fn, implicitly) - - def filterKeys[K, V](fn: K => Boolean): FlatMapping[(K, V), (K, V)] = - filter { kv => fn(kv._1) } - - case class Identity[A, B](ev: A =:= B) extends FlatMapping[A, B] - case class Filter[A, B](fn: A => Boolean, ev: A =:= B) extends FlatMapping[A, B] - case class Map[A, B](fn: A => B) extends FlatMapping[A, B] - case class FlatM[A, B](fn: A => TraversableOnce[B]) extends FlatMapping[A, B] -} - -/** - * This is a composition of one or more FlatMappings - */ -sealed trait FlatMappedFn[-A, +B] extends (A => TraversableOnce[B]) with java.io.Serializable { - import FlatMappedFn._ - - final def runAfter[Z](fn: FlatMapping[Z, A]): FlatMappedFn[Z, B] = this match { - case Single(FlatMapping.Identity(_)) => Single(fn.asInstanceOf[FlatMapping[Z, B]]) // since we have A =:= B, we know this cast is safe - case notId => fn match { - case FlatMapping.Identity(ev) => this.asInstanceOf[FlatMappedFn[Z, B]] // we have Z =:= A we know this cast is safe - case notIdFn => Series(notIdFn, notId) // only make a Series without either side being identity - } - } - - /** - * We interpret this composition once to minimize pattern matching when we execute - */ - private[this] val toFn: A => TraversableOnce[B] = { - import FlatMapping._ - - def loop[A1, B1](fn: FlatMappedFn[A1, B1]): A1 => TraversableOnce[B1] = fn match { - case Single(Identity(ev)) => - { (t: A1) => Iterator.single(t.asInstanceOf[B1]) } // A1 =:= B1 - case Single(Filter(f, ev)) => - { (t: A1) => if (f(t)) Iterator.single(t.asInstanceOf[B1]) else Iterator.empty } // A1 =:= B1 - case Single(Map(f)) => f.andThen(Iterator.single) - case Single(FlatM(f)) => f - case Series(Identity(ev), rest) => loop(rest).asInstanceOf[A1 => TraversableOnce[B1]] // we know that A1 =:= C - case Series(Filter(f, ev), rest) => - val next = loop(rest).asInstanceOf[A1 => TraversableOnce[B1]] // A1 =:= C - - { (t: A1) => if (f(t)) next(t) else Iterator.empty } - case Series(Map(f), rest) => - val next = loop(rest) - f.andThen(next) - case Series(FlatM(f), rest) => - val next = loop(rest) - f.andThen(_.flatMap(next)) - } - - loop(this) - } - - def apply(a: A): TraversableOnce[B] = toFn(a) -} - -object FlatMappedFn { - import FlatMapping._ - - def asId[A, B](f: FlatMappedFn[A, B]): Option[(_ >: A) =:= (_ <: B)] = f match { - case Single(i@Identity(_)) => Some(i.ev) - case _ => None - } - - def asFilter[A, B](f: FlatMappedFn[A, B]): Option[(A => Boolean, (_ >: A) =:= (_ <: B))] = f match { - case Single(filter@Filter(_, _)) => Some((filter.fn, filter.ev)) - case _ => None - } - - def identity[T]: FlatMappedFn[T, T] = Single(FlatMapping.Identity[T, T](implicitly[T =:= T])) - case class Single[A, B](fn: FlatMapping[A, B]) extends FlatMappedFn[A, B] - case class Series[A, B, C](first: FlatMapping[A, B], next: FlatMappedFn[B, C]) extends FlatMappedFn[A, C] -} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/Grouped.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/Grouped.scala index 83979a91dd..b355522b19 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/Grouped.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/Grouped.scala @@ -17,6 +17,8 @@ package com.twitter.scalding.typed import com.twitter.algebird.Semigroup import com.twitter.algebird.mutable.PriorityQueueMonoid +import com.twitter.scalding.typed.functions.{ Constant, EmptyGuard } +import com.twitter.scalding.typed.functions.ComposedFunctions.ComposedMapGroup import scala.collection.JavaConverters._ object CoGroupable { @@ -24,10 +26,7 @@ object CoGroupable { * This is the default empty join function needed for CoGroupable and HashJoinable */ def castingJoinFunction[V]: (Any, Iterator[Any], Seq[Iterable[Any]]) => Iterator[V] = - { (k, iter, empties) => - assert(empties.isEmpty, "this join function should never be called with non-empty right-most") - iter.asInstanceOf[Iterator[V]] - } + Joiner.CastingWideJoin[V]() } /** @@ -90,7 +89,7 @@ object CoGrouped { go(list) } - case class Pair[K, A, B, C]( + final case class Pair[K, A, B, C]( larger: CoGroupable[K, A], smaller: CoGroupable[K, B], fn: (K, Iterator[A], Iterable[B]) => Iterator[C]) extends CoGrouped[K, C] { @@ -114,7 +113,7 @@ object CoGrouped { val joinedLeft = jf(k, leftMost, leftSeq) // Only do this once, for all calls to iterator below - val smallerHead = rightSeq.head + val smallerHead = rightSeq.head // linter:disable:UndesirableTypeInference val smallerTail = rightSeq.tail // TODO: it might make sense to cache this in memory as an IndexedSeq and not // recompute it on every value for the left if the smallerJf is non-trivial @@ -129,7 +128,7 @@ object CoGrouped { } } - case class WithReducers[K, V](on: CoGrouped[K, V], reds: Int) extends CoGrouped[K, V] { + final case class WithReducers[K, V](on: CoGrouped[K, V], reds: Int) extends CoGrouped[K, V] { def inputs = on.inputs def reducers = Some(reds) def keyOrdering = on.keyOrdering @@ -137,7 +136,7 @@ object CoGrouped { def descriptions: Seq[String] = on.descriptions } - case class WithDescription[K, V]( + final case class WithDescription[K, V]( on: CoGrouped[K, V], description: String) extends CoGrouped[K, V] { @@ -148,7 +147,7 @@ object CoGrouped { def descriptions: Seq[String] = on.descriptions :+ description } - case class FilterKeys[K, V](on: CoGrouped[K, V], fn: K => Boolean) extends CoGrouped[K, V] { + final case class FilterKeys[K, V](on: CoGrouped[K, V], fn: K => Boolean) extends CoGrouped[K, V] { val inputs = on.inputs.map(_.filterKeys(fn)) def reducers = on.reducers def keyOrdering = on.keyOrdering @@ -156,7 +155,7 @@ object CoGrouped { def descriptions: Seq[String] = on.descriptions } - case class MapGroup[K, V1, V2](on: CoGrouped[K, V1], fn: (K, Iterator[V1]) => Iterator[V2]) extends CoGrouped[K, V2] { + final case class MapGroup[K, V1, V2](on: CoGrouped[K, V1], fn: (K, Iterator[V1]) => Iterator[V2]) extends CoGrouped[K, V2] { def inputs = on.inputs def reducers = on.reducers def descriptions: Seq[String] = on.descriptions @@ -263,9 +262,8 @@ object Grouped { def apply[K, V](pipe: TypedPipe[(K, V)])(implicit ordering: Ordering[K]): Grouped[K, V] = IdentityReduce(ordering, pipe, None, Nil) - def addEmptyGuard[K, V1, V2](fn: (K, Iterator[V1]) => Iterator[V2]): (K, Iterator[V1]) => Iterator[V2] = { - (key: K, iter: Iterator[V1]) => if (iter.nonEmpty) fn(key, iter) else Iterator.empty - } + def addEmptyGuard[K, V1, V2](fn: (K, Iterator[V1]) => Iterator[V2]): (K, Iterator[V1]) => Iterator[V2] = + EmptyGuard(fn) } /** @@ -306,7 +304,7 @@ sealed trait ReduceStep[K, V1, V2] extends KeyedPipe[K] { def toTypedPipe: TypedPipe[(K, V2)] = TypedPipe.ReduceStepPipe(this) } -case class IdentityReduce[K, V1]( +final case class IdentityReduce[K, V1]( override val keyOrdering: Ordering[K], override val mapped: TypedPipe[(K, V1)], override val reducers: Option[Int], @@ -364,7 +362,7 @@ case class IdentityReduce[K, V1]( override def joinFunction = CoGroupable.castingJoinFunction[V1] } -case class UnsortedIdentityReduce[K, V1]( +final case class UnsortedIdentityReduce[K, V1]( override val keyOrdering: Ordering[K], override val mapped: TypedPipe[(K, V1)], override val reducers: Option[Int], @@ -380,7 +378,7 @@ case class UnsortedIdentityReduce[K, V1]( override def bufferedTake(n: Int) = if (n < 1) { // This means don't take anything, which is legal, but strange - filterKeys(_ => false) + filterKeys(Constant(false)) } else if (n == 1) { head } else { @@ -431,7 +429,7 @@ case class UnsortedIdentityReduce[K, V1]( override def joinFunction = CoGroupable.castingJoinFunction[V1] } -case class IdentityValueSortedReduce[K, V1]( +final case class IdentityValueSortedReduce[K, V1]( override val keyOrdering: Ordering[K], override val mapped: TypedPipe[(K, V1)], valueSort: Ordering[_ >: V1], @@ -454,10 +452,9 @@ case class IdentityValueSortedReduce[K, V1]( // copy fails to get the types right, :/ IdentityValueSortedReduce[K, V1](keyOrdering, mapped.filterKeys(fn), valueSort, reducers, descriptions) - override def mapGroup[V3](fn: (K, Iterator[V1]) => Iterator[V3]) = { + override def mapGroup[V3](fn: (K, Iterator[V1]) => Iterator[V3]) = // Only pass non-Empty iterators to subsequent functions ValueSortedReduce[K, V1, V3](keyOrdering, mapped, valueSort, Grouped.addEmptyGuard(fn), reducers, descriptions) - } /** * This does the partial heap sort followed by take in memory on the mappers @@ -467,7 +464,7 @@ case class IdentityValueSortedReduce[K, V1]( override def bufferedTake(n: Int): SortedGrouped[K, V1] = if (n <= 0) { // This means don't take anything, which is legal, but strange - filterKeys(_ => false) + filterKeys(Constant(false)) } else { implicit val mon: PriorityQueueMonoid[V1] = new PriorityQueueMonoid[V1](n)(valueSort.asInstanceOf[Ordering[V1]]) // Do the heap-sort on the mappers: @@ -491,7 +488,7 @@ case class IdentityValueSortedReduce[K, V1]( else mapValueStream(_.take(n)) } -case class ValueSortedReduce[K, V1, V2]( +final case class ValueSortedReduce[K, V1, V2]( override val keyOrdering: Ordering[K], override val mapped: TypedPipe[(K, V1)], valueSort: Ordering[_ >: V1], @@ -520,19 +517,13 @@ case class ValueSortedReduce[K, V1, V2]( ValueSortedReduce[K, V1, V2](keyOrdering, mapped.filterKeys(fn), valueSort, reduceFn, reducers, descriptions) override def mapGroup[V3](fn: (K, Iterator[V2]) => Iterator[V3]) = { - // don't make a closure - val localRed = reduceFn - val newReduce = { (k: K, iter: Iterator[V1]) => - val step1 = localRed(k, iter) - // Only pass non-Empty iterators to subsequent functions - Grouped.addEmptyGuard(fn)(k, step1) - } + val newReduce = ComposedMapGroup(reduceFn, fn) ValueSortedReduce[K, V1, V3]( keyOrdering, mapped, valueSort, newReduce, reducers, descriptions) } } -case class IteratorMappedReduce[K, V1, V2]( +final case class IteratorMappedReduce[K, V1, V2]( override val keyOrdering: Ordering[K], override val mapped: TypedPipe[(K, V1)], reduceFn: (K, Iterator[V1]) => Iterator[V2], @@ -557,12 +548,7 @@ case class IteratorMappedReduce[K, V1, V2]( override def mapGroup[V3](fn: (K, Iterator[V2]) => Iterator[V3]) = { // don't make a closure - val localRed = reduceFn - val newReduce = { (k: K, iter: Iterator[V1]) => - val step1 = localRed(k, iter) - // Only pass non-Empty iterators to subsequent functions - Grouped.addEmptyGuard(fn)(k, step1) - } + val newReduce = ComposedMapGroup(reduceFn, fn) copy(reduceFn = newReduce) } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/Joiner.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/Joiner.scala index f0ee79e8db..823dd16cde 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/Joiner.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/Joiner.scala @@ -18,36 +18,170 @@ package com.twitter.scalding.typed import com.twitter.scalding._ object Joiner extends java.io.Serializable { - def toCogroupJoiner2[K, V, U, R](hashJoiner: (K, V, Iterable[U]) => Iterator[R]): (K, Iterator[V], Iterable[U]) => Iterator[R] = { - (k: K, itv: Iterator[V], itu: Iterable[U]) => - itv.flatMap { hashJoiner(k, _, itu) } + + type JoinFn[K, V, U, R] = (K, Iterator[V], Iterable[U]) => Iterator[R] + type HashJoinFn[K, V, U, R] = (K, V, Iterable[U]) => Iterator[R] + + def toCogroupJoiner2[K, V, U, R](hashJoiner: (K, V, Iterable[U]) => Iterator[R]): JoinFn[K, V, U, R] = + JoinFromHashJoin(hashJoiner) + + def hashInner2[K, V, U]: HashJoinFn[K, V, U, (V, U)] = + HashInner() + + def hashLeft2[K, V, U]: HashJoinFn[K, V, U, (V, Option[U])] = + HashLeft() + + def inner2[K, V, U]: JoinFn[K, V, U, (V, U)] = + InnerJoin() + + def asOuter[U](it: Iterator[U]): Iterator[Option[U]] = + if (it.isEmpty) Iterator.single(None) + else it.map(Some(_)) + + def outer2[K, V, U]: JoinFn[K, V, U, (Option[V], Option[U])] = + OuterJoin() + + def left2[K, V, U]: JoinFn[K, V, U, (V, Option[U])] = + LeftJoin() + + def right2[K, V, U]: JoinFn[K, V, U, (Option[V], U)] = + RightJoin() + + /** + * Optimizers want to match on the kinds of joins we are doing. + * This gives them that ability + */ + sealed abstract class HashJoinFunction[K, V, U, R] extends Function3[K, V, Iterable[U], Iterator[R]] + + final case class HashInner[K, V, U]() extends HashJoinFunction[K, V, U, (V, U)] { + def apply(k: K, v: V, u: Iterable[U]) = u.iterator.map((v, _)) + } + final case class HashLeft[K, V, U]() extends HashJoinFunction[K, V, U, (V, Option[U])] { + def apply(k: K, v: V, u: Iterable[U]) = asOuter(u.iterator).map((v, _)) + } + final case class FilteredHashJoin[K, V1, V2, R](jf: HashJoinFunction[K, V1, V2, R], fn: ((K, R)) => Boolean) extends HashJoinFunction[K, V1, V2, R] { + def apply(k: K, left: V1, right: Iterable[V2]) = + jf.apply(k, left, right).filter { r => fn((k, r)) } + } + final case class MappedHashJoin[K, V1, V2, R, R1](jf: HashJoinFunction[K, V1, V2, R], fn: R => R1) extends HashJoinFunction[K, V1, V2, R1] { + def apply(k: K, left: V1, right: Iterable[V2]) = + jf.apply(k, left, right).map(fn) + } + final case class FlatMappedHashJoin[K, V1, V2, R, R1](jf: HashJoinFunction[K, V1, V2, R], fn: R => TraversableOnce[R1]) extends HashJoinFunction[K, V1, V2, R1] { + def apply(k: K, left: V1, right: Iterable[V2]) = + jf.apply(k, left, right).flatMap(fn) } - def hashInner2[K, V, U] = { (key: K, v: V, itu: Iterable[U]) => itu.iterator.map { (v, _) } } - def hashLeft2[K, V, U] = { (key: K, v: V, itu: Iterable[U]) => asOuter(itu.iterator).map { (v, _) } } + sealed abstract class JoinFunction[K, V1, V2, R] extends Function3[K, Iterator[V1], Iterable[V2], Iterator[R]] - def inner2[K, V, U] = { (key: K, itv: Iterator[V], itu: Iterable[U]) => - itv.flatMap { v => itu.map { u => (v, u) } } + final case class InnerJoin[K, V1, V2]() extends JoinFunction[K, V1, V2, (V1, V2)] { + def apply(k: K, left: Iterator[V1], right: Iterable[V2]): Iterator[(V1, V2)] = + left.flatMap { v1 => right.iterator.map((v1, _)) } } - def asOuter[U](it: Iterator[U]): Iterator[Option[U]] = { - if (it.isEmpty) { - Iterator(None) - } else { - it.map { Some(_) } - } + final case class LeftJoin[K, V1, V2]() extends JoinFunction[K, V1, V2, (V1, Option[V2])] { + def apply(k: K, left: Iterator[V1], right: Iterable[V2]): Iterator[(V1, Option[V2])] = + left.flatMap { v1 => asOuter(right.iterator).map((v1, _)) } + } + final case class RightJoin[K, V1, V2]() extends JoinFunction[K, V1, V2, (Option[V1], V2)] { + def apply(k: K, left: Iterator[V1], right: Iterable[V2]): Iterator[(Option[V1], V2)] = + asOuter(left).flatMap { v1 => right.iterator.map((v1, _)) } + } + final case class OuterJoin[K, V1, V2]() extends JoinFunction[K, V1, V2, (Option[V1], Option[V2])] { + def apply(k: K, left: Iterator[V1], right: Iterable[V2]): Iterator[(Option[V1], Option[V2])] = + if (left.isEmpty && right.isEmpty) Iterator.empty + else asOuter(left).flatMap { v1 => asOuter(right.iterator).map((v1, _)) } + } + final case class FilteredJoin[K, V1, V2, R](jf: JoinFunction[K, V1, V2, R], fn: ((K, R)) => Boolean) extends JoinFunction[K, V1, V2, R] { + def apply(k: K, left: Iterator[V1], right: Iterable[V2]) = + jf.apply(k, left, right).filter { r => fn((k, r)) } } - def outer2[K, V, U] = { (key: K, itv: Iterator[V], itu: Iterable[U]) => - if (itv.isEmpty && itu.isEmpty) { - Iterator.empty - } else { - asOuter(itv).flatMap { v => asOuter(itu.iterator).map { u => (v, u) } } + final case class MappedJoin[K, V1, V2, R, R1](jf: JoinFunction[K, V1, V2, R], fn: R => R1) extends JoinFunction[K, V1, V2, R1] { + def apply(k: K, left: Iterator[V1], right: Iterable[V2]) = + jf.apply(k, left, right).map(fn) + } + final case class FlatMappedJoin[K, V1, V2, R, R1](jf: JoinFunction[K, V1, V2, R], fn: R => TraversableOnce[R1]) extends JoinFunction[K, V1, V2, R1] { + def apply(k: K, left: Iterator[V1], right: Iterable[V2]) = + jf.apply(k, left, right).flatMap(fn) + } + final case class MappedGroupJoin[K, V1, V2, R, R1](jf: JoinFunction[K, V1, V2, R], fn: (K, Iterator[R]) => Iterator[R1]) extends JoinFunction[K, V1, V2, R1] { + def apply(k: K, left: Iterator[V1], right: Iterable[V2]) = { + val iterr = jf.apply(k, left, right) + if (iterr.isEmpty) Iterator.empty // mapGroup operates on non-empty groups + else fn(k, iterr) } } - def left2[K, V, U] = { (key: K, itv: Iterator[V], itu: Iterable[U]) => - itv.flatMap { v => asOuter(itu.iterator).map { u => (v, u) } } + final case class JoinFromHashJoin[K, V1, V2, R](hj: (K, V1, Iterable[V2]) => Iterator[R]) extends JoinFunction[K, V1, V2, R] { + def apply(k: K, itv: Iterator[V1], itu: Iterable[V2]) = + itv.flatMap(hj(k, _, itu)) } - def right2[K, V, U] = { (key: K, itv: Iterator[V], itu: Iterable[U]) => - asOuter(itv).flatMap { v => itu.map { u => (v, u) } } + + /** + * an inner-like join function is empty definitely if either side is empty + */ + final def isInnerJoinLike[K, V1, V2, R](jf: (K, Iterator[V1], Iterable[V2]) => Iterator[R]): Option[Boolean] = + jf match { + case InnerJoin() => Some(true) + case LeftJoin() => Some(false) + case RightJoin() => Some(false) + case OuterJoin() => Some(false) + case JoinFromHashJoin(hj) => isInnerHashJoinLike(hj) + case FilteredJoin(jf, _) => isInnerJoinLike(jf) + case MappedJoin(jf, _) => isInnerJoinLike(jf) + case FlatMappedJoin(jf, _) => isInnerJoinLike(jf) + case MappedGroupJoin(jf, _) => isInnerJoinLike(jf) + case _ => None + } + /** + * a left-like join function is empty definitely if the left side is empty + */ + final def isLeftJoinLike[K, V1, V2, R](jf: (K, Iterator[V1], Iterable[V2]) => Iterator[R]): Option[Boolean] = + jf match { + case InnerJoin() => Some(true) + case JoinFromHashJoin(hj) => isInnerHashJoinLike(hj) + case LeftJoin() => Some(true) + case RightJoin() => Some(false) + case OuterJoin() => Some(false) + case FilteredJoin(jf, _) => isLeftJoinLike(jf) + case MappedJoin(jf, _) => isLeftJoinLike(jf) + case FlatMappedJoin(jf, _) => isLeftJoinLike(jf) + case MappedGroupJoin(jf, _) => isLeftJoinLike(jf) + case _ => None + } + /** + * a right-like join function is empty definitely if the right side is empty + */ + final def isRightJoinLike[K, V1, V2, R](jf: (K, Iterator[V1], Iterable[V2]) => Iterator[R]): Option[Boolean] = + jf match { + case InnerJoin() => Some(true) + case JoinFromHashJoin(hj) => isInnerHashJoinLike(hj) + case LeftJoin() => Some(false) + case RightJoin() => Some(true) + case OuterJoin() => Some(false) + case FilteredJoin(jf, _) => isRightJoinLike(jf) + case MappedJoin(jf, _) => isRightJoinLike(jf) + case FlatMappedJoin(jf, _) => isRightJoinLike(jf) + case MappedGroupJoin(jf, _) => isRightJoinLike(jf) + case _ => None + } + + /** + * a inner-like hash-join function is empty definitely if either side is empty + */ + final def isInnerHashJoinLike[K, V1, V2, R](jf: (K, V1, Iterable[V2]) => Iterator[R]): Option[Boolean] = + jf match { + case HashInner() => Some(true) + case HashLeft() => Some(false) + case FilteredHashJoin(jf, _) => isInnerHashJoinLike(jf) + case MappedHashJoin(jf, _) => isInnerHashJoinLike(jf) + case FlatMappedHashJoin(jf, _) => isInnerHashJoinLike(jf) + case _ => None + } + + final case class CastingWideJoin[A]() extends Function3[Any, Iterator[Any], Seq[Iterable[Any]], Iterator[A]] { + def apply(k: Any, iter: Iterator[Any], empties: Seq[Iterable[Any]]) = { + assert(empties.isEmpty, "this join function should never be called with non-empty right-most") + iter.asInstanceOf[Iterator[A]] + } } } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/KeyedList.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/KeyedList.scala index 5ce74280d7..2cdb32225c 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/KeyedList.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/KeyedList.scala @@ -23,6 +23,7 @@ import com.twitter.algebird.{ Fold, Semigroup, Ring, Aggregator } import com.twitter.algebird.mutable.PriorityQueueMonoid import com.twitter.scalding._ +import com.twitter.scalding.typed.functions._ object KeyedListLike { /** KeyedListLike items are implicitly convertable to TypedPipe */ @@ -68,7 +69,7 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] if (n < 1) { // This means don't take anything, which is legal, but strange - filterKeys(_ => false) + filterKeys(Constant(false)) } else if (n == 1) { head } else { @@ -123,9 +124,9 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] * Use Algebird Aggregator to do the reduction */ def aggregate[B, C](agg: Aggregator[T, B, C]): This[K, C] = - mapValues[B](agg.prepare(_)) + mapValues[B](AggPrepare(agg)) .sum[B](agg.semigroup) - .mapValues[C](agg.present(_)) + .mapValues[C](AggPresent(agg)) /** * .filter(fn).toTypedPipe == .toTypedPipe.filter(fn) @@ -134,35 +135,27 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] * and out of cascading/hadoop types. */ def filter(fn: ((K, T)) => Boolean): This[K, T] = - mapGroup { (k: K, items: Iterator[T]) => items.filter { t => fn((k, t)) } } + mapGroup(FilterGroup(fn)) /** * flatten the values * Useful after sortedTake, for instance */ def flattenValues[U](implicit ev: T <:< TraversableOnce[U]): This[K, U] = - mapValueStream(_.flatMap { us => us.asInstanceOf[TraversableOnce[U]] }) + flatMapValues(Widen(SubTypes.fromEv(ev))) /** * This is just short hand for mapValueStream(identity), it makes sure the * planner sees that you want to force a shuffle. For expert tuning */ def forceToReducers: This[K, T] = - mapValueStream(identity) + mapValueStream(Identity()) /** * Use this to get the first value encountered. * prefer this to take(1). */ - def head: This[K, T] = sum { - new Semigroup[T] { - override def plus(left: T, right: T) = left - // Don't enumerate every item, just take the first - override def sumOption(to: TraversableOnce[T]): Option[T] = - if (to.isEmpty) None - else Some(to.toIterator.next) - } - } + def head: This[K, T] = sum(HeadSemigroup[T]()) /** * This is a special case of mapValueStream, but can be optimized because it doesn't need @@ -171,21 +164,21 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] * but for Grouped we can avoid resorting to mapValueStream */ def mapValues[V](fn: T => V): This[K, V] = - mapGroup { (_, iter) => iter.map(fn) } + mapGroup(MapGroupMapValues(fn)) /** * Similar to mapValues, but works like flatMap, returning a collection of outputs * for each value input. */ def flatMapValues[V](fn: T => TraversableOnce[V]): This[K, V] = - mapGroup { (_, iter) => iter.flatMap(fn) } + mapGroup(MapGroupFlatMapValues(fn)) /** * Use this when you don't care about the key for the group, * otherwise use mapGroup */ def mapValueStream[V](smfn: Iterator[T] => Iterator[V]): This[K, V] = - mapGroup { (k: K, items: Iterator[T]) => smfn(items) } + mapGroup(MapValueStream(smfn)) /** * Add all items according to the implicit Semigroup @@ -203,7 +196,8 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] * Like the above this can be optimized in some Grouped cases. * If you don't have a commutative operator, use reduceLeft */ - def reduce[U >: T](fn: (U, U) => U): This[K, U] = sum(Semigroup.from(fn)) + def reduce[U >: T](fn: (U, U) => U): This[K, U] = + sum(SemigroupFromFn(fn)) /** * Take the largest k things according to the implicit ordering. @@ -233,41 +227,42 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] sortedTake(k)(Ordering.fromLessThan(lessThan)) /** For each key, Return the product of all the values */ - def product[U >: T](implicit ring: Ring[U]): This[K, U] = reduce(ring.times) + def product[U >: T](implicit ring: Ring[U]): This[K, U] = + sum(SemigroupFromProduct(ring)) /** For each key, count the number of values that satisfy a predicate */ def count(fn: T => Boolean): This[K, Long] = - mapValues { t => if (fn(t)) 1L else 0L }.sum + mapValues(Count(fn)).sum /** For each key, check to see if a predicate is true for all Values*/ def forall(fn: T => Boolean): This[K, Boolean] = - mapValues { fn(_) }.product + mapValues(fn).product /** * For each key, selects all elements except first n ones. */ def drop(n: Int): This[K, T] = - mapValueStream { _.drop(n) } + mapValueStream(Drop(n)) /** * For each key, Drops longest prefix of elements that satisfy the given predicate. */ - def dropWhile(p: (T) => Boolean): This[K, T] = - mapValueStream { _.dropWhile(p) } + def dropWhile(p: T => Boolean): This[K, T] = + mapValueStream(DropWhile(p)) /** * For each key, Selects first n elements. Don't use this if n == 1, head is faster in that case. */ def take(n: Int): This[K, T] = - if (n < 1) filterKeys(_ => false) // just don't keep anything + if (n < 1) filterKeys(Constant(false)) // just don't keep anything else if (n == 1) head - else mapValueStream { _.take(n) } + else mapValueStream(Take(n)) /** * For each key, Takes longest prefix of elements that satisfy the given predicate. */ - def takeWhile(p: (T) => Boolean): This[K, T] = - mapValueStream { _.takeWhile(p) } + def takeWhile(p: T => Boolean): This[K, T] = + mapValueStream(TakeWhile(p)) /** * Folds are composable aggregations that make one pass over the data. @@ -275,22 +270,22 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] * and this method */ def fold[V](f: Fold[T, V]): This[K, V] = - mapValueStream(it => Iterator(f.overTraversable(it))) + mapValueStream(FoldIterator(f)) /** * If the fold depends on the key, use this method to construct * the fold for each key */ def foldWithKey[V](fn: K => Fold[T, V]): This[K, V] = - mapGroup { (k, vs) => Iterator(fn(k).overTraversable(vs)) } + mapGroup(FoldWithKeyIterator(fn)) /** For each key, fold the values. see scala.collection.Iterable.foldLeft */ def foldLeft[B](z: B)(fn: (B, T) => B): This[K, B] = - mapValueStream { stream => Iterator(stream.foldLeft(z)(fn)) } + mapValueStream(FoldLeftIterator(z, fn)) /** For each key, scanLeft the values. see scala.collection.Iterable.scanLeft */ def scanLeft[B](z: B)(fn: (B, T) => B): This[K, B] = - mapValueStream { _.scanLeft(z)(fn) } + mapValueStream(ScanLeftIterator(z, fn)) /** * Similar to reduce but always on the reduce-side (never optimized to mapside), @@ -299,23 +294,24 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] * the old value comes in on the left. */ def reduceLeft[U >: T](fn: (U, U) => U): This[K, U] = - sumLeft[U](Semigroup.from(fn)) + sumLeft[U](SemigroupFromFn(fn)) /** * Semigroups MAY have a faster implementation of sum for iterators, * so prefer using sum/sumLeft to reduce/reduceLeft */ def sumLeft[U >: T](implicit sg: Semigroup[U]): This[K, U] = - mapValueStream[U](Semigroup.sumOption[U](_).iterator) + mapValueStream[U](SumAll(sg)) /** For each key, give the number of values */ - def size: This[K, Long] = mapValues { x => 1L }.sum + def size: This[K, Long] = mapValues(Constant(1L)).sum /** * For each key, give the number of unique values. WARNING: May OOM. * This assumes the values for each key can fit in memory. */ - def distinctSize: This[K, Long] = toSet[T].mapValues(_.size) + def distinctSize: This[K, Long] = + toSet[T].mapValues(SizeOfSet()) /** * For each key, remove duplicate values. WARNING: May OOM. @@ -330,7 +326,7 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] * You really should try to ask why you need all the values, and if you * want to do some custom reduction, do it in mapGroup or mapValueStream */ - def toList: This[K, List[T]] = mapValues { List(_) }.sum + def toList: This[K, List[T]] = mapValues(ToList[T]()).sum /** * AVOID THIS IF POSSIBLE * Same risks apply here as to toList: you may OOM. See toList. @@ -339,23 +335,23 @@ trait KeyedListLike[K, +T, +This[K, +T] <: KeyedListLike[K, T, This]] * but Set is invariant. See: * http://stackoverflow.com/questions/676615/why-is-scalas-immutable-set-not-covariant-in-its-type */ - def toSet[U >: T]: This[K, Set[U]] = mapValues { Set[U](_) }.sum + def toSet[U >: T]: This[K, Set[U]] = mapValues(ToSet[U]()).sum /** For each key, give the maximum value*/ def max[B >: T](implicit cmp: Ordering[B]): This[K, T] = - reduce(cmp.max).asInstanceOf[This[K, T]] + reduce(MaxOrd[T, B](cmp)) /** For each key, give the maximum value by some function*/ def maxBy[B](fn: T => B)(implicit cmp: Ordering[B]): This[K, T] = - reduce(Ordering.by(fn).max) + reduce(MaxOrdBy(fn, cmp)) /** For each key, give the minimum value*/ def min[B >: T](implicit cmp: Ordering[B]): This[K, T] = - reduce(cmp.min).asInstanceOf[This[K, T]] + reduce(MinOrd[T, B](cmp)) /** For each key, give the minimum value by some function*/ def minBy[B](fn: T => B)(implicit cmp: Ordering[B]): This[K, T] = - reduce(Ordering.by(fn).min) + reduce(MinOrdBy(fn, cmp)) /** Convert to a TypedPipe and only keep the keys */ def keys: TypedPipe[K] = toTypedPipe.keys diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/NoStackAndThen.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/NoStackAndThen.scala index 3b96d669ee..b935e1a91d 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/NoStackAndThen.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/NoStackAndThen.scala @@ -53,10 +53,10 @@ object NoStackAndThen { def apply[A, B](fn: A => B): NoStackAndThen[A, B] = WithStackTrace(NoStackWrap(fn), buildStackEntry) private sealed trait ReversedStack[-A, +B] - private case class EmptyStack[-A, +B](fn: A => B) extends ReversedStack[A, B] - private case class NonEmpty[-A, B, +C](head: A => B, rest: ReversedStack[B, C]) extends ReversedStack[A, C] + private final case class EmptyStack[-A, +B](fn: A => B) extends ReversedStack[A, B] + private final case class NonEmpty[-A, B, +C](head: A => B, rest: ReversedStack[B, C]) extends ReversedStack[A, C] - private[scalding] case class WithStackTrace[A, B](inner: NoStackAndThen[A, B], stackEntry: Array[StackTraceElement]) extends NoStackAndThen[A, B] { + private[scalding] final case class WithStackTrace[A, B](inner: NoStackAndThen[A, B], stackEntry: Array[StackTraceElement]) extends NoStackAndThen[A, B] { override def apply(a: A): B = inner(a) override def andThen[C](fn: B => C): NoStackAndThen[A, C] = @@ -67,11 +67,11 @@ object NoStackAndThen { } // Just wraps a function - private case class NoStackWrap[A, B](fn: A => B) extends NoStackAndThen[A, B] { + private final case class NoStackWrap[A, B](fn: A => B) extends NoStackAndThen[A, B] { def apply(a: A) = fn(a) } // This is the defunctionalized andThen - private case class NoStackMore[A, B, C](first: NoStackAndThen[A, B], andThenFn: (B) => C) extends NoStackAndThen[A, C] { + private final case class NoStackMore[A, B, C](first: NoStackAndThen[A, B], andThenFn: (B) => C) extends NoStackAndThen[A, C] { /* * scala cannot optimize tail calls if the types change. * Any call that changes types, we replace that type with Any. These casts diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/OptimizationPhases.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/OptimizationPhases.scala new file mode 100644 index 0000000000..6518fceb32 --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/OptimizationPhases.scala @@ -0,0 +1,15 @@ +package com.twitter.scalding.typed + +import com.stripe.dagon.Rule + +/** + * This is a class to allow customization + * of how we plan typed pipes + */ +abstract class OptimizationPhases { + def phases: Seq[Rule[TypedPipe]] +} + +final class EmptyOptimizationPhases extends OptimizationPhases { + def phases = Nil +} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/OptimizationRules.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/OptimizationRules.scala new file mode 100644 index 0000000000..ac6974cef8 --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/OptimizationRules.scala @@ -0,0 +1,1000 @@ +package com.twitter.scalding.typed + +import com.stripe.dagon.{ FunctionK, Memoize, Rule, PartialRule, Dag, Literal } +import com.twitter.scalding.typed.functions.{ FlatMapping, FlatMappedFn, FilterKeysToFilter } +import com.twitter.scalding.typed.functions.ComposedFunctions.{ ComposedMapFn, ComposedFilterFn, ComposedOnComplete } + +object OptimizationRules { + type LiteralPipe[T] = Literal[TypedPipe, T] + + import Literal.{ Unary, Binary } + import TypedPipe._ + + /** + * Since our TypedPipe is covariant, but the Literal is not + * this is actually safe in this context, but not in general + */ + def widen[T](l: LiteralPipe[_ <: T]): LiteralPipe[T] = { + // to prove this is safe, see that if you have + // LiteralPipe[_ <: T] we can call .evaluate to get + // TypedPipe[_ <: T] which due to covariance is + // TypedPipe[T], and then using toLiteral we can get + // LiteralPipe[T] + // + // that would be wasteful to apply since the final + // result is identity. + l.asInstanceOf[LiteralPipe[T]] + } + + /** + * Convert a TypedPipe[T] to a Literal[TypedPipe, T] for + * use with Dagon + */ + def toLiteral: FunctionK[TypedPipe, LiteralPipe] = + Memoize.functionK[TypedPipe, LiteralPipe]( + new Memoize.RecursiveK[TypedPipe, LiteralPipe] { + + def toFunction[A] = { + case (cp: CounterPipe[a], f) => + Unary(f(cp.pipe), CounterPipe(_: TypedPipe[(a, Iterable[((String, String), Long)])])) + case (c: CrossPipe[a, b], f) => + Binary(f(c.left), f(c.right), CrossPipe(_: TypedPipe[a], _: TypedPipe[b])) + case (cv@CrossValue(_, _), f) => + def go[A, B](cv: CrossValue[A, B]): LiteralPipe[(A, B)] = + cv match { + case CrossValue(a, ComputedValue(v)) => + Binary(f(a), f(v), { (a: TypedPipe[A], b: TypedPipe[B]) => + CrossValue(a, ComputedValue(b)) + }) + case CrossValue(a, v) => + Unary(f(a), CrossValue(_: TypedPipe[A], v)) + } + widen(go(cv)) + case (p: DebugPipe[a], f) => + Unary(f(p.input), DebugPipe(_: TypedPipe[a])) + case (p: FilterKeys[a, b], f) => + widen(Unary(f(p.input), FilterKeys(_: TypedPipe[(a, b)], p.fn))) + case (p: Filter[a], f) => + Unary(f(p.input), Filter(_: TypedPipe[a], p.fn)) + case (p: Fork[a], f) => + Unary(f(p.input), Fork(_: TypedPipe[a])) + case (p: FlatMapValues[a, b, c], f) => + widen(Unary(f(p.input), FlatMapValues(_: TypedPipe[(a, b)], p.fn))) + case (p: FlatMapped[a, b], f) => + Unary(f(p.input), FlatMapped(_: TypedPipe[a], p.fn)) + case (p: ForceToDisk[a], f) => + Unary(f(p.input), ForceToDisk(_: TypedPipe[a])) + case (it@IterablePipe(_), _) => + Literal.Const(it) + case (p: MapValues[a, b, c], f) => + widen(Unary(f(p.input), MapValues(_: TypedPipe[(a, b)], p.fn))) + case (p: Mapped[a, b], f) => + Unary(f(p.input), Mapped(_: TypedPipe[a], p.fn)) + case (p: MergedTypedPipe[a], f) => + Binary(f(p.left), f(p.right), MergedTypedPipe(_: TypedPipe[a], _: TypedPipe[a])) + case (src@SourcePipe(_), _) => + Literal.Const(src) + case (p: SumByLocalKeys[a, b], f) => + widen(Unary(f(p.input), SumByLocalKeys(_: TypedPipe[(a, b)], p.semigroup))) + case (p: TrappedPipe[a], f) => + Unary(f(p.input), TrappedPipe[a](_: TypedPipe[a], p.sink, p.conv)) + case (p: WithDescriptionTypedPipe[a], f) => + Unary(f(p.input), WithDescriptionTypedPipe(_: TypedPipe[a], p.description, p.deduplicate)) + case (p: WithOnComplete[a], f) => + Unary(f(p.input), WithOnComplete(_: TypedPipe[a], p.fn)) + case (EmptyTypedPipe, _) => + Literal.Const(EmptyTypedPipe) + case (hg: HashCoGroup[a, b, c, d], f) => + widen(handleHashCoGroup(hg, f)) + case (CoGroupedPipe(cg), f) => + widen(handleCoGrouped(cg, f)) + case (ReduceStepPipe(rs), f) => + widen(handleReduceStep(rs, f)) + } + }) + + private def handleReduceStep[K, V1, V2](rs: ReduceStep[K, V1, V2], recurse: FunctionK[TypedPipe, LiteralPipe]): LiteralPipe[(K, V2)] = + rs match { + case step@IdentityReduce(_, _, _, _) => + Unary(widen[(K, V2)](recurse(step.mapped)), { (tp: TypedPipe[(K, V2)]) => ReduceStepPipe(step.copy(mapped = tp)) }) + case step@UnsortedIdentityReduce(_, _, _, _) => + Unary(widen[(K, V2)](recurse(step.mapped)), { (tp: TypedPipe[(K, V2)]) => ReduceStepPipe(step.copy(mapped = tp)) }) + case step@IdentityValueSortedReduce(_, _, _, _, _) => + def go[A, B](ivsr: IdentityValueSortedReduce[A, B]): LiteralPipe[(A, B)] = + Unary(widen[(A, B)](recurse(ivsr.mapped)), { (tp: TypedPipe[(A, B)]) => + ReduceStepPipe[A, B, B](IdentityValueSortedReduce[A, B]( + ivsr.keyOrdering, + tp, + ivsr.valueSort, + ivsr.reducers, + ivsr.descriptions)) + }) + widen[(K, V2)](go(step)) + case step@ValueSortedReduce(_, _, _, _, _, _) => + def go[A, B, C](vsr: ValueSortedReduce[A, B, C]): LiteralPipe[(A, C)] = + Unary(recurse(vsr.mapped), { (tp: TypedPipe[(A, B)]) => + ReduceStepPipe[A, B, C](ValueSortedReduce[A, B, C]( + vsr.keyOrdering, + tp, + vsr.valueSort, + vsr.reduceFn, + vsr.reducers, + vsr.descriptions)) + }) + go(step) + case step@IteratorMappedReduce(_, _, _, _, _) => + def go[A, B, C](imr: IteratorMappedReduce[A, B, C]): LiteralPipe[(A, C)] = + Unary(recurse(imr.mapped), { (tp: TypedPipe[(A, B)]) => ReduceStepPipe[A, B, C](imr.copy(mapped = tp)) }) + + go(step) + } + + private def handleCoGrouped[K, V](cg: CoGroupable[K, V], recurse: FunctionK[TypedPipe, LiteralPipe]): LiteralPipe[(K, V)] = { + import CoGrouped._ + + def pipeToCG[V1](t: TypedPipe[(K, V1)]): CoGroupable[K, V1] = + t match { + case ReduceStepPipe(cg: CoGroupable[K @unchecked, V1 @unchecked]) => + // we are relying on the fact that we use Ordering[K] + // as a contravariant type, despite it not being defined + // that way. + cg + case CoGroupedPipe(cg) => + // we are relying on the fact that we use Ordering[K] + // as a contravariant type, despite it not being defined + // that way. + cg.asInstanceOf[CoGroupable[K, V1]] + case kvPipe => IdentityReduce(cg.keyOrdering, kvPipe, None, Nil) + } + + cg match { + case p@Pair(_, _, _) => + def go[A, B, C](pair: Pair[K, A, B, C]): LiteralPipe[(K, C)] = { + val llit = handleCoGrouped(pair.larger, recurse) + val rlit = handleCoGrouped(pair.smaller, recurse) + val fn = pair.fn + Binary(llit, rlit, { (l: TypedPipe[(K, A)], r: TypedPipe[(K, B)]) => + Pair(pipeToCG(l), pipeToCG(r), fn) + }) + } + widen(go(p)) + case wr@WithReducers(_, _) => + def go[V1 <: V](wr: WithReducers[K, V1]): LiteralPipe[(K, V)] = { + val reds = wr.reds + Unary[TypedPipe, (K, V1), (K, V)](handleCoGrouped(wr.on, recurse), { (tp: TypedPipe[(K, V1)]) => + tp match { + case ReduceStepPipe(rs) => + withReducers(rs, reds) + case CoGroupedPipe(cg) => + CoGroupedPipe(WithReducers(cg, reds)) + case kvPipe => + ReduceStepPipe(IdentityReduce(cg.keyOrdering, kvPipe, None, Nil) + .withReducers(reds)) + } + }) + } + go(wr) + case wd@WithDescription(_, _) => + def go[V1 <: V](wd: WithDescription[K, V1]): LiteralPipe[(K, V)] = { + val desc = wd.description + Unary[TypedPipe, (K, V1), (K, V)](handleCoGrouped(wd.on, recurse), { (tp: TypedPipe[(K, V1)]) => + tp match { + case ReduceStepPipe(rs) => + withDescription(rs, desc) + case CoGroupedPipe(cg) => + CoGroupedPipe(WithDescription(cg, desc)) + case kvPipe => + kvPipe.withDescription(desc) + } + }) + } + go(wd) + case fk@FilterKeys(_, _) => + def go[V1 <: V](fk: FilterKeys[K, V1]): LiteralPipe[(K, V)] = { + val fn = fk.fn + Unary[TypedPipe, (K, V1), (K, V)](handleCoGrouped(fk.on, recurse), { (tp: TypedPipe[(K, V1)]) => + tp match { + case ReduceStepPipe(rs) => + filterKeys(rs, fn) + case CoGroupedPipe(cg) => + CoGroupedPipe(FilterKeys(cg, fn)) + case kvPipe => + kvPipe.filterKeys(fn) + } + }) + } + go(fk) + case mg@MapGroup(_, _) => + def go[V1, V2 <: V](mg: MapGroup[K, V1, V2]): LiteralPipe[(K, V)] = { + val fn = mg.fn + Unary[TypedPipe, (K, V1), (K, V)](handleCoGrouped(mg.on, recurse), { (tp: TypedPipe[(K, V1)]) => + tp match { + case ReduceStepPipe(rs) => + mapGroup(rs, fn) + case CoGroupedPipe(cg) => + CoGroupedPipe(MapGroup(cg, fn)) + case kvPipe => + ReduceStepPipe( + IdentityReduce(cg.keyOrdering, kvPipe, None, Nil) + .mapGroup(fn)) + } + }) + } + go(mg) + case step@IdentityReduce(_, _, _, _) => + widen(handleReduceStep(step, recurse)) + case step@UnsortedIdentityReduce(_, _, _, _) => + widen(handleReduceStep(step, recurse)) + case step@IteratorMappedReduce(_, _, _, _, _) => + widen(handleReduceStep(step, recurse)) + } + } + + /** + * This can't really usefully be on ReduceStep since users never want to use it + * as an ADT, as the planner does. + */ + private def withReducers[K, V1, V2](rs: ReduceStep[K, V1, V2], reds: Int): TypedPipe[(K, V2)] = + rs match { + case step@IdentityReduce(_, _, _, _) => + ReduceStepPipe(step.withReducers(reds)) + case step@UnsortedIdentityReduce(_, _, _, _) => + ReduceStepPipe(step.withReducers(reds)) + case step@IdentityValueSortedReduce(_, _, _, _, _) => + ReduceStepPipe(step.withReducers(reds)) + case step@ValueSortedReduce(_, _, _, _, _, _) => + ReduceStepPipe(step.withReducers(reds)) + case step@IteratorMappedReduce(_, _, _, _, _) => + ReduceStepPipe(step.withReducers(reds)) + } + + private def withDescription[K, V1, V2](rs: ReduceStep[K, V1, V2], descr: String): TypedPipe[(K, V2)] = + rs match { + case step@IdentityReduce(_, _, _, _) => + ReduceStepPipe(step.withDescription(descr)) + case step@UnsortedIdentityReduce(_, _, _, _) => + ReduceStepPipe(step.withDescription(descr)) + case step@IdentityValueSortedReduce(_, _, _, _, _) => + ReduceStepPipe(step.withDescription(descr)) + case step@ValueSortedReduce(_, _, _, _, _, _) => + ReduceStepPipe(step.withDescription(descr)) + case step@IteratorMappedReduce(_, _, _, _, _) => + ReduceStepPipe(step.withDescription(descr)) + } + + private def filterKeys[K, V1, V2](rs: ReduceStep[K, V1, V2], fn: K => Boolean): TypedPipe[(K, V2)] = + rs match { + case IdentityReduce(ord, p, r, d) => + ReduceStepPipe(IdentityReduce(ord, FilterKeys(p, fn), r, d)) + case UnsortedIdentityReduce(ord, p, r, d) => + ReduceStepPipe(UnsortedIdentityReduce(ord, FilterKeys(p, fn), r, d)) + case ivsr@IdentityValueSortedReduce(_, _, _, _, _) => + def go[V](ivsr: IdentityValueSortedReduce[K, V]): TypedPipe[(K, V)] = { + val IdentityValueSortedReduce(ord, p, v, r, d) = ivsr + ReduceStepPipe(IdentityValueSortedReduce[K, V](ord, FilterKeys(p, fn), v, r, d)) + } + go(ivsr) + case vsr@ValueSortedReduce(_, _, _, _, _, _) => + def go(vsr: ValueSortedReduce[K, V1, V2]): TypedPipe[(K, V2)] = { + val ValueSortedReduce(ord, p, v, redfn, r, d) = vsr + ReduceStepPipe(ValueSortedReduce[K, V1, V2](ord, FilterKeys(p, fn), v, redfn, r, d)) + } + go(vsr) + case imr@IteratorMappedReduce(_, _, _, _, _) => + def go(imr: IteratorMappedReduce[K, V1, V2]): TypedPipe[(K, V2)] = { + val IteratorMappedReduce(ord, p, redfn, r, d) = imr + ReduceStepPipe(IteratorMappedReduce[K, V1, V2](ord, FilterKeys(p, fn), redfn, r, d)) + } + go(imr) + } + + private def mapGroup[K, V1, V2, V3](rs: ReduceStep[K, V1, V2], fn: (K, Iterator[V2]) => Iterator[V3]): TypedPipe[(K, V3)] = + rs match { + case step@IdentityReduce(_, _, _, _) => + ReduceStepPipe(step.mapGroup(fn)) + case step@UnsortedIdentityReduce(_, _, _, _) => + ReduceStepPipe(step.mapGroup(fn)) + case step@IdentityValueSortedReduce(_, _, _, _, _) => + ReduceStepPipe(step.mapGroup(fn)) + case step@ValueSortedReduce(_, _, _, _, _, _) => + ReduceStepPipe(step.mapGroup(fn)) + case step@IteratorMappedReduce(_, _, _, _, _) => + ReduceStepPipe(step.mapGroup(fn)) + } + + private def handleHashCoGroup[K, V, V2, R](hj: HashCoGroup[K, V, V2, R], recurse: FunctionK[TypedPipe, LiteralPipe]): LiteralPipe[(K, R)] = { + val rightLit: LiteralPipe[(K, V2)] = hj.right match { + case step@IdentityReduce(_, _, _, _) => + Unary(widen[(K, V2)](recurse(step.mapped)), { (tp: TypedPipe[(K, V2)]) => ReduceStepPipe(step.copy(mapped = tp)) }) + case step@UnsortedIdentityReduce(_, _, _, _) => + Unary(widen[(K, V2)](recurse(step.mapped)), { (tp: TypedPipe[(K, V2)]) => ReduceStepPipe(step.copy(mapped = tp)) }) + case step@IteratorMappedReduce(_, _, _, _, _) => + def go[A, B, C](imr: IteratorMappedReduce[A, B, C]): LiteralPipe[(A, C)] = + Unary(recurse(imr.mapped), { (tp: TypedPipe[(A, B)]) => ReduceStepPipe[A, B, C](imr.copy(mapped = tp)) }) + + widen(go(step)) + } + + val ordK: Ordering[K] = hj.right match { + case step@IdentityReduce(_, _, _, _) => step.keyOrdering + case step@UnsortedIdentityReduce(_, _, _, _) => step.keyOrdering + case step@IteratorMappedReduce(_, _, _, _, _) => step.keyOrdering + } + + val joiner = hj.joiner + + Binary(recurse(hj.left), rightLit, + { (ltp: TypedPipe[(K, V)], rtp: TypedPipe[(K, V2)]) => + rtp match { + case ReduceStepPipe(hg: HashJoinable[K @unchecked, V2 @unchecked]) => + HashCoGroup(ltp, hg, joiner) + case otherwise => + HashCoGroup(ltp, IdentityReduce(ordK, otherwise, None, Nil), joiner) + } + }) + } + + /** + * Unroll a set of merges up to the first non-merge node, dropping + * an EmptyTypedPipe from the list + */ + def unrollMerge[A](t: TypedPipe[A]): List[TypedPipe[A]] = { + @annotation.tailrec + def loop(first: TypedPipe[A], todo: List[TypedPipe[A]], acc: List[TypedPipe[A]]): List[TypedPipe[A]] = + first match { + case MergedTypedPipe(l, r) => loop(l, r :: todo, acc) + case EmptyTypedPipe => + todo match { + case Nil => acc.reverse + case h :: tail => loop(h, tail, acc) + } + case notMerge => + val acc1 = notMerge :: acc + todo match { + case Nil => acc1.reverse + case h :: tail => loop(h, tail, acc1) + } + } + + loop(t, Nil, Nil) + } + + ///////////////////////////// + // + // Here are some actual rules for simplifying TypedPipes + // + ///////////////////////////// + + /** + * It is easier for planning if all fanouts are made explicit. + * This rule adds a Fork node every time there is a fanout + * + * This rule applied first makes it easier to match in subsequent + * rules without constantly checking for fanout nodes. + * + * This can increase the number of map-reduce steps compared + * to simply recomputing on both sides of a fork + */ + object AddExplicitForks extends Rule[TypedPipe] { + + def maybeFork[A](on: Dag[TypedPipe], t: TypedPipe[A]): Option[TypedPipe[A]] = + t match { + case ForceToDisk(_) => None + case Fork(t) if on.contains(ForceToDisk(t)) => Some(ForceToDisk(t)) + case Fork(_) => None + case EmptyTypedPipe | IterablePipe(_) | SourcePipe(_) => None + case other if !on.hasSingleDependent(other) => + Some { + // if we are already forcing this, use it + if (on.contains(ForceToDisk(other))) ForceToDisk(other) + else Fork(other) + } + case _ => None + } + + def needsFork[A](on: Dag[TypedPipe], t: TypedPipe[A]): Boolean = + maybeFork(on, t).isDefined + + private def forkCoGroup[K, V](on: Dag[TypedPipe], cg: CoGrouped[K, V]): Option[CoGrouped[K, V]] = { + import CoGrouped._ + + cg match { + case Pair(left: HashJoinable[K, v], right, jf) if forkHashJoinable(on, left).isDefined => + forkHashJoinable(on, left).map { + Pair(_, right, jf) + } + case Pair(left: CoGrouped[K, v], right, jf) if forkCoGroup(on, left).isDefined => + forkCoGroup(on, left).map { + Pair(_, right, jf) + } + case Pair(left, right: HashJoinable[K, v], jf) if forkHashJoinable(on, right).isDefined => + forkHashJoinable(on, right).map { + Pair(left, _, jf) + } + case Pair(left, right: CoGrouped[K, v], jf) if forkCoGroup(on, right).isDefined => + forkCoGroup(on, right).map { + Pair(left, _, jf) + } + case Pair(_, _, _) => None // neither side needs a fork + case WithDescription(cg, d) => forkCoGroup(on, cg).map(WithDescription(_, d)) + case WithReducers(cg, r) => forkCoGroup(on, cg).map(WithReducers(_, r)) + case MapGroup(cg, fn) => forkCoGroup(on, cg).map(MapGroup(_, fn)) + case FilterKeys(cg, fn) => forkCoGroup(on, cg).map(FilterKeys(_, fn)) + } + } + + /** + * The casts in here are safe, but scala loses track of the types in these kinds of + * pattern matches. + * We can fix it by changing the types on the identity reduces to use EqTypes[V1, V2] + * in case class and leaving the V2 parameter. + */ + private def forkReduceStep[A, B, C](on: Dag[TypedPipe], rs: ReduceStep[A, B, C]): Option[ReduceStep[A, B, C]] = rs match { + case step@IdentityReduce(_, _, _, _) => + maybeFork(on, step.mapped).map { p => step.copy(mapped = p) }.asInstanceOf[Option[ReduceStep[A, B, C]]] + case step@UnsortedIdentityReduce(_, _, _, _) => + maybeFork(on, step.mapped).map { p => step.copy(mapped = p) }.asInstanceOf[Option[ReduceStep[A, B, C]]] + case step@IdentityValueSortedReduce(_, _, _, _, _) => + maybeFork(on, step.mapped).map { p => step.copy(mapped = p) }.asInstanceOf[Option[ReduceStep[A, B, C]]] + case step@ValueSortedReduce(_, _, _, _, _, _) => + def go(vsr: ValueSortedReduce[A, B, C]): Option[ValueSortedReduce[A, B, C]] = + maybeFork(on, step.mapped).map { p => + ValueSortedReduce[A, B, C](vsr.keyOrdering, + p, vsr.valueSort, vsr.reduceFn, vsr.reducers, vsr.descriptions) + } + go(step) + case step@IteratorMappedReduce(_, _, _, _, _) => + def go(imr: IteratorMappedReduce[A, B, C]): Option[IteratorMappedReduce[A, B, C]] = + maybeFork(on, step.mapped).map { p => imr.copy(mapped = p) } + go(step) + } + + private def forkHashJoinable[K, V](on: Dag[TypedPipe], hj: HashJoinable[K, V]): Option[HashJoinable[K, V]] = + hj match { + case step@IdentityReduce(_, _, _, _) => + maybeFork(on, step.mapped).map { p => step.copy(mapped = p) } + case step@UnsortedIdentityReduce(_, _, _, _) => + maybeFork(on, step.mapped).map { p => step.copy(mapped = p) } + case step@IteratorMappedReduce(_, _, _, _, _) => + maybeFork(on, step.mapped).map { p => step.copy(mapped = p) } + } + + def apply[T](on: Dag[TypedPipe]) = { + case CounterPipe(a) if needsFork(on, a) => maybeFork(on, a).map(CounterPipe(_)) + case CrossPipe(a, b) if needsFork(on, a) => maybeFork(on, a).map(CrossPipe(_, b)) + case CrossPipe(a, b) if needsFork(on, b) => maybeFork(on, b).map(CrossPipe(a, _)) + case CrossValue(a, b) if needsFork(on, a) => maybeFork(on, a).map(CrossValue(_, b)) + case CrossValue(a, ComputedValue(b)) if needsFork(on, b) => maybeFork(on, b).map { fb => CrossValue(a, ComputedValue(fb)) } + case DebugPipe(p) => maybeFork(on, p).map(DebugPipe(_)) + case FilterKeys(p, fn) => maybeFork(on, p).map(FilterKeys(_, fn)) + case f@Filter(_, _) => + def go[A](f: Filter[A]): Option[TypedPipe[A]] = { + val Filter(p, fn) = f + maybeFork(on, p).map(Filter(_, fn)) + } + go(f) + case FlatMapValues(p, fn) => maybeFork(on, p).map(FlatMapValues(_, fn)) + case FlatMapped(p, fn) => maybeFork(on, p).map(FlatMapped(_, fn)) + case ForceToDisk(_) | Fork(_) => None // already has a barrier + case HashCoGroup(left, right, jf) if needsFork(on, left) => maybeFork(on, left).map(HashCoGroup(_, right, jf)) + case HashCoGroup(left, right, jf) => forkHashJoinable(on, right).map(HashCoGroup(left, _, jf)) + case MapValues(p, fn) => maybeFork(on, p).map(MapValues(_, fn)) + case Mapped(p, fn) => maybeFork(on, p).map(Mapped(_, fn)) + case MergedTypedPipe(a, b) if needsFork(on, a) => maybeFork(on, a).map(MergedTypedPipe(_, b)) + case MergedTypedPipe(a, b) if needsFork(on, b) => maybeFork(on, b).map(MergedTypedPipe(a, _)) + case ReduceStepPipe(rs) => forkReduceStep(on, rs).map(ReduceStepPipe(_)) + case SumByLocalKeys(p, sg) => maybeFork(on, p).map(SumByLocalKeys(_, sg)) + case t@TrappedPipe(_, _, _) => + def go[A](t: TrappedPipe[A]): Option[TypedPipe[A]] = { + val TrappedPipe(p, sink, conv) = t + maybeFork(on, p).map(TrappedPipe(_, sink, conv)) + } + go(t) + case CoGroupedPipe(cgp) => forkCoGroup(on, cgp).map(CoGroupedPipe(_)) + case WithOnComplete(p, fn) => maybeFork(on, p).map(WithOnComplete(_, fn)) + case WithDescriptionTypedPipe(p, d1, d2) => maybeFork(on, p).map(WithDescriptionTypedPipe(_, d1, d2)) + case _ => None + } + } + + /** + * a.flatMap(f).flatMap(g) == a.flatMap { x => f(x).flatMap(g) } + */ + object ComposeFlatMap extends PartialRule[TypedPipe] { + def applyWhere[T](on: Dag[TypedPipe]) = { + case FlatMapped(FlatMapped(in, fn0), fn1) => + FlatMapped(in, FlatMappedFn(fn1).runAfter(FlatMapping.FlatM(fn0))) + case FlatMapValues(FlatMapValues(in, fn0), fn1) => + FlatMapValues(in, FlatMappedFn(fn1).runAfter(FlatMapping.FlatM(fn0))) + } + } + + /** + * a.map(f).map(g) == a.map { x => f(x).map(g) } + */ + object ComposeMap extends PartialRule[TypedPipe] { + def applyWhere[T](on: Dag[TypedPipe]) = { + case Mapped(Mapped(in, fn0), fn1) => + Mapped(in, ComposedMapFn(fn0, fn1)) + case MapValues(MapValues(in, fn0), fn1) => + MapValues(in, ComposedMapFn(fn0, fn1)) + } + } + + /** + * a.filter(f).filter(g) == a.filter { x => f(x) && g(x) } + * + * also if a filterKeys follows a filter, we might as well + * compose because we can't push the filterKeys up higher + */ + object ComposeFilter extends Rule[TypedPipe] { + def apply[T](on: Dag[TypedPipe]) = { + // scala can't type check this, so we hold its hand: + // case Filter(Filter(in, fn0), fn1) => + // Some(Filter(in, ComposedFilterFn(fn0, fn1))) + case f@Filter(_, _) => + def go[A](f: Filter[A]): Option[TypedPipe[A]] = + f.input match { + case f1: Filter[a] => + // We have to be really careful here because f.fn and f1.fn + // have the same type. Type checking won't save you here + // we do have a test that exercises this, however + Some(Filter[a](f1.input, ComposedFilterFn(f1.fn, f.fn))) + case _ => None + } + go(f) + case FilterKeys(FilterKeys(in, fn0), fn1) => + Some(FilterKeys(in, ComposedFilterFn(fn0, fn1))) + case FilterKeys(Filter(in, fn0), fn1) => + Some(Filter(in, ComposedFilterFn(fn0, FilterKeysToFilter(fn1)))) + case _ => None + } + } + + /** + * a.onComplete(f).onComplete(g) == a.onComplete { () => f(); g() } + */ + object ComposeWithOnComplete extends PartialRule[TypedPipe] { + def applyWhere[T](on: Dag[TypedPipe]) = { + case WithOnComplete(WithOnComplete(pipe, fn0), fn1) => + WithOnComplete(pipe, ComposedOnComplete(fn0, fn1)) + } + } + /** + * a.map(f).flatMap(g) == a.flatMap { x => g(f(x)) } + * a.flatMap(f).map(g) == a.flatMap { x => f(x).map(g) } + * + * This is a rule you may want to apply after having + * composed all the maps first + */ + object ComposeMapFlatMap extends PartialRule[TypedPipe] { + def applyWhere[T](on: Dag[TypedPipe]) = { + case FlatMapped(Mapped(in, f), g) => + FlatMapped(in, FlatMappedFn(g).runAfter(FlatMapping.Map(f))) + case FlatMapValues(MapValues(in, f), g) => + FlatMapValues(in, FlatMappedFn(g).runAfter(FlatMapping.Map(f))) + case Mapped(FlatMapped(in, f), g) => + FlatMapped(in, FlatMappedFn(f).combine(FlatMappedFn.fromMap(g))) + case MapValues(FlatMapValues(in, f), g) => + FlatMapValues(in, FlatMappedFn(f).combine(FlatMappedFn.fromMap(g))) + } + } + + + /** + * a.filter(f).flatMap(g) == a.flatMap { x => if (f(x)) g(x) else Iterator.empty } + * a.flatMap(f).filter(g) == a.flatMap { x => f(x).filter(g) } + * + * This is a rule you may want to apply after having + * composed all the filters first + */ + object ComposeFilterFlatMap extends Rule[TypedPipe] { + def apply[T](on: Dag[TypedPipe]) = { + case FlatMapped(Filter(in, f), g) => + Some(FlatMapped(in, FlatMappedFn(g).runAfter(FlatMapping.filter(f)))) + case filter: Filter[b] => + filter.input match { + case fm: FlatMapped[a, b] => + Some(FlatMapped[a, b](fm.input, FlatMappedFn(fm.fn).combine(FlatMappedFn.fromFilter(filter.fn)))) + case _ => None + } + case _ => + None + } + } + /** + * a.filter(f).map(g) == a.flatMap { x => if (f(x)) Iterator.single(g(x)) else Iterator.empty } + * a.map(f).filter(g) == a.flatMap { x => val y = f(x); if (g(y)) Iterator.single(y) else Iterator.empty } + * + * This is a rule you may want to apply after having + * composed all the filters first + * + * This may be a deoptimization on some platforms that have native filters since + * you could avoid the Iterator boxing in that case. + */ + object ComposeFilterMap extends Rule[TypedPipe] { + def apply[T](on: Dag[TypedPipe]) = { + case Mapped(Filter(in, f), g) => + Some(FlatMapped(in, FlatMappedFn.fromFilter(f).combine(FlatMappedFn.fromMap(g)))) + case filter: Filter[b] => + filter.input match { + case fm: Mapped[a, b] => + Some(FlatMapped[a, b](fm.input, FlatMappedFn.fromMap(fm.fn).combine(FlatMappedFn.fromFilter(filter.fn)))) + case _ => None + } + case _ => + None + } + } + + /** + * In scalding 0.17 and earlier, descriptions were automatically pushdown below + * merges and flatMaps/map/etc.. + */ + object DescribeLater extends PartialRule[TypedPipe] { + def applyWhere[T](on: Dag[TypedPipe]) = { + case Mapped(WithDescriptionTypedPipe(in, desc, dedup), fn) => + WithDescriptionTypedPipe(Mapped(in, fn), desc, dedup) + case MapValues(WithDescriptionTypedPipe(in, desc, dedup), fn) => + WithDescriptionTypedPipe(MapValues(in, fn), desc, dedup) + case FlatMapped(WithDescriptionTypedPipe(in, desc, dedup), fn) => + WithDescriptionTypedPipe(FlatMapped(in, fn), desc, dedup) + case FlatMapValues(WithDescriptionTypedPipe(in, desc, dedup), fn) => + WithDescriptionTypedPipe(FlatMapValues(in, fn), desc, dedup) + case f@Filter(WithDescriptionTypedPipe(_, _, _), _) => + def go[A](f: Filter[A]): TypedPipe[A] = + f match { + case Filter(WithDescriptionTypedPipe(in, desc, dedup), fn) => + WithDescriptionTypedPipe(Filter(in, fn), desc, dedup) + case unreachable => unreachable + } + go(f) + case FilterKeys(WithDescriptionTypedPipe(in, desc, dedup), fn) => + WithDescriptionTypedPipe(FilterKeys(in, fn), desc, dedup) + } + } + + /** + * (a ++ a) == a.flatMap { t => List(t, t) } + */ + object DiamondToFlatMap extends Rule[TypedPipe] { + def apply[T](on: Dag[TypedPipe]) = { + case m@MergedTypedPipe(_, _) => + val pipes = unrollMerge(m) + val flatMapped = + pipes.groupBy { tp => tp: TypedPipe[T] } + .iterator + .map { + case (p, Nil) => sys.error(s"unreachable: $p has no values") + case (p, _ :: Nil) => p // just once + case (p, repeated) => + val rsize = repeated.size + p.flatMap(Iterator.fill(rsize)(_)) + } + .toVector + + if (pipes.size == flatMapped.size) None // we didn't reduce the number of merges + else Some(TypedPipe.typedPipeMonoid.sum(flatMapped)) + case _ => None + } + } + + /** + * After a forceToDisk there is no need to immediately fork. + * Calling forceToDisk twice in a row is the same as once. + * Calling fork twice in a row is the same as once. + */ + object RemoveDuplicateForceFork extends PartialRule[TypedPipe] { + def applyWhere[T](on: Dag[TypedPipe]) = { + case ForceToDisk(ForceToDisk(t)) => ForceToDisk(t) + case ForceToDisk(Fork(t)) => ForceToDisk(t) + case Fork(Fork(t)) => Fork(t) + case Fork(ForceToDisk(t)) => ForceToDisk(t) + case Fork(t) if on.contains(ForceToDisk(t)) => ForceToDisk(t) + } + } + + /** + * We ignore .group if there are is no setting of reducers + * + * This is arguably not a great idea, but scalding has always + * done it to minimize accidental map-reduce steps + */ + object IgnoreNoOpGroup extends PartialRule[TypedPipe] { + def applyWhere[T](on: Dag[TypedPipe]) = { + case ReduceStepPipe(IdentityReduce(_, input, None, _)) => + input + } + } + + /** + * In map-reduce settings, Merge is almost free in two contexts: + * 1. the final write + * 2. at the point we are doing a shuffle anyway. + * + * By defering merge as long as possible, we hope to find more such + * cases + */ + object DeferMerge extends PartialRule[TypedPipe] { + private def handleFilter[A]: PartialFunction[Filter[A], TypedPipe[A]] = { + case Filter(MergedTypedPipe(a, b), fn) => MergedTypedPipe(Filter(a, fn), Filter(b, fn)) + } + + def applyWhere[T](on: Dag[TypedPipe]) = { + case Mapped(MergedTypedPipe(a, b), fn) => + MergedTypedPipe(Mapped(a, fn), Mapped(b, fn)) + case FlatMapped(MergedTypedPipe(a, b), fn) => + MergedTypedPipe(FlatMapped(a, fn), FlatMapped(b, fn)) + case MapValues(MergedTypedPipe(a, b), fn) => + MergedTypedPipe(MapValues(a, fn), MapValues(b, fn)) + case FlatMapValues(MergedTypedPipe(a, b), fn) => + MergedTypedPipe(FlatMapValues(a, fn), FlatMapValues(b, fn)) + case f@Filter(_, _) if handleFilter.isDefinedAt(f) => handleFilter(f) + case FilterKeys(MergedTypedPipe(a, b), fn) => + MergedTypedPipe(FilterKeys(a, fn), FilterKeys(b, fn)) + } + } + + /** + * Push filterKeys up as early as possible. This can happen before + * a shuffle, which can be a major win. This allows you to write + * generic methods that return all the data, but if downstream someone + * only wants certain keys they don't pay to compute everything. + * + * This is an optimization we didn't do in scalding 0.17 and earlier + * because .toTypedPipe on the group totally hid the structure from + * us + */ + object FilterKeysEarly extends Rule[TypedPipe] { + private def filterReduceStep[K, V1, V2](rs: ReduceStep[K, V1, V2], fn: K => Boolean): ReduceStep[K, _, _ <: V2] = + rs match { + case step@IdentityReduce(_, _, _, _) => step.filterKeys(fn) + case step@UnsortedIdentityReduce(_, _, _, _) => step.filterKeys(fn) + case step@IdentityValueSortedReduce(_, _, _, _, _) => step.filterKeys(fn) + case step@ValueSortedReduce(_, _, _, _, _, _) => step.filterKeys(fn) + case step@IteratorMappedReduce(_, _, _, _, _) => step.filterKeys(fn) + } + + private def filterCoGroupable[K, V](rs: CoGroupable[K, V], fn: K => Boolean): CoGroupable[K, V] = + rs match { + case step@IdentityReduce(_, _, _, _) => step.filterKeys(fn) + case step@UnsortedIdentityReduce(_, _, _, _) => step.filterKeys(fn) + case step@IteratorMappedReduce(_, _, _, _, _) => step.filterKeys(fn) + case cg: CoGrouped[K, V] => filterCoGroup(cg, fn) + } + + private def filterCoGroup[K, V](cg: CoGrouped[K, V], fn: K => Boolean): CoGrouped[K, V] = + cg match { + case CoGrouped.Pair(a, b, jf) => + CoGrouped.Pair(filterCoGroupable(a, fn), filterCoGroupable(b, fn), jf) + case CoGrouped.FilterKeys(cg, g) => + filterCoGroup(cg, ComposedFilterFn(g, fn)) + case CoGrouped.MapGroup(cg, g) => + CoGrouped.MapGroup(filterCoGroup(cg, fn), g) + case CoGrouped.WithDescription(cg, d) => + CoGrouped.WithDescription(filterCoGroup(cg, fn), d) + case CoGrouped.WithReducers(cg, r) => + CoGrouped.WithReducers(filterCoGroup(cg, fn), r) + } + + def apply[T](on: Dag[TypedPipe]) = { + case FilterKeys(ReduceStepPipe(rsp), fn) => + Some(ReduceStepPipe(filterReduceStep(rsp, fn))) + case FilterKeys(CoGroupedPipe(cg), fn) => + Some(CoGroupedPipe(filterCoGroup(cg, fn))) + case FilterKeys(HashCoGroup(left, right, joiner), fn) => + val newRight = right match { + case step@IdentityReduce(_, _, _, _) => step.filterKeys(fn) + case step@UnsortedIdentityReduce(_, _, _, _) => step.filterKeys(fn) + case step@IteratorMappedReduce(_, _, _, _, _) => step.filterKeys(fn) + } + Some(HashCoGroup(FilterKeys(left, fn), newRight, joiner)) + case FilterKeys(MapValues(pipe, mapFn), filterFn) => + Some(MapValues(FilterKeys(pipe, filterFn), mapFn)) + case FilterKeys(FlatMapValues(pipe, fmFn), filterFn) => + Some(FlatMapValues(FilterKeys(pipe, filterFn), fmFn)) + case _ => None + } + } + + /** + * EmptyTypedPipe is kind of zero of most of these operations + * We go ahead and simplify as much as possible if we see + * an EmptyTypedPipe + */ + object EmptyIsOftenNoOp extends PartialRule[TypedPipe] { + + private def emptyCogroup[K, V](cg: CoGrouped[K, V]): Boolean = { + import CoGrouped._ + + def empty(t: TypedPipe[Any]): Boolean = t match { + case EmptyTypedPipe => true + case _ => false + } + cg match { + case Pair(left, _, jf) if left.inputs.forall(empty) && (Joiner.isLeftJoinLike(jf) == Some(true)) => true + case Pair(_, right, jf) if right.inputs.forall(empty) && (Joiner.isRightJoinLike(jf) == Some(true)) => true + case Pair(left, right, _) if left.inputs.forall(empty) && right.inputs.forall(empty) => true + case Pair(_, _, _) => false + case WithDescription(cg, _) => emptyCogroup(cg) + case WithReducers(cg, _) => emptyCogroup(cg) + case MapGroup(cg, _) => emptyCogroup(cg) + case FilterKeys(cg, _) => emptyCogroup(cg) + } + } + + private def emptyHashJoinable[K, V](hj: HashJoinable[K, V]): Boolean = + hj match { + case step@IdentityReduce(_, _, _, _) => step.mapped == EmptyTypedPipe + case step@UnsortedIdentityReduce(_, _, _, _) => step.mapped == EmptyTypedPipe + case step@IteratorMappedReduce(_, _, _, _, _) => step.mapped == EmptyTypedPipe + } + + def applyWhere[T](on: Dag[TypedPipe]) = { + case CrossPipe(EmptyTypedPipe, _) => EmptyTypedPipe + case CrossPipe(_, EmptyTypedPipe) => EmptyTypedPipe + case CrossValue(EmptyTypedPipe, _) => EmptyTypedPipe + case CrossValue(_, ComputedValue(EmptyTypedPipe)) => EmptyTypedPipe + case CrossValue(_, EmptyValue) => EmptyTypedPipe + case DebugPipe(EmptyTypedPipe) => EmptyTypedPipe + case FilterKeys(EmptyTypedPipe, _) => EmptyTypedPipe + case Filter(EmptyTypedPipe, _) => EmptyTypedPipe + case FlatMapValues(EmptyTypedPipe, _) => EmptyTypedPipe + case FlatMapped(EmptyTypedPipe, _) => EmptyTypedPipe + case ForceToDisk(EmptyTypedPipe) => EmptyTypedPipe + case Fork(EmptyTypedPipe) => EmptyTypedPipe + case HashCoGroup(EmptyTypedPipe, _, _) => EmptyTypedPipe + case HashCoGroup(_, right, hjf) if emptyHashJoinable(right) && Joiner.isInnerHashJoinLike(hjf) == Some(true) => EmptyTypedPipe + case MapValues(EmptyTypedPipe, _) => EmptyTypedPipe + case Mapped(EmptyTypedPipe, _) => EmptyTypedPipe + case MergedTypedPipe(EmptyTypedPipe, a) => a + case MergedTypedPipe(a, EmptyTypedPipe) => a + case ReduceStepPipe(rs: ReduceStep[_, _, _]) if rs.mapped == EmptyTypedPipe => EmptyTypedPipe + case SumByLocalKeys(EmptyTypedPipe, _) => EmptyTypedPipe + case TrappedPipe(EmptyTypedPipe, _, _) => EmptyTypedPipe + case CoGroupedPipe(cgp) if emptyCogroup(cgp) => EmptyTypedPipe + case WithOnComplete(EmptyTypedPipe, _) => EmptyTypedPipe // there is nothing to do, so we never have workers complete + case WithDescriptionTypedPipe(EmptyTypedPipe, _, _) => EmptyTypedPipe // descriptions apply to tasks, but empty has no tasks + } + } + + /** + * If an Iterable is empty, it is the same as EmptyTypedPipe + */ + object EmptyIterableIsEmpty extends PartialRule[TypedPipe] { + def applyWhere[T](on: Dag[TypedPipe]) = { + case IterablePipe(it) if it.isEmpty => EmptyTypedPipe + } + } + + /** + * This is useful on map-reduce like systems to avoid + * serializing data into the system that you are going + * to then filter + */ + object FilterLocally extends Rule[TypedPipe] { + def apply[T](on: Dag[TypedPipe]) = { + case f@Filter(_, _) => + def go[T1 <: T](f: Filter[T1]): Option[TypedPipe[T]] = + f match { + case Filter(IterablePipe(iter), fn) => + Some(IterablePipe(iter.filter(fn))) + case _ => None + } + go(f) + case f@FilterKeys(_, _) => + def go[K, V, T >: (K, V)](f: FilterKeys[K, V]): Option[TypedPipe[T]] = + f match { + case FilterKeys(IterablePipe(iter), fn) => + Some(IterablePipe(iter.filter { case (k, _) => fn(k) })) + case _ => None + } + go(f) + case _ => None + } + } + /** + * ForceToDisk before hashJoin, this makes sure any filters + * have been applied + */ + object ForceToDiskBeforeHashJoin extends Rule[TypedPipe] { + // A set of operations naturally have barriers after them, + // there is no need to add an explicit force after a reduce + // step or after a source, since both will already have been + // checkpointed + final def maybeForce[T](t: TypedPipe[T]): TypedPipe[T] = + t match { + case ReduceStepPipe(IdentityReduce(_, input, None, _)) => + // this is a no-op reduce that will be removed, so we may need to add a force + maybeForce(input) + case SourcePipe(_) | IterablePipe(_) | CoGroupedPipe(_) | ReduceStepPipe(_) | ForceToDisk(_) => t + case WithOnComplete(pipe, fn) => // TODO it is not clear this is safe in cascading 3, since oncomplete is an each + WithOnComplete(maybeForce(pipe), fn) + case WithDescriptionTypedPipe(pipe, desc, dedup) => + WithDescriptionTypedPipe(maybeForce(pipe), desc, dedup) + case pipe => ForceToDisk(pipe) + } + + def apply[T](on: Dag[TypedPipe]) = { + case HashCoGroup(left, right: HashJoinable[a, b], joiner) => + val newRight: HashJoinable[a, b] = right match { + case step@IdentityReduce(_, _, _, _) => + step.copy(mapped = maybeForce(step.mapped)) + case step@UnsortedIdentityReduce(_, _, _, _) => + step.copy(mapped = maybeForce(step.mapped)) + case step@IteratorMappedReduce(_, _, _, _, _) => + step.copy(mapped = maybeForce(step.mapped)) + } + if (newRight != right) Some(HashCoGroup(left, newRight, joiner)) + else None + case (cp@CrossPipe(_, _)) => Some(cp.viaHashJoin) + case (cv@CrossValue(_, _)) => Some(cv.viaHashJoin) + case _ => None + } + } + + /** + * Convert all HashCoGroup to CoGroupedPipe + */ + object HashToShuffleCoGroup extends Rule[TypedPipe] { + def apply[T](on: Dag[TypedPipe]) = { + case HashCoGroup(left, right: HashJoinable[a, b], joiner) => + val leftg = Grouped(left)(right.keyOrdering) + val joiner2 = Joiner.toCogroupJoiner2(joiner) + Some(CoGroupedPipe(CoGrouped.Pair(leftg, right, joiner2))) + case (cp@CrossPipe(_, _)) => Some(cp.viaHashJoin) + case (cv@CrossValue(_, _)) => Some(cv.viaHashJoin) + case _ => None + } + } + + /////// + // These are composed rules that are related + ////// + + /** + * Like kinds can be composed .map(f).map(g), + * filter(f).filter(g) etc... + */ + val composeSame: Rule[TypedPipe] = + Rule.orElse( + List( + ComposeMap, + ComposeFilter, + ComposeFlatMap, + ComposeWithOnComplete)) + /** + * If you are going to do a flatMap, following it or preceding it with map/filter + * you might as well compose into the flatMap + */ + val composeIntoFlatMap: Rule[TypedPipe] = + Rule.orElse( + List( + ComposeMapFlatMap, + ComposeFilterFlatMap, + ComposeFlatMap)) + + val simplifyEmpty: Rule[TypedPipe] = + EmptyIsOftenNoOp.orElse( + EmptyIterableIsEmpty) + + /** + * These are a list of rules to be applied in order (Dag.applySeq) + * that should generally always improve things on Map/Reduce-like + * platforms. + * + * These are rules we should apply to any TypedPipe before handing + * to cascading. These should be a bit conservative in that they + * should be highly likely to improve the graph. + */ + val standardMapReduceRules: List[Rule[TypedPipe]] = + List( + // phase 0, add explicit forks to not duplicate pipes on fanout below + AddExplicitForks, + // phase 1, compose flatMap/map, move descriptions down, defer merge, filter pushup etc... + composeSame.orElse(DescribeLater).orElse(FilterKeysEarly).orElse(DeferMerge), + // phase 2, combine different kinds of mapping operations into flatMaps, including redundant merges + composeIntoFlatMap.orElse(simplifyEmpty).orElse(DiamondToFlatMap), + // phase 3, remove duplicates forces/forks (e.g. .fork.fork or .forceToDisk.fork, ....) + RemoveDuplicateForceFork) +} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/TypedPipe.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/TypedPipe.scala index 1b5a8641d8..9101c7d5fa 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/TypedPipe.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/TypedPipe.scala @@ -14,7 +14,7 @@ limitations under the License. package com.twitter.scalding.typed import java.io.{ OutputStream, InputStream, Serializable } -import java.util.{ Random, UUID } +import java.util.UUID import cascading.flow.FlowDef import cascading.pipe.Pipe @@ -22,7 +22,8 @@ import cascading.tuple.Fields import com.twitter.algebird.{ Aggregator, Batched, Monoid, Semigroup } import com.twitter.scalding.TupleConverter.singleConverter import com.twitter.scalding._ -import com.twitter.scalding.serialization.OrderedSerialization +import com.twitter.scalding.typed.functions.{ AsLeft, AsRight, Constant, DropValue1, EqTypes, Identity, MakeKey, GetKey, GetValue, RandomFilter, RandomNextInt, Swap, TuplizeFunction, WithConstant, PartialFunctionToFilter, SubTypes } +import com.twitter.scalding.serialization.{ OrderedSerialization, UnitOrderedSerialization } import com.twitter.scalding.serialization.OrderedSerialization.Result import com.twitter.scalding.serialization.macros.impl.BinaryOrdering import com.twitter.scalding.serialization.macros.impl.BinaryOrdering._ @@ -137,41 +138,75 @@ object TypedPipe extends Serializable { } } - case class CrossPipe[T, U](left: TypedPipe[T], right: TypedPipe[U]) extends TypedPipe[(T, U)] { + final case class CoGroupedPipe[K, V](cogrouped: CoGrouped[K, V]) extends TypedPipe[(K, V)] + final case class CounterPipe[A](pipe: TypedPipe[(A, Iterable[((String, String), Long)])]) extends TypedPipe[A] + final case class CrossPipe[T, U](left: TypedPipe[T], right: TypedPipe[U]) extends TypedPipe[(T, U)] { def viaHashJoin: TypedPipe[(T, U)] = left.groupAll.hashJoin(right.groupAll).values } - case class CrossValue[T, U](left: TypedPipe[T], right: ValuePipe[U]) extends TypedPipe[(T, U)] { + final case class CrossValue[T, U](left: TypedPipe[T], right: ValuePipe[U]) extends TypedPipe[(T, U)] { def viaHashJoin: TypedPipe[(T, U)] = right match { - case EmptyValue => EmptyTypedPipe - case LiteralValue(v) => left.map { (_, v) } - case ComputedValue(pipe) => CrossPipe(left, pipe) + case EmptyValue => + EmptyTypedPipe + case LiteralValue(v) => + left.map(WithConstant(v)) + case ComputedValue(pipe) => + CrossPipe(left, pipe) } } - case class DebugPipe[T](pipe: TypedPipe[T]) extends TypedPipe[T] - case class FilterKeys[K, V](input: TypedPipe[(K, V)], fn: K => Boolean) extends TypedPipe[(K, V)] - case class Filter[T](input: TypedPipe[T], fn: T => Boolean) extends TypedPipe[T] - case class Fork[T](input: TypedPipe[T]) extends TypedPipe[T] - case class FlatMapValues[K, V, U](input: TypedPipe[(K, V)], fn: V => TraversableOnce[U]) extends TypedPipe[(K, U)] - case class FlatMapped[T, U](input: TypedPipe[T], fn: T => TraversableOnce[U]) extends TypedPipe[U] - case class ForceToDisk[T](pipe: TypedPipe[T]) extends TypedPipe[T] - case class IterablePipe[T](iterable: Iterable[T]) extends TypedPipe[T] - case class MapValues[K, V, U](input: TypedPipe[(K, V)], fn: V => U) extends TypedPipe[(K, U)] - case class Mapped[T, U](input: TypedPipe[T], fn: T => U) extends TypedPipe[U] - case class MergedTypedPipe[T](left: TypedPipe[T], right: TypedPipe[T]) extends TypedPipe[T] - case class SourcePipe[T](source: TypedSource[T]) extends TypedPipe[T] - case class SumByLocalKeys[K, V](input: TypedPipe[(K, V)], semigroup: Semigroup[V]) extends TypedPipe[(K, V)] - case class TrappedPipe[T, U >: T](input: TypedPipe[T], sink: Source with TypedSink[T], conv: TupleConverter[U]) extends TypedPipe[U] - case class WithDescriptionTypedPipe[T](input: TypedPipe[T], description: String, deduplicate: Boolean) extends TypedPipe[T] - case class WithOnComplete[T](input: TypedPipe[T], fn: () => Unit) extends TypedPipe[T] + final case class DebugPipe[T](input: TypedPipe[T]) extends TypedPipe[T] + final case class FilterKeys[K, V](input: TypedPipe[(K, V)], fn: K => Boolean) extends TypedPipe[(K, V)] + final case class Filter[T](input: TypedPipe[T], fn: T => Boolean) extends TypedPipe[T] + final case class FlatMapValues[K, V, U](input: TypedPipe[(K, V)], fn: V => TraversableOnce[U]) extends TypedPipe[(K, U)] + final case class FlatMapped[T, U](input: TypedPipe[T], fn: T => TraversableOnce[U]) extends TypedPipe[U] + final case class ForceToDisk[T](input: TypedPipe[T]) extends TypedPipe[T] + final case class Fork[T](input: TypedPipe[T]) extends TypedPipe[T] + final case class HashCoGroup[K, V, W, R](left: TypedPipe[(K, V)], right: HashJoinable[K, W], joiner: (K, V, Iterable[W]) => Iterator[R]) extends TypedPipe[(K, R)] + final case class IterablePipe[T](iterable: Iterable[T]) extends TypedPipe[T] + final case class MapValues[K, V, U](input: TypedPipe[(K, V)], fn: V => U) extends TypedPipe[(K, U)] + final case class Mapped[T, U](input: TypedPipe[T], fn: T => U) extends TypedPipe[U] + final case class MergedTypedPipe[T](left: TypedPipe[T], right: TypedPipe[T]) extends TypedPipe[T] + final case class ReduceStepPipe[K, V1, V2](reduce: ReduceStep[K, V1, V2]) extends TypedPipe[(K, V2)] + final case class SourcePipe[T](source: TypedSource[T]) extends TypedPipe[T] + final case class SumByLocalKeys[K, V](input: TypedPipe[(K, V)], semigroup: Semigroup[V]) extends TypedPipe[(K, V)] + final case class TrappedPipe[T](input: TypedPipe[T], sink: Source with TypedSink[T], conv: TupleConverter[T]) extends TypedPipe[T] + final case class WithDescriptionTypedPipe[T](input: TypedPipe[T], description: String, deduplicate: Boolean) extends TypedPipe[T] + final case class WithOnComplete[T](input: TypedPipe[T], fn: () => Unit) extends TypedPipe[T] + case object EmptyTypedPipe extends TypedPipe[Nothing] - case class HashCoGroup[K, V, W, R](left: TypedPipe[(K, V)], - right: HashJoinable[K, W], - joiner: (K, V, Iterable[W]) => Iterator[R]) extends TypedPipe[(K, R)] - case class CoGroupedPipe[K, V](cogrouped: CoGrouped[K, V]) extends TypedPipe[(K, V)] - case class ReduceStepPipe[K, V1, V2](reduce: ReduceStep[K, V1, V2]) extends TypedPipe[(K, V2)] + implicit class InvariantTypedPipe[T](val pipe: TypedPipe[T]) extends AnyVal { + /** + * If any errors happen below this line, but before a groupBy, write to a TypedSink + */ + def addTrap(trapSink: Source with TypedSink[T])(implicit conv: TupleConverter[T]): TypedPipe[T] = + TypedPipe.TrappedPipe[T](pipe, trapSink, conv).withLine + } + + + private case class TallyByFn[A](group: String, fn: A => String) extends Function1[A, (A, Iterable[((String, String), Long)])] { + def apply(a: A) = (a, (((group, fn(a)), 1L)) :: Nil) + } + private case class TallyFn[A](group: String, counter: String) extends Function1[A, (A, Iterable[((String, String), Long)])] { + private[this] val inc = ((group, counter), 1L) :: Nil + def apply(a: A) = (a, inc) + } + private case class TallyLeft[A, B](group: String, fn: A => Either[String, B]) extends Function1[A, (List[B], Iterable[((String, String), Long)])] { + def apply(a: A) = fn(a) match { + case Right(b) => (b :: Nil, Nil) + case Left(cnt) => (Nil, ((group, cnt), 1L) :: Nil) + } + } + + implicit class TallyEnrichment[A, B <: Iterable[((String, String), Long)]](val pipe: TypedPipe[(A, B)]) extends AnyVal { + /** + * Increment hadoop counters with a (group, counter) by the amount in the second + * part of the tuple, and remove that second part + */ + def tally: TypedPipe[A] = + CounterPipe(pipe) + } } /** @@ -181,14 +216,41 @@ object TypedPipe extends Serializable { * Represents a phase in a distributed computation on an input data source * Wraps a cascading Pipe object, and holds the transformation done up until that point */ -sealed trait TypedPipe[+T] extends Serializable { +sealed abstract class TypedPipe[+T] extends Serializable { protected def withLine: TypedPipe[T] = LineNumber.tryNonScaldingCaller.map(_.toString) match { - case None => this - case Some(desc) => TypedPipe.WithDescriptionTypedPipe(this, desc, true) // deduplicate line numbers + case None => + this + case Some(desc) => + TypedPipe.WithDescriptionTypedPipe(this, desc, true) // deduplicate line numbers } + /** + * Increment diagnostic counters by 1 for each item in the pipe. + * The counter group will be the same for each item, the counter name + * is determined by the result of the `fn` passed in. + */ + def tallyBy(group: String)(fn: T => String): TypedPipe[T] = + map(TypedPipe.TallyByFn(group, fn)).tally + + /** + * Increment a specific diagnostic counter by 1 for each item in the pipe. + * + * this is the same as tallyBy(group)(_ => counter) + */ + def tallyAll(group: String, counter: String): TypedPipe[T] = + map(TypedPipe.TallyFn(group, counter)).tally + + /** + * Increment a diagnostic counter for each failure. This is like map, + * where the `fn` should return a `Right[U]` for each successful transformation + * and a `Left[String]` for each failure, with the String describing the failure. + * Each failure will be counted, and the result is just the successes. + */ + def tallyLeft[B](group: String)(fn: T => Either[String, B]): TypedPipe[B] = + map(TypedPipe.TallyLeft(group, fn)).tally.flatten + /** * Implements a cross product. The right side should be tiny * This gives the same results as @@ -248,13 +310,13 @@ sealed trait TypedPipe[+T] extends Serializable { */ @annotation.implicitNotFound(msg = "For asKeys method to work, the type in TypedPipe must have an Ordering.") def asKeys[U >: T](implicit ord: Ordering[U]): Grouped[U, Unit] = - map((_, ())).group + map(WithConstant(())).group /** * If T <:< U, then this is safe to treat as TypedPipe[U] due to covariance */ protected def raiseTo[U](implicit ev: T <:< U): TypedPipe[U] = - this.asInstanceOf[TypedPipe[U]] + SubTypes.fromEv(ev).liftCo[TypedPipe](this) /** * Filter and map. See scala.collection.List.collect. @@ -263,7 +325,7 @@ sealed trait TypedPipe[+T] extends Serializable { * } */ def collect[U](fn: PartialFunction[T, U]): TypedPipe[U] = - filter(fn.isDefinedAt(_)).map(fn) + filter(PartialFunctionToFilter(fn)).map(fn) /** * Attach a ValuePipe to each element this TypedPipe @@ -305,29 +367,24 @@ sealed trait TypedPipe[+T] extends Serializable { // cast because Ordering is not contravariant, but should be (and this cast is safe) implicit val ordT: Ordering[U] = ord.asInstanceOf[Ordering[U]] - // Semigroup to handle duplicates for a given key might have different values. - implicit val sg: Semigroup[T] = new Semigroup[T] { - def plus(a: T, b: T) = b - } - - val op = map { tup => (fn(tup), tup) }.sumByKey + val op = groupBy(fn).head val reduced = numReducers match { case Some(red) => op.withReducers(red) case None => op } - reduced.map(_._2) + reduced.map(GetValue()) } /** Merge two TypedPipes of different types by using Either */ def either[R](that: TypedPipe[R]): TypedPipe[Either[T, R]] = - map(Left(_)) ++ (that.map(Right(_))) + map(AsLeft()) ++ (that.map(AsRight())) /** * Sometimes useful for implementing custom joins with groupBy + mapValueStream when you know * that the value/key can fit in memory. Beware. */ def eitherValues[K, V, R](that: TypedPipe[(K, R)])(implicit ev: T <:< (K, V)): TypedPipe[(K, Either[V, R])] = - mapValues { (v: V) => Left(v) } ++ (that.mapValues { (r: R) => Right(r) }) + mapValues(AsLeft[V, R]()) ++ (that.mapValues(AsRight[V, R]())) /** * If you are going to create two branches or forks, @@ -386,14 +443,14 @@ sealed trait TypedPipe[+T] extends Serializable { /** flatten an Iterable */ def flatten[U](implicit ev: T <:< TraversableOnce[U]): TypedPipe[U] = - flatMap(_.asInstanceOf[TraversableOnce[U]]) // don't use ev which may not be serializable + raiseTo[TraversableOnce[U]].flatMap(Identity[TraversableOnce[U]]()) /** * flatten just the values * This is more useful on KeyedListLike, but added here to reduce assymmetry in the APIs */ def flattenValues[K, U](implicit ev: T <:< (K, TraversableOnce[U])): TypedPipe[(K, U)] = - flatMapValues[K, TraversableOnce[U], U] { us => us } + flatMapValues[K, TraversableOnce[U], U](Identity[TraversableOnce[U]]()) /** * Force a materialization of this pipe prior to the next operation. @@ -417,11 +474,12 @@ sealed trait TypedPipe[+T] extends Serializable { Grouped(raiseTo[(K, V)].withLine) /** Send all items to a single reducer */ - def groupAll: Grouped[Unit, T] = groupBy(x => ())(ordSer[Unit]).withReducers(1) + def groupAll: Grouped[Unit, T] = + groupBy(Constant(()))(UnitOrderedSerialization).withReducers(1) /** Given a key function, add the key, then call .group */ def groupBy[K](g: T => K)(implicit ord: Ordering[K]): Grouped[K, T] = - map { t => (g(t), t) }.group + map(MakeKey(g)).group /** Group using an explicit Ordering on the key. */ def groupWith[K, V](ord: Ordering[K])(implicit ev: <:<[T, (K, V)]): Grouped[K, V] = group(ev, ord) @@ -435,12 +493,9 @@ sealed trait TypedPipe[+T] extends Serializable { * * You probably want shard if you are just forcing a shuffle. */ - def groupRandomly(partitions: Int): Grouped[Int, T] = { - // Make it lazy so all mappers get their own: - lazy val rng = new java.util.Random(123) // seed this so it is repeatable - groupBy { _ => rng.nextInt(partitions) }(TypedPipe.identityOrdering) + def groupRandomly(partitions: Int): Grouped[Int, T] = + groupBy(RandomNextInt(123, partitions))(TypedPipe.identityOrdering) .withReducers(partitions) - } /** * Partitions this into two pipes according to a predicate. @@ -456,8 +511,10 @@ sealed trait TypedPipe[+T] extends Serializable { /** * Sample a fraction (between 0 and 1) uniformly independently at random each element of the pipe * does not require a reduce step. + * This method makes sure to fix the seed, otherwise restarts cause subtle errors. */ def sample(fraction: Double): TypedPipe[T] = sample(fraction, defaultSeed) + /** * Sample a fraction (between 0 and 1) uniformly independently at random each element of the pipe with * a given seed. @@ -465,10 +522,7 @@ sealed trait TypedPipe[+T] extends Serializable { */ def sample(fraction: Double, seed: Long): TypedPipe[T] = { require(0.0 <= fraction && fraction <= 1.0, s"got $fraction which is an invalid fraction") - - // Make sure to fix the seed, otherwise restarts cause subtle errors - lazy val rand = new Random(seed) - filter(_ => rand.nextDouble < fraction) + filter(RandomFilter(seed, fraction)) } /** @@ -502,9 +556,11 @@ sealed trait TypedPipe[+T] extends Serializable { def sum[U >: T](implicit plus: Semigroup[U]): ValuePipe[U] = { // every 1000 items, compact. lazy implicit val batchedSG: Semigroup[Batched[U]] = Batched.compactingSemigroup[U](1000) + // TODO: literals like this defeat caching in the planner ComputedValue(map { t => ((), Batched[U](t)) } .sumByLocalKeys // remove the Batched before going to the reducers + // TODO: literals like this defeat caching in the planner .map { case (_, batched) => batched.sum } .groupAll .forceToReducers @@ -563,10 +619,9 @@ sealed trait TypedPipe[+T] extends Serializable { * @return a pipe equivalent to the current pipe. */ def write(dest: TypedSink[T])(implicit flowDef: FlowDef, mode: Mode): TypedPipe[T] = { - // Make sure that we don't render the whole pipeline twice: - val res = fork - dest.writeFrom(res.toPipe[T](dest.sinkFields)(flowDef, mode, dest.setter)) - res + dest.writeFrom(toPipe[T](dest.sinkFields)(flowDef, mode, dest.setter)) + // We want to fork after this point + fork } /** @@ -600,16 +655,15 @@ sealed trait TypedPipe[+T] extends Serializable { /** Just keep the keys, or ._1 (if this type is a Tuple2) */ def keys[K](implicit ev: <:<[T, (K, Any)]): TypedPipe[K] = - // avoid capturing ev in the closure: - raiseTo[(K, Any)].map(_._1) + raiseTo[(K, Any)].map(GetKey()) /** swap the keys with the values */ def swap[K, V](implicit ev: <:<[T, (K, V)]): TypedPipe[(V, K)] = - raiseTo[(K, V)].map(_.swap) + raiseTo[(K, V)].map(Swap()) /** Just keep the values, or ._2 (if this type is a Tuple2) */ def values[V](implicit ev: <:<[T, (Any, V)]): TypedPipe[V] = - raiseTo[(Any, V)].map(_._2) + raiseTo[(Any, V)].map(GetValue()) /** * ValuePipe may be empty, so, this attaches it as an Option @@ -617,8 +671,8 @@ sealed trait TypedPipe[+T] extends Serializable { */ def leftCross[V](p: ValuePipe[V]): TypedPipe[(T, Option[V])] = p match { - case EmptyValue => map { (_, None) } - case LiteralValue(v) => map { (_, Some(v)) } + case EmptyValue => map(WithConstant(None)) + case LiteralValue(v) => map(WithConstant(Some(v))) case ComputedValue(pipe) => leftCross(pipe) } @@ -638,7 +692,7 @@ sealed trait TypedPipe[+T] extends Serializable { * } */ def mapWithValue[U, V](value: ValuePipe[U])(f: (T, Option[U]) => V): TypedPipe[V] = - leftCross(value).map(t => f(t._1, t._2)) + leftCross(value).map(TuplizeFunction(f)) /** * common pattern of attaching a value and then flatMap @@ -652,7 +706,7 @@ sealed trait TypedPipe[+T] extends Serializable { * } */ def flatMapWithValue[U, V](value: ValuePipe[U])(f: (T, Option[U]) => TraversableOnce[V]): TypedPipe[V] = - leftCross(value).flatMap(t => f(t._1, t._2)) + leftCross(value).flatMap(TuplizeFunction(f)) /** * common pattern of attaching a value and then filter @@ -666,7 +720,7 @@ sealed trait TypedPipe[+T] extends Serializable { * } */ def filterWithValue[U](value: ValuePipe[U])(f: (T, Option[U]) => Boolean): TypedPipe[T] = - leftCross(value).filter(t => f(t._1, t._2)).map(_._1) + leftCross(value).filter(TuplizeFunction(f)).map(GetKey()) /** * These operations look like joins, but they do not force any communication @@ -693,9 +747,9 @@ sealed trait TypedPipe[+T] extends Serializable { * For each element, do a map-side (hash) left join to look up a value */ def hashLookup[K >: T, V](grouped: HashJoinable[K, V]): TypedPipe[(K, Option[V])] = - map((_, ())) + map(WithConstant(())) .hashLeftJoin(grouped) - .map { case (t, (_, optV)) => (t, optV) } + .map(DropValue1()) /** * Enables joining when this TypedPipe has some keys with many many values and @@ -719,12 +773,6 @@ sealed trait TypedPipe[+T] extends Serializable { serialization: K => Array[Byte], ordering: Ordering[K]): Sketched[K, V] = Sketched(ev(this), reducers, delta, eps, seed) - - /** - * If any errors happen below this line, but before a groupBy, write to a TypedSink - */ - def addTrap[U >: T](trapSink: Source with TypedSink[T])(implicit conv: TupleConverter[U]): TypedPipe[U] = - TypedPipe.TrappedPipe[T, U](this, trapSink, conv).withLine } /** diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/ValuePipe.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/ValuePipe.scala index 1fe1a8de0e..5a0ec8f860 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/ValuePipe.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/ValuePipe.scala @@ -98,7 +98,7 @@ case object EmptyValue extends ValuePipe[Nothing] { this } } -case class LiteralValue[T](value: T) extends ValuePipe[T] { +final case class LiteralValue[T](value: T) extends ValuePipe[T] { override def map[U](fn: T => U) = LiteralValue(fn(value)) override def filter(fn: T => Boolean) = if (fn(value)) this else EmptyValue override def toTypedPipe = TypedPipe.from(Iterable(value)) @@ -109,7 +109,7 @@ case class LiteralValue[T](value: T) extends ValuePipe[T] { v } } -case class ComputedValue[T](override val toTypedPipe: TypedPipe[T]) extends ValuePipe[T] { +final case class ComputedValue[T](override val toTypedPipe: TypedPipe[T]) extends ValuePipe[T] { override def map[U](fn: T => U) = ComputedValue(toTypedPipe.map(fn)) override def filter(fn: T => Boolean) = ComputedValue(toTypedPipe.filter(fn)) diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/cascading_backend/AsyncFlowDefRunner.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/cascading_backend/AsyncFlowDefRunner.scala index c301dfbc98..f80abd6a85 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/cascading_backend/AsyncFlowDefRunner.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/cascading_backend/AsyncFlowDefRunner.scala @@ -17,7 +17,9 @@ import com.twitter.scalding.{ Mode, TypedPipe } +import com.twitter.scalding.typed.TypedSink import com.twitter.scalding.cascading_interop.FlowListenerPromise +import com.stripe.dagon.{ Dag, Rule, HMap } import java.util.UUID import java.util.concurrent.LinkedBlockingQueue import org.apache.hadoop.conf.Configuration @@ -35,7 +37,7 @@ object AsyncFlowDefRunner { * We send messages from other threads into the submit thread here */ private sealed trait FlowDefAction - private case class RunFlowDef(conf: Config, + private final case class RunFlowDef(conf: Config, mode: Mode, fd: FlowDef, result: Promise[(Long, JobStats)]) extends FlowDefAction @@ -83,9 +85,20 @@ class AsyncFlowDefRunner extends Writer { self => private[this] val mutex = new AnyRef + type StateKey[T] = (Config, Mode, TypedPipe[T]) + type WorkVal[T] = Future[TypedPipe[T]] + + /** + * @param filesToCleanup temporary files created by forceToDiskExecution + * @param initToOpt this is the mapping between user's TypedPipes and their optimized versions + * which are actually run. + * @param forcedPipes these are all the side effecting forcing of TypedPipes into simple + * SourcePipes or IterablePipes. These are for both toIterableExecution and forceToDiskExecution + */ private case class State( filesToCleanup: Map[Mode, Set[String]], - forcedPipes: Map[(Config, Mode, TypedPipe[Any]), Future[TypedPipe[Any]]]) { + initToOpt: HMap[TypedPipe, TypedPipe], + forcedPipes: HMap[StateKey, WorkVal]) { def addFilesToCleanup(m: Mode, s: Option[String]): State = s match { @@ -95,19 +108,38 @@ class AsyncFlowDefRunner extends Writer { self => case None => this } - def addPipe[T](c: Config, + /** + * Returns true if we actually add this optimized pipe. We do this + * because we don't want to take the side effect twice. + */ + def addForce[T](c: Config, m: Mode, init: TypedPipe[T], - p: Future[TypedPipe[T]]): Option[State] = + opt: TypedPipe[T], + p: Future[TypedPipe[T]]): (State, Boolean) = - forcedPipes.get((c, m, init)) match { + forcedPipes.get((c, m, opt)) match { case None => - Some(copy(forcedPipes = forcedPipes + ((c, m, init) -> p))) - case Some(exists) => None + (copy(forcedPipes = forcedPipes + ((c, m, opt) -> p), + initToOpt = initToOpt + (init -> opt)), true) + case Some(_) => + (copy(initToOpt = initToOpt + (init -> opt)), false) + } + + def getForce[T](c: Config, + m: Mode, + init: TypedPipe[T]): Option[Future[TypedPipe[T]]] = + + initToOpt.get(init).map { opt => + forcedPipes.get((c, m, opt)) match { + case None => + sys.error(s"invariant violation: initToOpt mapping exists for $init, but no forcedPipe") + case Some(p) => p + } } } - private[this] var state: State = State(Map.empty, Map.empty) + private[this] var state: State = State(Map.empty, HMap.empty, HMap.empty) private def updateState[S](fn: State => (State, S)): S = mutex.synchronized { @@ -222,50 +254,74 @@ class AsyncFlowDefRunner extends Writer { self => val done = Promise[Unit]() + val phases: Seq[Rule[TypedPipe]] = + CascadingBackend.defaultOptimizationRules(conf) + + val toOptimized = ToWrite.optimizeWriteBatch(writes, phases) + def prepareFD(c: Config, m: Mode): FlowDef = { val fd = new FlowDef - def force[A](t: TypedPipe[A]): Unit = { + def write[A](tpipe: TypedPipe[A], dest: TypedSink[A]): Unit = { + // We have already applied the optimizations to the batch of writes above + val pipe = CascadingBackend.toPipeUnoptimized(tpipe, dest.sinkFields)(fd, mode, dest.setter) + dest.writeFrom(pipe)(fd, mode) + } + + def force[A](init: TypedPipe[A], opt: TypedPipe[A]): Unit = { val pipePromise = Promise[TypedPipe[A]]() val fut = pipePromise.future // This updates mutable state val sinkOpt = updateState { s => - s.addPipe(conf, mode, t, fut) - .map { nextState => - val uuid = UUID.randomUUID - val (sink, forcedPipe, clean) = forceToDisk(uuid, c, m, t) - (nextState.addFilesToCleanup(m, clean), Some((sink, forcedPipe))) - } - .getOrElse((s, None)) + val (nextState, added) = s.addForce(conf, mode, init, opt, fut) + if (added) { + val uuid = UUID.randomUUID + val (sink, forcedPipe, clean) = forceToDisk(uuid, c, m, opt) + (nextState.addFilesToCleanup(m, clean), Some((sink, forcedPipe))) + } else { + (nextState, None) + } } sinkOpt.foreach { case (sink, fp) => - t.write(sink)(fd, m) + // We write the optimized pipe + write(opt, sink) val pipeFut = done.future.map(_ => fp()) pipePromise.completeWith(pipeFut) } } + def addIter[A](init: TypedPipe[A], optimized: Either[Iterable[A], Mappable[A]]): Unit = { + val result = optimized match { + case Left(iter) if iter.isEmpty => TypedPipe.EmptyTypedPipe + case Left(iter) => TypedPipe.IterablePipe(iter) + case Right(mappable) => TypedPipe.SourcePipe(mappable) + } + val fut = Future.successful(result) + updateState(_.addForce(conf, mode, init, result, fut)) + } writes.foreach { - case Force(pipe) => force(pipe) - case ToIterable(pipe) => - def step[A](t: TypedPipe[A]): Unit = { - t match { - case TypedPipe.EmptyTypedPipe => () - case TypedPipe.IterablePipe(_) => () - case TypedPipe.SourcePipe(src: Mappable[A]) => () + case Force(init) => + val opt = toOptimized(init) + force(init, opt) + case ToIterable(init) => + def step[A](opt: TypedPipe[A]): Unit = { + opt match { + case TypedPipe.EmptyTypedPipe => addIter(init, Left(Nil)) + case TypedPipe.IterablePipe(as) => addIter(init, Left(as)) + case TypedPipe.SourcePipe(src: Mappable[A]) => addIter(init, Right(src)) case other => // we need to write the pipe out first. - force(other) + force(init, opt) // now, when we go to check for the pipe later, it // will be a SourcePipe of a Mappable by construction } } - step(pipe) + step(toOptimized(init)) case SimpleWrite(pipe, sink) => - pipe.write(sink)(fd, m) + write(toOptimized(pipe), sink) } fd @@ -283,31 +339,34 @@ class AsyncFlowDefRunner extends Writer { self => m: Mode, initial: TypedPipe[T])(implicit cec: ConcurrentExecutionContext): Future[TypedPipe[T]] = - getState.forcedPipes.get((conf, m, initial)) match { - case Some(fut) => fut.asInstanceOf[Future[TypedPipe[T]]] + getState.getForce(conf, m, initial) match { + case Some(fut) => fut case None => val msg = - s"logic error: getForced($conf, $m, $initial) does not have a forced pipe" + s"logic error: getForced($conf, $m, $initial) does not have a forced pipe." Future.failed(new Exception(msg)) } def getIterable[T]( conf: Config, m: Mode, - initial: TypedPipe[T])(implicit cec: ConcurrentExecutionContext): Future[Iterable[T]] = initial match { - case TypedPipe.EmptyTypedPipe => Future.successful(Nil) - case TypedPipe.IterablePipe(iter) => Future.successful(iter) - case TypedPipe.SourcePipe(src: Mappable[T]) => - Future.successful( - new Iterable[T] { - def iterator = src.toIterator(conf, m) - }) - case other => - // this should have been forced: - getForced(conf, m, initial).flatMap(getIterable(conf, m, _)) - } + initial: TypedPipe[T])(implicit cec: ConcurrentExecutionContext): Future[Iterable[T]] = + + getForced(conf, m, initial).flatMap { + case TypedPipe.EmptyTypedPipe => Future.successful(Nil) + case TypedPipe.IterablePipe(iter) => Future.successful(iter) + case TypedPipe.SourcePipe(src: Mappable[T]) => + Future.successful( + new Iterable[T] { + def iterator = src.toIterator(conf, m) + }) + case other => + val msg = + s"logic error: expected an Iterable pipe. ($conf, $m, $initial) -> $other is not iterable" + Future.failed(new Exception(msg)) + } - private def forceToDisk[T]( + private def forceToDisk[T]( // linter:disable:UnusedParameter uuid: UUID, conf: Config, mode: Mode, diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/cascading_backend/CascadingBackend.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/cascading_backend/CascadingBackend.scala index 0a24f2d840..4c444f81fb 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/cascading_backend/CascadingBackend.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/cascading_backend/CascadingBackend.scala @@ -1,14 +1,28 @@ package com.twitter.scalding.typed.cascading_backend + import cascading.flow.FlowDef -import cascading.operation.Operation -import cascading.pipe.{CoGroup, Each, HashJoin, Pipe} -import cascading.tuple.{Fields, TupleEntry, Tuple => CTuple} -import com.twitter.scalding.TupleConverter.{singleConverter, tuple2Converter} -import com.twitter.scalding.TupleSetter.{singleSetter, tup2Setter} -import com.twitter.scalding.{CleanupIdentityFunction, Config, Dsl, Field, FlatMapFunction, FlowStateMap, GroupBuilder, HadoopMode, IterableSource, LineNumber, MapsideReduce, Mode, RichPipe, TupleConverter, TupleGetter, TupleSetter, TypedBufferOp, WrappedJoiner, Write} +import cascading.operation.{ Debug, Operation } +import cascading.pipe.{ CoGroup, Each, Pipe, HashJoin } +import cascading.tuple.{ Fields, Tuple => CTuple, TupleEntry } +import com.stripe.dagon.{ FunctionK, HCache, Memoize, Rule, Dag } +import com.twitter.scalding.TupleConverter.{ singleConverter, tuple2Converter } +import com.twitter.scalding.TupleSetter.{ singleSetter, tup2Setter } +import com.twitter.scalding.{ + CleanupIdentityFunction, Config, Dsl, Field, FlatMapFunction, FlowStateMap, GroupBuilder, + HadoopMode, IncrementCounters, LineNumber, IterableSource, MapsideReduce, Mode, RichFlowDef, + RichPipe, TupleConverter, TupleGetter, TupleSetter, TypedBufferOp, WrappedJoiner, Write +} import com.twitter.scalding.typed._ -import com.twitter.scalding.serialization.{Boxed, BoxedOrderedSerialization, CascadingBinaryComparator, EquivSerialization, OrderedSerialization, WrappedSerialization} +import com.twitter.scalding.typed.functions.{ FilterKeysToFilter, MapValuesToMap, FlatMapValuesToFlatMap, FlatMappedFn } +import com.twitter.scalding.serialization.{ + Boxed, + BoxedOrderedSerialization, + CascadingBinaryComparator, + EquivSerialization, + OrderedSerialization, + WrappedSerialization +} import java.util.WeakHashMap import scala.collection.immutable @@ -19,6 +33,7 @@ object CascadingBackend { private val valueField: Fields = new Fields("value") private val kvFields: Fields = new Fields("key", "value") + private val f0: Fields = new Fields(java.lang.Integer.valueOf(0)) private def tuple2Conv[K, V](ord: Ordering[K]): TupleConverter[(K, V)] = ord match { @@ -91,10 +106,28 @@ object CascadingBackend { op(ts, keyF) } - private case class CascadingPipe[T](pipe: Pipe, + private case class CascadingPipe[+T](pipe: Pipe, fields: Fields, @transient localFlowDef: FlowDef, // not serializable. - converter: TupleConverter[T]) + converter: TupleConverter[_ <: T]) { + + /** + * merge the flowDef into this new flowdef an make sure the tuples + * have the structure defined by setter + */ + def toPipe[U >: T](f: Fields, fd: FlowDef, setter: TupleSetter[U]): Pipe = { + // TODO, this may be identity if the setter is the inverse of the + // converter. If we can identify this we will save allocations + val resFd = new RichFlowDef(fd) + resFd.mergeFrom(localFlowDef) + RichPipe(pipe).mapTo[T, U](fields -> f)(t => t)(TupleConverter.asSuperConverter(converter), setter) + } + } + + private object CascadingPipe { + def single[T](pipe: Pipe, fd: FlowDef): CascadingPipe[T] = + CascadingPipe(pipe, f0, fd, singleConverter[T]) + } /** * we want to cache renderings of some TypedPipe to Pipe so cascading @@ -103,266 +136,281 @@ object CascadingBackend { * at once, and not need a static cache here, but currently we still * plan one TypedPipe at a time. */ - private class PipeCache { - private[this] val pipeCache = new WeakHashMap[TypedPipe[Any], Map[Mode, CascadingPipe[Any]]]() - - def cacheGet[T](t: TypedPipe[T], m: Mode)(p: FlowDef => CascadingPipe[T]): CascadingPipe[T] = { - def add(mmc: Map[Mode, CascadingPipe[Any]]): CascadingPipe[T] = { - val emptyFD = new FlowDef - val res = p(emptyFD) - pipeCache.put(t, mmc + (m -> res.asInstanceOf[CascadingPipe[Any]])) - res - } - - pipeCache.synchronized { - pipeCache.get(t) match { - case null => add(Map.empty) - case somemap if somemap.contains(m) => somemap(m).asInstanceOf[CascadingPipe[T]] - case missing => add(missing) + private class CompilerCache { + + private[this] val cache = new WeakHashMap[FlowDef, FunctionK[TypedPipe, CascadingPipe]]() + + def get(fd: FlowDef, m: Mode): FunctionK[TypedPipe, CascadingPipe] = + cache.synchronized { + cache.get(fd) match { + case null => + val c = compile(m) + cache.put(fd, c) + c + case nonNull => nonNull } } - } } - private[this] val pipeCache = new PipeCache - - final def toPipe[U](p: TypedPipe[U], fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]): Pipe = { - - import pipeCache.cacheGet - - val f0 = new Fields(java.lang.Integer.valueOf(0)) - - def singlePipe[T](t: TypedPipe[T], force: Boolean = false): CascadingPipe[T] = - cacheGet(t, mode) { localFD => - val pipe = toPipe(t, f0)(localFD, mode, singleSetter) - val p = if (force) RichPipe(pipe).forceToDisk else pipe - CascadingPipe[T](p, f0, localFD, singleConverter) - } - - def applyDescriptions(p: Pipe, descriptions: List[(String, Boolean)]): Pipe = { - val ordered = descriptions.collect { case (d, false) => d }.reverse - val unordered = descriptions.collect { case (d, true) => d }.distinct.sorted - - RichPipe.setPipeDescriptions(p, ordered ::: unordered) - } + private[this] val cache = new CompilerCache + + private def compile[T](mode: Mode): FunctionK[TypedPipe, CascadingPipe] = + Memoize.functionK[TypedPipe, CascadingPipe]( + new Memoize.RecursiveK[TypedPipe, CascadingPipe] { + def toFunction[T] = { + case (cp@CounterPipe(_), rec) => + def go[A](cp: CounterPipe[A]): CascadingPipe[A] = { + val CascadingPipe(pipe0, initF, fd, conv) = rec(cp.pipe) + val cpipe = RichPipe(pipe0) + .eachTo(initF -> f0)(new IncrementCounters[A](_, TupleConverter.asSuperConverter(conv))) + CascadingPipe.single[A](cpipe, fd) + } + go(cp) + case (cp@CrossPipe(_, _), rec) => + rec(cp.viaHashJoin) + case (cv@CrossValue(_, _), rec) => + rec(cv.viaHashJoin) + case (DebugPipe(p), rec) => + val inner = rec(p) + inner.copy(pipe = new Each(inner.pipe, new Debug)) + case (EmptyTypedPipe, rec) => + // just use an empty iterable pipe. + rec(IterablePipe(List.empty[T])) + case (fk@FilterKeys(_, _), rec) => + def go[K, V](node: FilterKeys[K, V]): CascadingPipe[(K, V)] = { + val rewrite = Filter[(K, V)](node.input, FilterKeysToFilter(node.fn)) + rec(rewrite) + } + go(fk) + case (f@Filter(_, _), rec) => + // hand holding for type inference + def go[T1 <: T](f: Filter[T1]): CascadingPipe[T] = { + val Filter(input, fn) = f + val CascadingPipe(pipe, initF, fd, conv) = rec(input) + // This does not need a setter, which is nice. + val fpipe = RichPipe(pipe).filter[T1](initF)(fn)(TupleConverter.asSuperConverter(conv)) + CascadingPipe[T](fpipe, initF, fd, conv) + } - /* - * This creates a mapping operation on a Pipe. It does so - * by merging the local FlowDef of the CascadingPipe into - * the one passed to this method, then running the FlatMappedFn - * and finally applying the descriptions. - */ - def finish[T](cp: CascadingPipe[T], - rest: FlatMappedFn[T, U], - descriptions: List[(String, Boolean)]): Pipe = { + go(f) + case (f@FlatMapValues(_, _), rec) => + def go[K, V, U](node: FlatMapValues[K, V, U]): CascadingPipe[T] = + rec(FlatMapped[(K, V), (K, U)](node.input, FlatMapValuesToFlatMap(node.fn))) + + go(f) + case (fm@FlatMapped(_, _), rec) => + // TODO we can optimize a flatmapped input directly and skip some tupleconverters + def go[A, B <: T](fm: FlatMapped[A, B]): CascadingPipe[T] = { + val CascadingPipe(pipe, initF, fd, conv) = rec(fm.input) + val fmpipe = RichPipe(pipe).flatMapTo[A, T](initF -> f0)(fm.fn)(TupleConverter.asSuperConverter(conv), singleSetter) + CascadingPipe.single[B](fmpipe, fd) + } - Dsl.flowDefToRichFlowDef(flowDef).mergeFrom(cp.localFlowDef) - val withRest = RichPipe(cp.pipe) - .flatMapTo[T, U](cp.fields -> fieldNames)(rest)(cp.converter, setter) + go(fm) + case (ForceToDisk(input), rec) => + val cp = rec(input) + cp.copy(pipe = RichPipe(cp.pipe).forceToDisk) + case (Fork(input), rec) => + // fork doesn't mean anything here since we are already planning each TypedPipe to + // something in cascading. Fork is an optimizer level operation + rec(input) + case (IterablePipe(iter), _) => + val fd = new FlowDef + val pipe = IterableSource[T](iter, f0)(singleSetter, singleConverter).read(fd, mode) + CascadingPipe.single[T](pipe, fd) + case (f@MapValues(_, _), rec) => + def go[K, A, B](fn: MapValues[K, A, B]): CascadingPipe[_ <: (K, B)] = + rec(Mapped[(K, A), (K, B)](fn.input, MapValuesToMap(fn.fn))) + + go(f) + case (m@Mapped(_, _), rec) => + def go[A, B <: T](m: Mapped[A, B]): CascadingPipe[T] = { + val Mapped(input, fn) = m + val CascadingPipe(pipe, initF, fd, conv) = rec(input) + val fmpipe = RichPipe(pipe).mapTo[A, T](initF -> f0)(fn)(TupleConverter.asSuperConverter(conv), singleSetter) + CascadingPipe.single[B](fmpipe, fd) + } - applyDescriptions(withRest, descriptions) - } + go(m) + + case (m@MergedTypedPipe(_, _), rec) => + OptimizationRules.unrollMerge(m) match { + case Nil => rec(EmptyTypedPipe) + case h :: Nil => rec(h) + case h :: tail => + // TODO: a better optimization is to not materialize this + // node at all if there is no fan out since groupBy and cogroupby + // can accept multiple inputs + // + // (a ++ a) == a.flatMap { x => List(x, x) } is an optimization we used to + // have + + val flowDef = new FlowDef + // if all of the converters are the same, we could skip some work + // here, but need to be able to see that correctly + val headPipe = rec(h).toPipe(f0, flowDef, singleSetter) + val tailPipes = tail.map { p => rec(p).toPipe(f0, flowDef, singleSetter) } + val merged = RichPipe.mergeAvoidingHashes(headPipe, tailPipes) + // push all the remaining flatmaps up: + // TODO: a better optimization is to not materialize this + // node at all if there is no fan out since groupBy and cogroupby + // can accept multiple inputs + CascadingPipe.single[T](merged, flowDef) + } + case (SourcePipe(typedSrc), _) => + val fd = new FlowDef + val pipe = typedSrc.read(fd, mode) + CascadingPipe[T](pipe, typedSrc.sourceFields, fd, typedSrc.converter[T]) + case (sblk@SumByLocalKeys(_, _), rec) => + def go[K, V](sblk: SumByLocalKeys[K, V]): CascadingPipe[(K, V)] = { + val cp = rec(sblk.input) + val localFD = new FlowDef + val cpKV: Pipe = cp.toPipe(kvFields, localFD, tup2Setter) + val msr = new MapsideReduce(sblk.semigroup, new Fields("key"), valueField, None)(singleConverter[V], singleSetter[V]) + val kvpipe = RichPipe(cpKV).eachTo(kvFields -> kvFields) { _ => msr } + CascadingPipe(kvpipe, kvFields, localFD, tuple2Converter[K, V]) + } + go(sblk) + case (trapped: TrappedPipe[u], rec) => + val cp = rec(trapped.input) + import trapped._ + // TODO: with diamonds in the graph, this might not be correct + // it seems cascading requires puts the immediate tuple that + // caused the exception, so if you addTrap( ).map(f).map(g) + // and f changes the tuple structure, if we don't collapse the + // maps into 1 operation, cascading can write two different + // schemas into the trap, making it unreadable. + // this basically means there can only be one operation in between + // a trap and a forceToDisk or a groupBy/cogroupBy (any barrier). + val fd = new FlowDef + val pp: Pipe = cp.toPipe[u](sink.sinkFields, fd, TupleSetter.asSubSetter(sink.setter)) + val pipe = RichPipe.assignName(pp) + fd.addTrap(pipe, sink.createTap(Write)(mode)) + CascadingPipe[u](pipe, sink.sinkFields, fd, conv) + case (WithDescriptionTypedPipe(input, descr, dedup), rec) => + + @annotation.tailrec + def loop[A](t: TypedPipe[A], acc: List[(String, Boolean)]): (TypedPipe[A], List[(String, Boolean)]) = + t match { + case WithDescriptionTypedPipe(i, desc, ded) => + loop(i, (desc, ded) :: acc) + case notDescr => (notDescr, acc) + } - def loop[T](t: TypedPipe[T], rest: FlatMappedFn[T, U], descriptions: List[(String, Boolean)]): Pipe = t match { - case cp@CrossPipe(_, _) => loop(cp.viaHashJoin, rest, descriptions) + val (root, descrs) = loop(input, (descr, dedup) :: Nil) + val cp = rec(root) + cp.copy(pipe = applyDescriptions(cp.pipe, descrs)) - case cv@CrossValue(_, _) => loop(cv.viaHashJoin, rest, descriptions) + case (WithOnComplete(input, fn), rec) => + val cp = rec(input) + val next = new Each(cp.pipe, Fields.ALL, new CleanupIdentityFunction(fn), Fields.REPLACE) + cp.copy(pipe = next) - case DebugPipe(p) => - // There is really little that can be done here but println - loop(p.map { t => println(t); t }, rest, descriptions) + case (hcg@HashCoGroup(_, _, _), rec) => + def go[K, V1, V2, R](hcg: HashCoGroup[K, V1, V2, R]): CascadingPipe[(K, R)] = + planHashJoin(hcg.left, + hcg.right, + hcg.joiner, + rec) - case EmptyTypedPipe => - // just use an empty iterable pipe. - // Note, rest is irrelevant - val empty = IterableSource(Iterable.empty, fieldNames)(setter, singleConverter[U]).read(flowDef, mode) - applyDescriptions(empty, descriptions) + go(hcg) + case (ReduceStepPipe(rs), rec) => + planReduceStep(rs, rec) - case fk@FilterKeys(_, _) => - def go[K, V](node: FilterKeys[K, V]): Pipe = node match { - case FilterKeys(IterablePipe(iter), fn) => - loop[(K, V)](IterablePipe(iter.filter { case (k, v) => fn(k) }), rest, descriptions) - case _ => - loop[(K, V)](node.input, rest.runAfter(FlatMapping.filterKeys(node.fn)), descriptions) - } - go(fk) - - case f@Filter(_, _) => - // hand holding for type inference - def go[T1 <: T](f: Filter[T1]) = f match { - case Filter(IterablePipe(iter), fn) => loop(IterablePipe(iter.filter(fn)), rest, descriptions) - case _ => - loop[T1](f.input, rest.runAfter(FlatMapping.filter(f.fn)), descriptions) + case (CoGroupedPipe(cg), rec) => + planCoGroup(cg, rec) } - go(f) - - case f@FlatMapValues(_, _) => - def go[K, V, U](node: FlatMapValues[K, V, U]): Pipe = { - // don't capture node, which is a TypedPipe, which we avoid serializing - val fn = node.fn - loop(node.input, rest.runAfter( - FlatMapping.FlatM[(K, V), (K, U)] { case (k, v) => - fn(v).map((k, _)) - }), descriptions) - } - - go(f) + }) - case FlatMapped(prev, fn) => - loop(prev, rest.runAfter(FlatMapping.FlatM(fn)), descriptions) + private def applyDescriptions(p: Pipe, descriptions: List[(String, Boolean)]): Pipe = { + val ordered = descriptions.collect { case (d, false) => d }.reverse + val unordered = descriptions.collect { case (d, true) => d }.distinct.sorted - case ForceToDisk(EmptyTypedPipe) => loop(EmptyTypedPipe, rest, descriptions) - case ForceToDisk(i@IterablePipe(iter)) => loop(i, rest, descriptions) - case ForceToDisk(pipe) => finish(singlePipe(pipe, force = true), rest, descriptions) + RichPipe.setPipeDescriptions(p, ordered ::: unordered) + } - case Fork(EmptyTypedPipe) => loop(EmptyTypedPipe, rest, descriptions) - case Fork(i@IterablePipe(iter)) => loop(i, rest, descriptions) - case Fork(pipe) => finish(singlePipe(pipe), rest, descriptions) + /** + * These are rules we should apply to any TypedPipe before handing + * to cascading. These should be a bit conservative in that they + * should be highly likely to improve the graph. + */ + def defaultOptimizationRules(config: Config): Seq[Rule[TypedPipe]] = { + + def std(forceHash: Rule[TypedPipe]) = + OptimizationRules.IgnoreNoOpGroup :: + (OptimizationRules.standardMapReduceRules ::: + List( + OptimizationRules.FilterLocally, // after filtering, we may have filtered to nothing, lets see + OptimizationRules.simplifyEmpty, + // add any explicit forces to the optimized graph + Rule.orElse(List( + forceHash, // do this only on the maximally optimized graph + OptimizationRules.RemoveDuplicateForceFork) + ))) + + config.getOptimizationPhases match { + case Some(tryPhases) => tryPhases.get.phases + case None => + val force = + if (config.getHashJoinAutoForceRight) OptimizationRules.ForceToDiskBeforeHashJoin + else Rule.empty[TypedPipe] + val hashToCogroup = + if (config.getConvertHashJoinToShuffleJoin) OptimizationRules.HashToShuffleCoGroup + else Rule.empty[TypedPipe] + std(force.orElse(hashToCogroup)) + } + } - case IterablePipe(iterable) => - val toSrc = IterableSource(iterable, f0)(singleSetter[T], singleConverter[T]) - loop(SourcePipe(toSrc), rest, descriptions) + final def toPipe[U](p: TypedPipe[U], fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]): Pipe = { - case f@MapValues(_, _) => - def go[K, V, U](node: MapValues[K, V, U]): Pipe = { - // don't capture node, which is a TypedPipe, which we avoid serializing - val mvfn = node.fn - loop(node.input, rest.runAfter( - FlatMapping.Map[(K, V), (K, U)] { case (k, v) => (k, mvfn(v)) }), descriptions) - } + val phases = defaultOptimizationRules( + mode match { + case h: HadoopMode => Config.fromHadoop(h.jobConf) + case _ => Config.empty + }) + val (d, id) = Dag(p, OptimizationRules.toLiteral) + val d1 = d.applySeq(phases) + val p1 = d1.evaluate(id) - go(f) - - case Mapped(input, fn) => loop(input, rest.runAfter(FlatMapping.Map(fn)), descriptions) - - case MergedTypedPipe(left, right) => - @annotation.tailrec - def allMerged[A](m: TypedPipe[A], - stack: List[TypedPipe[A]], - acc: List[TypedPipe[A]], - ds: List[(String, Boolean)]): (List[TypedPipe[A]], List[(String, Boolean)]) = m match { - case MergedTypedPipe(left, right) => - allMerged(left, right :: stack, acc, ds) - case EmptyTypedPipe => stack match { - case Nil => (acc, ds) - case h :: t => allMerged(h, t, acc, ds) - } - case WithDescriptionTypedPipe(p, desc, dedup) => - allMerged(p, stack, acc, (desc, dedup) :: ds) - case notMerged => - allMerged(EmptyTypedPipe, stack, notMerged :: acc, ds) - } - val (unmerged, ds) = allMerged(left, right :: Nil, Nil, Nil) - // check for repeated pipes - val uniquePipes: List[TypedPipe[T]] = unmerged - .groupBy(identity) - .mapValues(_.size) - .map { - case (pipe, 1) => pipe - case (pipe, cnt) => pipe.flatMap(List.fill(cnt)(_).iterator) - } - .toList - - uniquePipes match { - case Nil => loop(EmptyTypedPipe, rest, ds ::: descriptions) - case h :: Nil => loop(h, rest, ds ::: descriptions) - case h :: tail => - // push all the remaining flatmaps up: - // TODO: a better optimization is to not materialize this - // node at all if there is no fan out since groupBy and cogroupby - // can accept multiple inputs - val headPipe = loop(h, rest, Nil) - val tailPipes = tail.map(loop(_, rest, Nil)) - val merged = RichPipe.mergeAvoidingHashes(headPipe, tailPipes) - applyDescriptions(merged, ds ::: descriptions) - } - case src@SourcePipe(_) => - def go[A](sp: SourcePipe[A]): CascadingPipe[A] = - cacheGet[A](sp, mode) { implicit localFD => - val source = sp.source - val pipe = source.read(localFD, mode) - CascadingPipe[A](pipe, source.sourceFields, localFD, source.converter[A]) - } - finish(go(src), rest, descriptions) - - case slk@SumByLocalKeys(_, _) => - def sum[K, V](sblk: SumByLocalKeys[K, V]): CascadingPipe[(K, V)] = - cacheGet(sblk, mode) { implicit localFD => - val pairPipe = toPipe(sblk.input, kvFields)(localFD, mode, tup2Setter) - val msr = new MapsideReduce(sblk.semigroup, new Fields("key"), valueField, None)(singleConverter[V], singleSetter[V]) - val kvpipe = RichPipe(pairPipe).eachTo(kvFields -> kvFields) { _ => msr } - CascadingPipe(kvpipe, kvFields, localFD, tuple2Converter) - } - finish(sum(slk), rest, descriptions) + // Now that we have an optimized pipe, convert it to a Pipe + toPipeUnoptimized(p1, fieldNames) + } - case tp@TrappedPipe(_, _, _) => - def go[T0, T1 >: T0](tp: TrappedPipe[T0, T1], r: FlatMappedFn[T1, U]): Pipe = { - val cp = cacheGet(tp, mode) { implicit fd => - val sfields = tp.sink.sinkFields - // TODO: with diamonds in the graph, this might not be correct - val pp = toPipe[T0](tp.input, sfields)(fd, mode, tp.sink.setter) - val pipe = RichPipe.assignName(pp) - flowDef.addTrap(pipe, tp.sink.createTap(Write)(mode)) - CascadingPipe[T1](pipe, sfields, fd, tp.conv) - } - finish(cp, r, descriptions) - } - go(tp, rest) - - case WithDescriptionTypedPipe(pipe, description, dedup) => - loop(pipe, rest, (description, dedup) :: descriptions) - - case WithOnComplete(pipe, fn) => - val planned = loop(pipe, rest, descriptions) - new Each(planned, Fields.ALL, new CleanupIdentityFunction(fn), Fields.REPLACE) - - case hcg@HashCoGroup(_, _, _) => - def go[K, V1, V2, R](hcg: HashCoGroup[K, V1, V2, R]): Pipe = { - // TODO we can push up filterKeys on both the left and right - // and mapValues/flatMapValues on the result - val cp = cacheGet(hcg, mode) { implicit fd => - val kvPipe = planHashJoin(hcg.left, - hcg.right, - hcg.joiner, - hcg.right.keyOrdering, - fd, - mode) - CascadingPipe(kvPipe, kvFields, fd, tuple2Converter[K, R]) - } - finish(cp, rest, descriptions) - } - go(hcg) - - case cgp@CoGroupedPipe(_) => - def go[K, V](cgp: CoGroupedPipe[K, V]): Pipe = { - // TODO we can push up filterKeys on both the left and right - // and mapValues/flatMapValues on the result - val cp = cacheGet(cgp, mode) { implicit fd => - val kvPipe = planCoGroup(cgp.cogrouped, fd, mode) - CascadingPipe(kvPipe, kvFields, fd, tuple2Converter[K, V]) - } - finish(cp, rest, descriptions) - } - go(cgp) + /** + * This converts the TypedPipe to a cascading Pipe doing the most direct + * possible translation we can. This is useful for testing or for expert + * cases where you want more direct control of the TypedPipe than + * the default method gives you. + */ + final def toPipeUnoptimized[U](p: TypedPipe[U], + fieldNames: Fields)(implicit flowDef: FlowDef, mode: Mode, setter: TupleSetter[U]): Pipe = { - case r@ReduceStepPipe(_) => - planReduceStep(r, mode) match { - case Right(cp) => finish(cp, rest, descriptions) - case Left(tp) => loop(tp, rest, descriptions) - } - } + val compiler = cache.get(flowDef, mode) + val cp: CascadingPipe[U] = compiler(p) - RichPipe(loop(p, FlatMappedFn.identity[U], Nil)).applyFlowConfigProperties(flowDef) + RichPipe(cp.toPipe(fieldNames, flowDef, TupleSetter.asSubSetter(setter))) + // TODO: this indirection may not be needed anymore, we could directly track config changes + // rather than using FlowStateMap. This is the only call of this method, so maybe we can + // remove it. + .applyFlowConfigProperties(flowDef) } - private def planCoGroup[K, R](cg: CoGrouped[K, R], flowDef: FlowDef, mode: Mode): Pipe = { - import cg._ + private def planCoGroup[K, R](cg: CoGrouped[K, R], rec: FunctionK[TypedPipe, CascadingPipe]): CascadingPipe[(K, R)] = { + + // This has the side effect of planning all inputs now + // before we need to call them below + val inputsCR = cg.inputs.map(rec(_)) + import cg.{inputs, joinFunction} // Cascading handles the first item in join differently, we have to see if it is repeated val firstCount = inputs.count(_ == inputs.head) import Dsl._ import RichPipe.assignName + val flowDef = new FlowDef + + def toPipe[A, B](t: TypedPipe[(A, B)], f: Fields, setter: TupleSetter[(A, B)]): Pipe = + rec(t).toPipe(f, flowDef, TupleSetter.asSubSetter(setter)) /* * we only want key and value. * Cascading requires you have the same number coming in as out. @@ -373,7 +421,7 @@ object CascadingBackend { List("key", "value") ++ (0 until (2 * (inCount - 1))).map("null%d".format(_)) // Make this stable so the compiler does not make a closure - val ord = keyOrdering + val ord = cg.keyOrdering val newPipe = maybeBox[K, Any](ord, flowDef) { (tupset, ordKeyField) => if (firstCount == inputs.size) { @@ -385,8 +433,7 @@ object CascadingBackend { * not repeated. That case is below */ val NUM_OF_SELF_JOINS = firstCount - 1 - new CoGroup(assignName(toPipe[(K, Any)](inputs.head, kvFields)(flowDef, mode, - tupset)), + new CoGroup(assignName(toPipe[K, Any](inputs.head, kvFields, tupset)), ordKeyField, NUM_OF_SELF_JOINS, outFields(firstCount), @@ -400,8 +447,7 @@ object CascadingBackend { * This is handled by a different CoGroup constructor than the above case. */ def renamePipe(idx: Int, p: TypedPipe[(K, Any)]): Pipe = - toPipe[(K, Any)](p, List(keyId(idx), "value%d".format(idx)))(flowDef, mode, - tupset) + toPipe[K, Any](p, List(keyId(idx), "value%d".format(idx)), tupset) // This is tested for the properties we need (non-reordering) val distincts = CoGrouped.distinctBy(inputs)(identity) @@ -456,108 +502,65 @@ object CascadingBackend { * are null. We then project out at the end of the method. */ val pipeWithRedAndDescriptions = { - RichPipe.setReducers(newPipe, reducers.getOrElse(-1)) - RichPipe.setPipeDescriptions(newPipe, descriptions) + RichPipe.setReducers(newPipe, cg.reducers.getOrElse(-1)) + RichPipe.setPipeDescriptions(newPipe, cg.descriptions) newPipe.project(kvFields) } - pipeWithRedAndDescriptions + + CascadingPipe[(K, R)]( + pipeWithRedAndDescriptions, + kvFields, + flowDef, + tuple2Converter[K, R]) } + /** + * TODO: most of the complexity of this method should be rewritten + * as an optimization rule that works on the scalding typed AST. + * the code in here gets pretty complex and depending on the details + * of cascading and also how we compile to cascading. + * + * But the optimization is somewhat general: we often want a checkpoint + * before a hashjoin is replicated + */ private def planHashJoin[K, V1, V2, R](left: TypedPipe[(K, V1)], right: HashJoinable[K, V2], joiner: (K, V1, Iterable[V2]) => Iterator[R], - keyOrdering: Ordering[K], - fd: FlowDef, - mode: Mode): Pipe = { - - val getHashJoinAutoForceRight: Boolean = - mode match { - case h: HadoopMode => - val config = Config.fromHadoop(h.jobConf) - config.getHashJoinAutoForceRight - case _ => false //default to false - } + rec: FunctionK[TypedPipe, CascadingPipe]): CascadingPipe[(K, R)] = { - /** - * Checks the transform to deduce if it is safe to skip the force to disk. - * If the FlatMappedFn is an identity operation then we can skip - * For map and flatMap we can't definitively infer if it is OK to skip the forceToDisk. - * Thus we just go ahead and forceToDisk in those two cases - users can opt out if needed. - */ - def canSkipEachOperation(eachOperation: Operation[_]): Boolean = - eachOperation match { - case f: FlatMapFunction[_, _] => - f.getFunction match { - case fmp: FlatMappedFn[_, _] if (FlatMappedFn.asId(fmp).isDefined) => - // This is an operation that is doing nothing - true - case _ => - false - } - case _: CleanupIdentityFunction => true - case _ => false - } + val fd = new FlowDef + val leftPipe = rec(left).toPipe(kvFields, fd, tup2Setter) + val mappedPipe = rec(right.mapped).toPipe(new Fields("key1", "value1"), fd, tup2Setter) - /** - * Computes if it is safe to skip a force to disk (only if the user hasn't turned this off using - * Config.HashJoinAutoForceRight). - * If we know the pipe is persisted,we can safely skip. If the Pipe is an Each operator, we check - * if the function it holds can be skipped and we recurse to check its parent pipe. - * Recursion handles situations where we have a couple of Each ops in a row. - * For example: pipe.forceToDisk.onComplete results in: Each -> Each -> Checkpoint - */ - def isSafeToSkipForceToDisk(pipe: Pipe): Boolean = { - import cascading.pipe._ - - pipe match { - case eachPipe: Each => - if (canSkipEachOperation(eachPipe.getOperation)) { - //need to recurse down to see if parent pipe is ok - RichPipe.getSinglePreviousPipe(eachPipe).exists(prevPipe => isSafeToSkipForceToDisk(prevPipe)) - } else false - case _: Checkpoint => true - case _: GroupBy => true - case _: CoGroup => true - case _: Every => true - case p if RichPipe.isSourcePipe(p) => true - case _ => false - } - } - /** - * Returns a Pipe for the mapped (rhs) pipe with checkpointing (forceToDisk) applied if needed. - * Currently we skip checkpointing if we're confident that the underlying rhs Pipe is persisted - * (e.g. a source / Checkpoint / GroupBy / CoGroup / Every) and we have 0 or more Each operator Fns - * that are not doing any real work (e.g. Converter, CleanupIdentityFunction) - */ - val getForceToDiskPipeIfNecessary: Pipe = { - val mappedPipe = toPipe(right.mapped, new Fields("key1", "value1"))(fd, mode, tup2Setter) - - // if the user has turned off auto force right, we fall back to the old behavior and - //just return the mapped pipe - if (!getHashJoinAutoForceRight || isSafeToSkipForceToDisk(mappedPipe)) mappedPipe - else RichPipe(mappedPipe).forceToDisk - } - - new HashJoin( - RichPipe.assignName(toPipe(left, kvFields)(fd, mode, tup2Setter)), + val keyOrdering = right.keyOrdering + val hashPipe = new HashJoin( + RichPipe.assignName(leftPipe), Field.singleOrdered("key")(keyOrdering), - getForceToDiskPipeIfNecessary, + mappedPipe, Field.singleOrdered("key1")(keyOrdering), WrappedJoiner(new HashJoiner(right.joinFunction, joiner))) + + CascadingPipe[(K, R)]( + hashPipe, + kvFields, + fd, + tuple2Converter[K, R]) } - private def planReduceStep[K, V1, V2](rsp: ReduceStepPipe[K, V1, V2], - mode: Mode): Either[TypedPipe[(K, V2)], CascadingPipe[(K, V2)]] = { + private def planReduceStep[K, V1, V2]( + rs: ReduceStep[K, V1, V2], + rec: FunctionK[TypedPipe, CascadingPipe]): CascadingPipe[(K, V2)] = { - val rs = rsp.reduce + val mapped = rec(rs.mapped) - def groupOp(gb: GroupBuilder => GroupBuilder): CascadingPipe[(K, V2)] = + def groupOp(gb: GroupBuilder => GroupBuilder): CascadingPipe[_ <: (K, V2)] = groupOpWithValueSort(None)(gb) - def groupOpWithValueSort(valueSort: Option[Ordering[_ >: V1]])(gb: GroupBuilder => GroupBuilder): CascadingPipe[(K, V2)] = { - def pipe(flowDef: FlowDef) = maybeBox[K, V1](rs.keyOrdering, flowDef) { (tupleSetter, fields) => + def groupOpWithValueSort(valueSort: Option[Ordering[_ >: V1]])(gb: GroupBuilder => GroupBuilder): CascadingPipe[_ <: (K, V2)] = { + val flowDef = new FlowDef + val pipe = maybeBox[K, V1](rs.keyOrdering, flowDef) { (tupleSetter, fields) => val (sortOpt, ts) = valueSort.map { - case ordser: OrderedSerialization[V1] => + case ordser: OrderedSerialization[V1 @unchecked] => // We get in here when we do a secondary sort // and that sort is an ordered serialization // We now need a boxed serializer for this type @@ -572,7 +575,7 @@ object CascadingBackend { (Some(vord), tupleSetter) }.getOrElse((None, tupleSetter)) - val p = toPipe(rs.mapped, kvFields)(flowDef, mode, TupleSetter.asSubSetter(ts)) + val p = mapped.toPipe(kvFields, flowDef, TupleSetter.asSubSetter(ts)) RichPipe(p).groupBy(fields) { inGb => val withSort = sortOpt.fold(inGb)(inGb.sortBy) @@ -580,27 +583,25 @@ object CascadingBackend { } } - pipeCache.cacheGet(rsp, mode) { implicit fd => - val tupConv = tuple2Conv[K, V2](rs.keyOrdering) - CascadingPipe(pipe(fd), kvFields, fd, tupConv) - } + val tupConv = tuple2Conv[K, V2](rs.keyOrdering) + CascadingPipe(pipe, kvFields, flowDef, tupConv) } rs match { - case IdentityReduce(_, inp, None, descriptions) => + case IdentityReduce(_, _, None, descriptions) => // Not doing anything - Left(descriptions.foldLeft(inp)(_.withDescription(_))) - case UnsortedIdentityReduce(_, inp, None, descriptions) => + mapped.copy(pipe = RichPipe.setPipeDescriptions(mapped.pipe, descriptions)).asInstanceOf[CascadingPipe[_ <: (K, V2)]] + case UnsortedIdentityReduce(_, _, None, descriptions) => // Not doing anything - Left(descriptions.foldLeft(inp)(_.withDescription(_))) + mapped.copy(pipe = RichPipe.setPipeDescriptions(mapped.pipe, descriptions)).asInstanceOf[CascadingPipe[_ <: (K, V2)]] case IdentityReduce(_, _, Some(reds), descriptions) => - Right(groupOp { _.reducers(reds).setDescriptions(descriptions) }) + groupOp { _.reducers(reds).setDescriptions(descriptions) } case UnsortedIdentityReduce(_, _, Some(reds), descriptions) => // This is weird, but it is sometimes used to force a partition - Right(groupOp { _.reducers(reds).setDescriptions(descriptions) }) + groupOp { _.reducers(reds).setDescriptions(descriptions) } case ivsr@IdentityValueSortedReduce(_, _, _, _, _) => // in this case we know that V1 =:= V2 - Right(groupOpWithValueSort(Some(ivsr.valueSort.asInstanceOf[Ordering[_ >: V1]])) { gb => + groupOpWithValueSort(Some(ivsr.valueSort.asInstanceOf[Ordering[_ >: V1]])) { gb => // If its an ordered serialization we need to unbox val mappedGB = if (ivsr.valueSort.isInstanceOf[OrderedSerialization[_]]) @@ -613,10 +614,10 @@ object CascadingBackend { mappedGB .reducers(ivsr.reducers.getOrElse(-1)) .setDescriptions(ivsr.descriptions) - }) + } case vsr@ValueSortedReduce(_, _, _, _, _, _) => val optVOrdering = Some(vsr.valueSort) - Right(groupOpWithValueSort(optVOrdering) { + groupOpWithValueSort(optVOrdering) { // If its an ordered serialization we need to unbox // the value before handing it to the users operation _.every(new cascading.pipe.Every(_, valueField, @@ -626,14 +627,14 @@ object CascadingBackend { valueField), Fields.REPLACE)) .reducers(vsr.reducers.getOrElse(-1)) .setDescriptions(vsr.descriptions) - }) + } case imr@IteratorMappedReduce(_, _, _, _, _) => - Right(groupOp { + groupOp { _.every(new cascading.pipe.Every(_, valueField, new TypedBufferOp(keyConverter(imr.keyOrdering), TupleConverter.singleConverter[V1], imr.reduceFn, valueField), Fields.REPLACE)) .reducers(imr.reducers.getOrElse(-1)) .setDescriptions(imr.descriptions) - }) + } } } } diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/cascading_backend/CoGroupJoiner.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/cascading_backend/CoGroupJoiner.scala index 37bc55c7d4..a9dc954423 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/cascading_backend/CoGroupJoiner.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/cascading_backend/CoGroupJoiner.scala @@ -37,7 +37,7 @@ abstract class CoGroupedJoiner[K](inputSize: Int, def unbox(it: Iterator[CTuple]): Iterator[Any] = it.map(_.getObject(1): Any) - val leftMost = unbox(iters.head) + val leftMost = unbox(iters.head) // linter:disable:UndesirableTypeInference def toIterable(didx: Int) = new Iterable[Any] { diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/EqTypes.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/EqTypes.scala new file mode 100644 index 0000000000..11c54fe7b5 --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/EqTypes.scala @@ -0,0 +1,37 @@ +package com.twitter.scalding.typed.functions + +/** + * This is a more powerful version of =:= that can allow + * us to remove casts and also not have any runtime cost + * for our function calls in some cases of trivial functions + */ +sealed abstract class EqTypes[A, B] extends java.io.Serializable { + def apply(a: A): B + def subst[F[_]](f: F[A]): F[B] + + final def reverse: EqTypes[B, A] = { + val aa = EqTypes.reflexive[A] + type F[T] = EqTypes[T, A] + subst[F](aa) + } + + def toEv: A =:= B = { + val aa = implicitly[A =:= A] + type F[T] = A =:= T + subst[F](aa) + } +} + +object EqTypes extends java.io.Serializable { + private[this] final case class ReflexiveEquality[A]() extends EqTypes[A, A] { + def apply(a: A): A = a + def subst[F[_]](f: F[A]): F[A] = f + } + + implicit def reflexive[A]: EqTypes[A, A] = ReflexiveEquality() + + def fromEv[A, B](ev: A =:= B): EqTypes[A, B] = // linter:disable:UnusedParameter + // in scala 2.13, this won't need a cast, but the cast is safe + reflexive[A].asInstanceOf[EqTypes[A, B]] +} + diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/FlatMappedFn.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/FlatMappedFn.scala new file mode 100644 index 0000000000..ef1bb6c6da --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/FlatMappedFn.scala @@ -0,0 +1,126 @@ +/* +Copyright 2013 Twitter, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package com.twitter.scalding.typed.functions + +import java.io.Serializable + +/** + * This is a composition of one or more FlatMappings + * + * For some reason, this fails in scala 2.12 if this is an abstract class + */ +sealed trait FlatMappedFn[-A, +B] extends (A => TraversableOnce[B]) with java.io.Serializable { + import FlatMappedFn._ + + final def runAfter[Z](fn: FlatMapping[Z, A]): FlatMappedFn[Z, B] = this match { + case Single(FlatMapping.Identity(ev)) => + type F[T] = FlatMapping[Z, T] + Single(ev.subst[F](fn)) + case notId => fn match { + case FlatMapping.Identity(ev) => + type F[T] = FlatMappedFn[T, B] + ev.reverse.subst[F](this) + case notIdFn => Series(notIdFn, notId) // only make a Series without either side being identity + } + } + + final def combine[C](next: FlatMappedFn[B, C]): FlatMappedFn[A, C] = { + /* + * We have to reassociate so the front of the series has the + * first flatmap, so we can bail out early when there are no more + * items in any flatMap result. + */ + def loop[X, Y, Z](fn0: FlatMappedFn[X, Y], fn1: FlatMappedFn[Y, Z]): FlatMappedFn[X, Z] = + fn0 match { + case Single(FlatMapping.Identity(ev)) => + type F[T] = FlatMappedFn[T, Z] + ev.reverse.subst[F](fn1) + case Single(f0) => + Series(f0, fn1) + case Series(f0f, f1f) => + Series(f0f, loop(f1f, fn1)) + } + loop(this, next) + } + + /** + * We interpret this composition once to minimize pattern matching when we execute + */ + private[this] val toFn: A => TraversableOnce[B] = { + import FlatMapping._ + + def loop[A1, B1](fn: FlatMappedFn[A1, B1]): A1 => TraversableOnce[B1] = fn match { + case Single(Identity(ev)) => + val const: A1 => TraversableOnce[A1] = FlatMapFunctions.FromIdentity[A1]() + type F[T] = A1 => TraversableOnce[T] + ev.subst[F](const) + case Single(Filter(f, ev)) => + val filter: A1 => TraversableOnce[A1] = FlatMapFunctions.FromFilter(f) + type F[T] = A1 => TraversableOnce[T] + ev.subst[F](filter) + case Single(Map(f)) => FlatMapFunctions.FromMap(f) + case Single(FlatM(f)) => f + case Series(Identity(ev), rest) => + type F[T] = T => TraversableOnce[B1] + ev.subst[F](loop(rest)) + case Series(Filter(f, ev), rest) => + type F[T] = T => TraversableOnce[B1] + val next = ev.subst[F](loop(rest)) // linter:disable:UndesirableTypeInference + + FlatMapFunctions.FromFilterCompose(f, next) + case Series(Map(f), rest) => + val next = loop(rest) // linter:disable:UndesirableTypeInference + FlatMapFunctions.FromMapCompose(f, next) + case Series(FlatM(f), rest) => + val next = loop(rest) // linter:disable:UndesirableTypeInference + FlatMapFunctions.FromFlatMapCompose(f, next) + } + + loop(this) + } + + def apply(a: A): TraversableOnce[B] = toFn(a) +} + +object FlatMappedFn { + + def asId[A, B](f: FlatMappedFn[A, B]): Option[EqTypes[_ >: A, _ <: B]] = f match { + case Single(FlatMapping.Identity(ev)) => Some(ev) + case _ => None + } + + def asFilter[A, B](f: FlatMappedFn[A, B]): Option[(A => Boolean, EqTypes[(_ >: A), (_ <: B)])] = f match { + case Single(filter@FlatMapping.Filter(_, _)) => Some((filter.fn, filter.ev)) + case _ => None + } + + def apply[A, B](fn: A => TraversableOnce[B]): FlatMappedFn[A, B] = + fn match { + case fmf: FlatMappedFn[A, B] => fmf + case rawfn => Single(FlatMapping.FlatM(rawfn)) + } + + def identity[T]: FlatMappedFn[T, T] = Single(FlatMapping.Identity[T, T](EqTypes.reflexive[T])) + + def fromFilter[A](fn: A => Boolean): FlatMappedFn[A, A] = + Single(FlatMapping.Filter[A, A](fn, EqTypes.reflexive)) + + def fromMap[A, B](fn: A => B): FlatMappedFn[A, B] = + Single(FlatMapping.Map(fn)) + + final case class Single[A, B](fn: FlatMapping[A, B]) extends FlatMappedFn[A, B] + final case class Series[A, B, C](first: FlatMapping[A, B], next: FlatMappedFn[B, C]) extends FlatMappedFn[A, C] +} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/FlatMapping.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/FlatMapping.scala new file mode 100644 index 0000000000..a322ef832a --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/FlatMapping.scala @@ -0,0 +1,25 @@ +package com.twitter.scalding.typed.functions + +import java.io.Serializable + +/** + * This is one of 4 core, non composed operations: + * identity + * filter + * map + * flatMap + */ +sealed abstract class FlatMapping[-A, +B] extends java.io.Serializable +object FlatMapping { + def filter[A](fn: A => Boolean): FlatMapping[A, A] = + Filter[A, A](fn, implicitly) + + def filterKeys[K, V](fn: K => Boolean): FlatMapping[(K, V), (K, V)] = + filter[(K, V)](FilterKeysToFilter(fn)) + + final case class Identity[A, B](ev: EqTypes[A, B]) extends FlatMapping[A, B] + final case class Filter[A, B](fn: A => Boolean, ev: EqTypes[A, B]) extends FlatMapping[A, B] + final case class Map[A, B](fn: A => B) extends FlatMapping[A, B] + final case class FlatM[A, B](fn: A => TraversableOnce[B]) extends FlatMapping[A, B] +} + diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/Functions.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/Functions.scala new file mode 100644 index 0000000000..ad7d8dcc17 --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/Functions.scala @@ -0,0 +1,274 @@ +package com.twitter.scalding.typed.functions + +import com.twitter.algebird.{ Aggregator, Ring, Semigroup, Fold } +import java.util.Random + +case class Constant[T](result: T) extends Function1[Any, T] { + def apply(a: Any) = result +} + +case class WithConstant[A, B](constant: B) extends Function1[A, (A, B)] { + def apply(a: A) = (a, constant) +} + +case class MakeKey[K, V](fn: V => K) extends Function1[V, (K, V)] { + def apply(v: V) = (fn(v), v) +} + +case class PartialFunctionToFilter[A, B](fn: PartialFunction[A, B]) extends Function1[A, Boolean] { + def apply(a: A) = fn.isDefinedAt(a) +} + +case class MapValueStream[A, B](fn: Iterator[A] => Iterator[B]) extends Function2[Any, Iterator[A], Iterator[B]] { + def apply(k: Any, vs: Iterator[A]) = fn(vs) +} + +case class Drop[A](count: Int) extends Function1[Iterator[A], Iterator[A]] { + def apply(as: Iterator[A]) = as.drop(count) +} +case class DropWhile[A](fn: A => Boolean) extends Function1[Iterator[A], Iterator[A]] { + def apply(as: Iterator[A]) = as.dropWhile(fn) +} + +case class Take[A](count: Int) extends Function1[Iterator[A], Iterator[A]] { + def apply(as: Iterator[A]) = as.take(count) +} + +case class TakeWhile[A](fn: A => Boolean) extends Function1[Iterator[A], Iterator[A]] { + def apply(as: Iterator[A]) = as.takeWhile(fn) +} + +case class Identity[A, B](eqTypes: EqTypes[A, B]) extends Function1[A, B] { + def apply(a: A) = eqTypes(a) +} + +object Identity { + def apply[A](): Identity[A, A] = Identity[A, A](EqTypes.reflexive[A]) +} + +case class Widen[A, B](subTypes: SubTypes[A, B]) extends Function1[A, B] { + def apply(a: A) = subTypes(a) +} + +case class GetKey[K]() extends Function1[(K, Any), K] { + def apply(kv: (K, Any)) = kv._1 +} + +case class GetValue[V]() extends Function1[(Any, V), V] { + def apply(kv: (Any, V)) = kv._2 +} + +case class Swap[A, B]() extends Function1[(A, B), (B, A)] { + def apply(ab: (A, B)) = (ab._2, ab._1) +} + +case class SumAll[T](sg: Semigroup[T]) extends Function1[TraversableOnce[T], Iterator[T]] { + def apply(ts: TraversableOnce[T]) = sg.sumOption(ts).iterator +} + +case class AggPrepare[A, B, C](agg: Aggregator[A, B, C]) extends Function1[A, B] { + def apply(a: A) = agg.prepare(a) +} + +case class AggPresent[A, B, C](agg: Aggregator[A, B, C]) extends Function1[B, C] { + def apply(a: B) = agg.present(a) +} + +case class FoldLeftIterator[A, B](init: B, fold: (B, A) => B) extends Function1[Iterator[A], Iterator[B]] { + def apply(as: Iterator[A]) = Iterator.single(as.foldLeft(init)(fold)) +} + +case class ScanLeftIterator[A, B](init: B, fold: (B, A) => B) extends Function1[Iterator[A], Iterator[B]] { + def apply(as: Iterator[A]) = as.scanLeft(init)(fold) +} + +case class FoldIterator[A, B](fold: Fold[A, B]) extends Function1[Iterator[A], Iterator[B]] { + def apply(as: Iterator[A]) = Iterator.single(fold.overTraversable(as)) +} + +case class FoldWithKeyIterator[K, A, B](foldfn: K => Fold[A, B]) extends Function2[K, Iterator[A], Iterator[B]] { + def apply(k: K, as: Iterator[A]) = Iterator.single(foldfn(k).overTraversable(as)) +} + +case class AsRight[A, B]() extends Function1[B, Either[A, B]] { + def apply(b: B) = Right(b) +} + +case class AsLeft[A, B]() extends Function1[A, Either[A, B]] { + def apply(b: A) = Left(b) +} + +case class TuplizeFunction[A, B, C](fn: (A, B) => C) extends Function1[(A, B), C] { + def apply(ab: (A, B)) = fn(ab._1, ab._2) +} + +case class DropValue1[A, B, C]() extends Function1[(A, (B, C)), (A, C)] { + def apply(abc: (A, (B, C))) = (abc._1, abc._2._2) +} + +case class RandomNextInt(seed: Long, modulus: Int) extends Function1[Any, Int] { + private[this] lazy val rng = new Random(seed) + def apply(a: Any) = rng.nextInt(modulus) +} + +case class RandomFilter(seed: Long, fraction: Double) extends Function1[Any, Boolean] { + private[this] lazy val rng = new Random(seed) + def apply(a: Any) = rng.nextDouble < fraction +} + +case class Count[T](fn: T => Boolean) extends Function1[T, Long] { + def apply(t: T) = if (fn(t)) 1L else 0L +} + +case class SizeOfSet[T]() extends Function1[Set[T], Long] { + def apply(s: Set[T]) = s.size.toLong +} + +case class HeadSemigroup[T]() extends Semigroup[T] { + def plus(a: T, b: T) = a + // Don't enumerate every item, just take the first + override def sumOption(to: TraversableOnce[T]): Option[T] = + if (to.isEmpty) None + else Some(to.toIterator.next) +} + +case class SemigroupFromFn[T](fn: (T, T) => T) extends Semigroup[T] { + def plus(a: T, b: T) = fn(a, b) +} + +case class SemigroupFromProduct[T](ring: Ring[T]) extends Semigroup[T] { + def plus(a: T, b: T) = ring.times(a, b) +} + +case class ConsList[T]() extends Function1[(T, List[T]), List[T]] { + def apply(results: (T, List[T])) = results._1 :: results._2 +} + +case class ReverseList[T]() extends Function1[List[T], List[T]] { + def apply(results: List[T]) = results.reverse +} + +case class ToList[A]() extends Function1[A, List[A]] { + def apply(a: A) = a :: Nil +} + +case class ToSet[A]() extends Function1[A, Set[A]] { + // this allows us to access Set1 without boxing into varargs + private[this] val empty = Set.empty[A] + def apply(a: A) = empty + a +} + +case class MaxOrd[A, B >: A](ord: Ordering[B]) extends Function2[A, A, A] { + def apply(a1: A, a2: A) = + if (ord.lt(a1, a2)) a2 else a1 +} + +case class MaxOrdBy[A, B](fn: A => B, ord: Ordering[B]) extends Function2[A, A, A] { + def apply(a1: A, a2: A) = + if (ord.lt(fn(a1), fn(a2))) a2 else a1 +} + +case class MinOrd[A, B >: A](ord: Ordering[B]) extends Function2[A, A, A] { + def apply(a1: A, a2: A) = + if (ord.lt(a1, a2)) a1 else a2 +} + +case class MinOrdBy[A, B](fn: A => B, ord: Ordering[B]) extends Function2[A, A, A] { + def apply(a1: A, a2: A) = + if (ord.lt(fn(a1), fn(a2))) a1 else a2 +} + +case class FilterKeysToFilter[K](fn: K => Boolean) extends Function1[(K, Any), Boolean] { + def apply(kv: (K, Any)) = fn(kv._1) +} + +case class FlatMapValuesToFlatMap[K, A, B](fn: A => TraversableOnce[B]) extends Function1[(K, A), TraversableOnce[(K, B)]] { + def apply(ka: (K, A)) = { + val k = ka._1 + fn(ka._2).map((k, _)) + } +} + +case class MapValuesToMap[K, A, B](fn: A => B) extends Function1[(K, A), (K, B)] { + def apply(ka: (K, A)) = (ka._1, fn(ka._2)) +} + +case class EmptyGuard[K, A, B](fn: (K, Iterator[A]) => Iterator[B]) extends Function2[K, Iterator[A], Iterator[B]] { + def apply(k: K, as: Iterator[A]) = + if (as.nonEmpty) fn(k, as) else Iterator.empty +} + +case class FilterGroup[A, B](fn: ((A, B)) => Boolean) extends Function2[A, Iterator[B], Iterator[B]] { + def apply(a: A, bs: Iterator[B]) = bs.filter(fn(a, _)) +} + +case class MapGroupMapValues[A, B, C](fn: B => C) extends Function2[A, Iterator[B], Iterator[C]] { + def apply(a: A, bs: Iterator[B]) = bs.map(fn) +} + +case class MapGroupFlatMapValues[A, B, C](fn: B => TraversableOnce[C]) extends Function2[A, Iterator[B], Iterator[C]] { + def apply(a: A, bs: Iterator[B]) = bs.flatMap(fn) +} + +object FlatMapFunctions { + case class FromIdentity[A]() extends Function1[A, Iterator[A]] { + def apply(a: A) = Iterator.single(a) + } + case class FromFilter[A](fn: A => Boolean) extends Function1[A, Iterator[A]] { + def apply(a: A) = if (fn(a)) Iterator.single(a) else Iterator.empty + } + case class FromMap[A, B](fn: A => B) extends Function1[A, Iterator[B]] { + def apply(a: A) = Iterator.single(fn(a)) + } + case class FromFilterCompose[A, B](fn: A => Boolean, next: A => TraversableOnce[B]) extends Function1[A, TraversableOnce[B]] { + def apply(a: A) = if (fn(a)) next(a) else Iterator.empty + } + case class FromMapCompose[A, B, C](fn: A => B, next: B => TraversableOnce[C]) extends Function1[A, TraversableOnce[C]] { + def apply(a: A) = next(fn(a)) + } + case class FromFlatMapCompose[A, B, C](fn: A => TraversableOnce[B], next: B => TraversableOnce[C]) extends Function1[A, TraversableOnce[C]] { + def apply(a: A) = fn(a).flatMap(next) + } +} + +object ComposedFunctions { + + case class ComposedMapFn[A, B, C](fn0: A => B, fn1: B => C) extends Function1[A, C] { + def apply(a: A) = fn1(fn0(a)) + } + case class ComposedFilterFn[-A](fn0: A => Boolean, fn1: A => Boolean) extends Function1[A, Boolean] { + def apply(a: A) = fn0(a) && fn1(a) + } + /** + * This is only called at the end of a task, so might as well make it stack safe since a little + * extra runtime cost won't matter + */ + case class ComposedOnComplete(fn0: () => Unit, fn1: () => Unit) extends Function0[Unit] { + def apply(): Unit = { + @annotation.tailrec + def loop(fn: () => Unit, stack: List[() => Unit]): Unit = + fn match { + case ComposedOnComplete(left, right) => loop(left, right :: stack) + case notComposed => + notComposed() + stack match { + case h :: tail => loop(h, tail) + case Nil => () + } + } + + loop(fn0, List(fn1)) + } + } + + case class ComposedMapGroup[A, B, C, D]( + f: (A, Iterator[B]) => Iterator[C], + g: (A, Iterator[C]) => Iterator[D]) extends Function2[A, Iterator[B], Iterator[D]] { + + def apply(a: A, bs: Iterator[B]) = { + val cs = f(a, bs) + if (cs.nonEmpty) g(a, cs) + else Iterator.empty + } + } +} diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/SubTypes.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/SubTypes.scala new file mode 100644 index 0000000000..c81deed3b2 --- /dev/null +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/functions/SubTypes.scala @@ -0,0 +1,42 @@ +package com.twitter.scalding.typed.functions + +/** + * This is a more powerful version of <:< that can allow + * us to remove casts and also not have any runtime cost + * for our function calls in some cases of trivial functions + */ +sealed abstract class SubTypes[-A, +B] extends java.io.Serializable { + def apply(a: A): B + def subst[F[-_]](f: F[B]): F[A] + + def toEv: A <:< B = { + val aa = implicitly[B <:< B] + type F[-T] = T <:< B + subst[F](aa) + } + + def liftCo[F[+_]]: SubTypes[F[A], F[B]] = { + type G[-T] = SubTypes[F[T], F[B]] + subst[G](SubTypes.fromSubType[F[B], F[B]]) + } + /** create a new evidence for a contravariant type F[_] + */ + def liftContra[F[-_]]: SubTypes[F[B], F[A]] = { + type G[-T] = SubTypes[F[B], F[T]] + subst[G](SubTypes.fromSubType[F[B], F[B]]) + } +} + +object SubTypes extends java.io.Serializable { + private[this] final case class ReflexiveSubTypes[A]() extends SubTypes[A, A] { + def apply(a: A): A = a + def subst[F[-_]](f: F[A]): F[A] = f + } + + implicit def fromSubType[A, B >: A]: SubTypes[A, B] = ReflexiveSubTypes[A]() + + def fromEv[A, B](ev: A <:< B): SubTypes[A, B] = // linter:disable:UnusedParameter + // in scala 2.13, this won't need a cast, but the cast is safe + fromSubType[A, A].asInstanceOf[SubTypes[A, B]] +} + diff --git a/scalding-core/src/main/scala/com/twitter/scalding/typed/memory_backend/MemoryBackend.scala b/scalding-core/src/main/scala/com/twitter/scalding/typed/memory_backend/MemoryBackend.scala index 60d5da3054..d128917df6 100644 --- a/scalding-core/src/main/scala/com/twitter/scalding/typed/memory_backend/MemoryBackend.scala +++ b/scalding-core/src/main/scala/com/twitter/scalding/typed/memory_backend/MemoryBackend.scala @@ -125,7 +125,7 @@ object MemoryPlanner { def source[I](i: Iterable[I]): Op[I] = Source(_ => Future.successful(i)) def empty[I]: Op[I] = source(Nil) - case class Source[I](input: ConcurrentExecutionContext => Future[Iterable[I]]) extends Op[I] { + final case class Source[I](input: ConcurrentExecutionContext => Future[Iterable[I]]) extends Op[I] { private[this] val promise: Promise[ArrayBuffer[I]] = Promise() def result(implicit cec: ConcurrentExecutionContext): Future[ArrayBuffer[I]] = { @@ -140,7 +140,7 @@ object MemoryPlanner { } } - case class Materialize[O](op: Op[O]) extends Op[O] { + final case class Materialize[O](op: Op[O]) extends Op[O] { private[this] val promise: Promise[ArrayBuffer[_ <: O]] = Promise() def result(implicit cec: ConcurrentExecutionContext) = { @@ -152,7 +152,7 @@ object MemoryPlanner { } } - case class Concat[O](left: Op[O], right: Op[O]) extends Op[O] { + final case class Concat[O](left: Op[O], right: Op[O]) extends Op[O] { def result(implicit cec: ConcurrentExecutionContext) = { val f1 = left.result val f2 = right.result @@ -160,7 +160,7 @@ object MemoryPlanner { } } - case class Map[I, O](input: Op[I], fn: I => TraversableOnce[O]) extends Op[O] { + final case class Map[I, O](input: Op[I], fn: I => TraversableOnce[O]) extends Op[O] { def result(implicit cec: ConcurrentExecutionContext): Future[ArrayBuffer[O]] = input.result.map { array => val res = ArrayBuffer[O]() @@ -173,7 +173,7 @@ object MemoryPlanner { } } - case class OnComplete[O](of: Op[O], fn: () => Unit) extends Op[O] { + final case class OnComplete[O](of: Op[O], fn: () => Unit) extends Op[O] { def result(implicit cec: ConcurrentExecutionContext) = { val res = of.result res.onComplete(_ => fn()) @@ -181,12 +181,12 @@ object MemoryPlanner { } } - case class Transform[I, O](input: Op[I], fn: IndexedSeq[I] => ArrayBuffer[O]) extends Op[O] { + final case class Transform[I, O](input: Op[I], fn: IndexedSeq[I] => ArrayBuffer[O]) extends Op[O] { def result(implicit cec: ConcurrentExecutionContext) = input.result.map(fn) } - case class Reduce[K, V1, V2]( + final case class Reduce[K, V1, V2]( input: Op[(K, V1)], fn: (K, Iterator[V1]) => Iterator[V2], ord: Option[Ordering[_ >: V1]] @@ -217,7 +217,7 @@ object MemoryPlanner { } } - case class Join[A, B, C]( + final case class Join[A, B, C]( opA: Op[A], opB: Op[B], fn: (IndexedSeq[A], IndexedSeq[B]) => ArrayBuffer[C]) extends Op[C] { @@ -302,12 +302,15 @@ class MemoryWriter(mem: MemoryMode) extends Writer { def plan[T](m: Memo, tp: TypedPipe[T]): (Memo, Op[T]) = m.plan(tp) { tp match { + case CounterPipe(pipe) => + // TODO: counters not yet supported, but can be with an concurrent hashmap + plan(m, pipe.map(_._1)) case cp@CrossPipe(_, _) => plan(m, cp.viaHashJoin) case CrossValue(left, EmptyValue) => (m, Op.empty) case CrossValue(left, LiteralValue(v)) => - val (m1, op) = plan(m, left) + val (m1, op) = plan(m, left) // linter:disable:UndesirableTypeInference (m1, op.concatMap { a => Iterator.single((a, v)) }) case CrossValue(left, ComputedValue(right)) => plan(m, CrossPipe(left, right)) @@ -346,7 +349,7 @@ class MemoryWriter(mem: MemoryMode) extends Writer { go(f) case FlatMapped(prev, fn) => - val (m1, op) = plan(m, prev) + val (m1, op) = plan(m, prev) // linter:disable:UndesirableTypeInference (m1, op.concatMap(fn)) case ForceToDisk(pipe) => @@ -370,7 +373,7 @@ class MemoryWriter(mem: MemoryMode) extends Writer { go(f) case Mapped(input, fn) => - val (m1, op) = plan(m, input) + val (m1, op) = plan(m, input) // linter:disable:UndesirableTypeInference (m1, op.map(fn)) case MergedTypedPipe(left, right) => @@ -545,7 +548,7 @@ class MemoryWriter(mem: MemoryMode) extends Writer { (st, a :: acts) } case ((oldState, acts), ToWrite.SimpleWrite(pipe, sink)) => - val (nextM, op) = plan(oldState.memo, pipe) + val (nextM, op) = plan(oldState.memo, pipe) // linter:disable:UndesirableTypeInference val action = () => { val arrayBufferF = op.result arrayBufferF.foreach { mem.writeSink(sink, _) } diff --git a/scalding-core/src/test/scala/com/twitter/scalding/ExecutionTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/ExecutionTest.scala index de17574ae2..0ed0cc8703 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/ExecutionTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/ExecutionTest.scala @@ -357,7 +357,7 @@ class ExecutionTest extends WordSpec with Matchers { val files = cleanupHook.get.asInstanceOf[TempFileCleanup].filesToCleanup assert(files.size == 1) - assert(files(0).contains(tempFile)) + assert(files.head.contains(tempFile)) cleanupHook.get.run() // Remove the hook so it doesn't show up in the list of shutdown hooks for other tests Runtime.getRuntime.removeShutdownHook(cleanupHook.get) @@ -385,7 +385,7 @@ class ExecutionTest extends WordSpec with Matchers { val files = cleanupHook.get.asInstanceOf[TempFileCleanup].filesToCleanup assert(files.size == 2) - assert(files(0).contains(tempFileOne) || files(0).contains(tempFileTwo)) + assert(files.head.contains(tempFileOne) || files.head.contains(tempFileTwo)) assert(files(1).contains(tempFileOne) || files(1).contains(tempFileTwo)) cleanupHook.get.run() // Remove the hook so it doesn't show up in the list of shutdown hooks for other tests @@ -437,6 +437,34 @@ class ExecutionTest extends WordSpec with Matchers { c1.shouldSucceed() should ===(100) c2.shouldSucceed() should ===(100) } + "zip does not duplicate pure counters" in { + val c1 = { + val e1 = TypedPipe.from(0 until 100) + .tallyAll("scalding", "test") + .writeExecution(source.NullSink) + + e1.zip(e1) + .getCounters.map { case (_, c) => + println(c.toMap) + c(("test", "scalding")) + } + } + + val c2 = { + val e2 = TypedPipe.from(0 until 100) + .tallyAll("scalding", "test") + .writeExecution(source.NullSink) + + e2.flatMap(Execution.from(_)).zip(e2) + .getCounters.map { case (_, c) => + println(c.toMap) + c(("test", "scalding")) + } + } + + c1.shouldSucceed() should ===(100) + c2.shouldSucceed() should ===(100) + } "Running a large loop won't exhaust boxed instances" in { var timesEvaluated = 0 diff --git a/scalding-core/src/test/scala/com/twitter/scalding/LargePlanTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/LargePlanTest.scala new file mode 100644 index 0000000000..930e56ee2a --- /dev/null +++ b/scalding-core/src/test/scala/com/twitter/scalding/LargePlanTest.scala @@ -0,0 +1,64 @@ +package com.twitter.scalding + +import org.scalatest.FunSuite + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Await +import scala.concurrent.duration._ + +/** + * on branch 0.17.x: + * - size=2 took 0.5 seconds + * - size=4 took 0.2 seconds + * - size=8 took 0.3 seconds + * - size=16 took 0.4 seconds + * - size=32 took 0.7 seconds + * - size=64 took 18.9 seconds + * - size=128 timed out (after 60 seconds) + * + * on branch cascading3: + * - size=2 took 0.6 seconds + * - size=4 took 0.3 seconds + * - size=8 took 0.3 seconds + * - size=16 took 0.4 seconds + * - size=32 took 0.5 seconds + * - size=64 took 1.2 seconds + * - size=128 took 2.7 seconds + */ + +class LargePlanTest extends FunSuite { + + val ns = List((1, 100), (2, 200)) + + // build a small pipe (only 2 keys) composed of a potentially large + // number of joins. + def build(size: Int): TypedPipe[(Int, Int)] = { + val pipe = TypedPipe.from(ns) + if (size <= 0) pipe + else pipe.join(build(size - 1)).mapValues { case (x, y) => x + y } + } + + // each test might run for up to this long + val Timeout = 60.seconds // one minute + + // run a test at a particular size + def run(size: Int): Unit = { + val t0 = System.currentTimeMillis() + val pipe = build(size) + val exec = pipe.toIterableExecution + val fut = exec.run(Config.empty, Local(true)) + val values = Await.result(fut, Timeout) + val secs = "%.1f" format ((System.currentTimeMillis() - t0) / 1000.0) + assert(true) + println(s"size=$size took $secs seconds") + } + + test("size=2") { run(2) } + test("size=4") { run(4) } + test("size=8") { run(8) } + test("size=16") { run(16) } + test("size=32") { run(32) } + test("size=64") { run(64) } + // test("size=128") { run(128) } +} + diff --git a/scalding-core/src/test/scala/com/twitter/scalding/RegressionTests.scala b/scalding-core/src/test/scala/com/twitter/scalding/RegressionTests.scala new file mode 100644 index 0000000000..bd8116e96a --- /dev/null +++ b/scalding-core/src/test/scala/com/twitter/scalding/RegressionTests.scala @@ -0,0 +1,24 @@ +package com.twitter.scalding + +import org.scalatest.FunSuite + +class RegressionTests extends FunSuite { + test("hashJoins + merges that fail in cascading 3") { + val p1 = + TypedPipe.from(List(1, 2)) + .cross(TypedPipe.from(List(3, 4))) + + val p2 = + TypedPipe.from(List(5, 6)) + .cross(TypedPipe.from(List(8, 9))) + + val p3 = (p1 ++ p2) + val p4 = (TypedPipe.from(List((8, 1), (10, 2))) ++ p3) + + val expected = List((1, 3), (1, 4), (2, 3), (2, 4), (5, 8), (5, 9), (6, 8), (6, 9), (8, 1), (10, 2)) + val values = p4.toIterableExecution + .waitFor(Config.empty, Local(true)) + .get + assert(values.toList.sorted == expected) + } +} diff --git a/scalding-core/src/test/scala/com/twitter/scalding/TypedPipeTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/TypedPipeTest.scala index 639f0fd3c6..0d6cbd2359 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/TypedPipeTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/TypedPipeTest.scala @@ -234,6 +234,7 @@ class TypedPipeJoinKryoTest extends WordSpec with Matchers { .finish() } } + class TypedPipeDistinctJob(args: Args) extends Job(args) { Tsv("inputFile").read.toTypedPipe[(Int, Int)](0, 1) .distinct @@ -256,6 +257,31 @@ class TypedPipeDistinctTest extends WordSpec with Matchers { } } +class TypedPipeDistinctWordsJob(args: Args) extends Job(args) { + TextLine("inputFile") + .flatMap(_.split("\\s+")) + .distinct + .write(TextLine("outputFile")) +} + +class TypedPipeDistinctWordsTest extends WordSpec with Matchers { + import Dsl._ + "A TypedPipeDistinctWordsJob" should { + var idx = 0 + JobTest(new TypedPipeDistinctWordsJob(_)) + .source(TextLine("inputFile"), List(1 -> "a b b c", 2 -> "c d e")) + .sink[String](TextLine("outputFile")){ outputBuffer => + s"$idx: correctly count unique item sizes" in { + outputBuffer.toSet should have size 5 + } + idx += 1 + } + .run + .runHadoop + .finish() + } +} + class TypedPipeDistinctByJob(args: Args) extends Job(args) { Tsv("inputFile").read.toTypedPipe[(Int, Int)](0, 1) .distinctBy(_._2) diff --git a/scalding-core/src/test/scala/com/twitter/scalding/typed/NoStackLineNumberTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/typed/NoStackLineNumberTest.scala index a4a67be95a..6475d28fc0 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/typed/NoStackLineNumberTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/typed/NoStackLineNumberTest.scala @@ -37,9 +37,15 @@ class NoStackLineNumberTest extends WordSpec { tp.toPipe('a, 'b) } val pipe = Await.result(pipeFut, SDuration.Inf) - // We pick up line number info via the NoStackAndThenClass + // We pick up line number info via TypedPipe.withLine // So this should have some non-scalding info in it. - assert(RichPipe.getPipeDescriptions(pipe).size > 0) + val allDesc = RichPipe(pipe) + .upstreamPipes + .map(RichPipe.getPipeDescriptions(_).toSet) + .foldLeft(Set.empty[String])(_ | _) + + assert(allDesc.size > 0) + assert(allDesc.exists(_.contains("com.twitter.example.scalding.typed.InAnotherPackage"))) } } -} \ No newline at end of file +} diff --git a/scalding-core/src/test/scala/com/twitter/scalding/typed/OptimizationRulesTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/typed/OptimizationRulesTest.scala new file mode 100644 index 0000000000..10462a6089 --- /dev/null +++ b/scalding-core/src/test/scala/com/twitter/scalding/typed/OptimizationRulesTest.scala @@ -0,0 +1,510 @@ +package com.twitter.scalding.typed + +import cascading.flow.FlowDef +import cascading.flow.planner.FlowPlanner +import cascading.tuple.Fields +import com.stripe.dagon.{ Dag, Rule } +import com.twitter.scalding.source.{ TypedText, NullSink } +import org.apache.hadoop.conf.Configuration +import com.twitter.scalding.{ Config, ExecutionContext, Local, Hdfs, FlowState, FlowStateMap, IterableSource, TupleConverter } +import com.twitter.scalding.typed.cascading_backend.CascadingBackend +import org.scalactic.anyvals.PosInt +import org.scalatest.FunSuite +import org.scalatest.prop.PropertyChecks +import org.scalatest.prop.GeneratorDrivenPropertyChecks.PropertyCheckConfiguration +import org.scalacheck.{ Arbitrary, Gen } +import scala.util.{ Failure, Success, Try } + +object TypedPipeGen { + val srcGen: Gen[TypedPipe[Int]] = { + val g1 = Gen.listOf(Arbitrary.arbitrary[Int]).map(TypedPipe.from(_)) + val src = Gen.identifier.map { f => TypedPipe.from(TypedText.tsv[Int](f)) } + Gen.oneOf(g1, src, Gen.const(TypedPipe.empty)) + } + + def mapped(srcGen: Gen[TypedPipe[Int]]): Gen[TypedPipe[Int]] = { + val next1: Gen[TypedPipe[Int] => TypedPipe[Int]] = + Gen.oneOf( + tpGen(srcGen).map { p: TypedPipe[Int] => + { x: TypedPipe[Int] => x.cross(p).keys } + }, + tpGen(srcGen).map { p: TypedPipe[Int] => + { x: TypedPipe[Int] => x.cross(ValuePipe(2)).values } + }, + //Gen.const({ t: TypedPipe[Int] => t.debug }), debug spews a lot to the terminal + Arbitrary.arbitrary[Int => Boolean].map { fn => + { t: TypedPipe[Int] => t.filter(fn) } + }, + Gen.const({ t: TypedPipe[Int] => t.forceToDisk }), + Gen.const({ t: TypedPipe[Int] => t.fork }), + tpGen(srcGen).map { p: TypedPipe[Int] => + { x: TypedPipe[Int] => x ++ p } + }, + Gen.identifier.map { id => + { t: TypedPipe[Int] => t.addTrap(TypedText.tsv[Int](id)) } + }, + Gen.identifier.map { id => + { t: TypedPipe[Int] => t.withDescription(id) } + }) + + val one = for { + n <- next1 + p <- tpGen(srcGen) + } yield n(p) + + val next2: Gen[TypedPipe[(Int, Int)] => TypedPipe[Int]] = + Gen.oneOf( + Gen.const({ p: TypedPipe[(Int, Int)] => p.values }), + Gen.const({ p: TypedPipe[(Int, Int)] => p.keys })) + + val two = for { + n <- next2 + p <- keyed(srcGen) + } yield n(p) + + Gen.frequency((4, one), (1, two)) + } + + def keyed(srcGen: Gen[TypedPipe[Int]]): Gen[TypedPipe[(Int, Int)]] = { + val keyRec = Gen.lzy(keyed(srcGen)) + val one = Gen.oneOf( + for { + single <- tpGen(srcGen) + fn <- Arbitrary.arbitrary[Int => (Int, Int)] + } yield single.map(fn), + for { + single <- tpGen(srcGen) + fn <- Arbitrary.arbitrary[Int => List[(Int, Int)]] + } yield single.flatMap(fn)) + + val two = Gen.oneOf( + for { + fn <- Arbitrary.arbitrary[Int => Boolean] + pair <- keyRec + } yield pair.filterKeys(fn), + for { + fn <- Arbitrary.arbitrary[Int => List[Int]] + pair <- keyRec + } yield pair.flatMapValues(fn), + for { + fn <- Arbitrary.arbitrary[Int => Int] + pair <- keyRec + } yield pair.mapValues(fn), + for { + pair <- keyRec + } yield pair.sumByKey.toTypedPipe, + for { + pair <- keyRec + } yield pair.sumByLocalKeys, + for { + pair <- keyRec + } yield pair.group.mapGroup { (k, its) => its }.toTypedPipe, + for { + pair <- keyRec + } yield pair.group.sorted.mapGroup { (k, its) => its }.toTypedPipe, + for { + pair <- keyRec + } yield pair.group.sorted.withReducers(2).mapGroup { (k, its) => its }.toTypedPipe, + for { + p1 <- keyRec + p2 <- keyRec + } yield p1.hashJoin(p2).values, + for { + p1 <- keyRec + p2 <- keyRec + } yield p1.join(p2).values, + for { + p1 <- keyRec + p2 <- keyRec + } yield p1.join(p2).mapValues { case (a, b) => a * b }.toTypedPipe) + + // bias to consuming Int, since the we can stack overflow with the (Int, Int) + // cases + Gen.frequency((2, one), (1, two)) + } + + def tpGen(srcGen: Gen[TypedPipe[Int]]): Gen[TypedPipe[Int]] = + Gen.lzy(Gen.frequency((1, srcGen), (1, mapped(srcGen)))) + + /** + * This generates a TypedPipe that can't neccesarily + * be run because it has fake sources + */ + val genWithFakeSources: Gen[TypedPipe[Int]] = tpGen(srcGen) + + /** + * This can always be run because all the sources are + * Iterable sources + */ + val genWithIterableSources: Gen[TypedPipe[Int]] = + Gen.choose(0, 20) // don't make giant lists which take too long to evaluate + .flatMap { sz => + tpGen(Gen.listOfN(sz, Arbitrary.arbitrary[Int]).map(TypedPipe.from(_))) + } + + val genKeyedWithFake: Gen[TypedPipe[(Int, Int)]] = + keyed(srcGen) + + import OptimizationRules._ + + val allRules = List( + AddExplicitForks, + ComposeFlatMap, + ComposeMap, + ComposeFilter, + ComposeWithOnComplete, + ComposeMapFlatMap, + ComposeFilterFlatMap, + ComposeFilterMap, + DescribeLater, + DiamondToFlatMap, + RemoveDuplicateForceFork, + IgnoreNoOpGroup, + DeferMerge, + FilterKeysEarly, + FilterLocally, + EmptyIsOftenNoOp, + EmptyIterableIsEmpty, + HashToShuffleCoGroup, + ForceToDiskBeforeHashJoin) + + def genRuleFrom(rs: List[Rule[TypedPipe]]): Gen[Rule[TypedPipe]] = + for { + c <- Gen.choose(1, rs.size) + rs <- Gen.pick(c, rs) + } yield rs.reduce(_.orElse(_)) + + val genRule: Gen[Rule[TypedPipe]] = genRuleFrom(allRules) +} + +/** + * Used to test that we call phases + */ +class ThrowingOptimizer extends OptimizationPhases { + def phases = sys.error("booom") +} + +/** + * Just convert everything to a constant + * so we can check that the optimization was applied + */ +class ConstantOptimizer extends OptimizationPhases { + def phases = List(new Rule[TypedPipe] { + def apply[T](on: Dag[TypedPipe]) = { t => + Some(TypedPipe.empty) + } + }) +} + +class JustHashJoinForce extends OptimizationPhases { + def phases = List(OptimizationRules.ForceToDiskBeforeHashJoin) +} + +// we need to extend PropertyChecks, it seems, to control the number of successful runs +// for optimization rules, we want to do many tests +class OptimizationRulesTest extends FunSuite with PropertyChecks { + import OptimizationRules.toLiteral + + def invert[T](t: TypedPipe[T]) = + assert(toLiteral(t).evaluate == t) + + test("randomly generated TypedPipe trees are invertible") { + forAll(TypedPipeGen.genWithFakeSources) { (t: TypedPipe[Int]) => + invert(t) + } + } + + def optimizationLaw[T: Ordering](t: TypedPipe[T], rule: Rule[TypedPipe]) = { + val optimized = Dag.applyRule(t, toLiteral, rule) + val optimized2 = Dag.applyRule(t, toLiteral, rule) + + // Optimization pure is function (wrt to universal equality) + assert(optimized == optimized2) + + // We don't want any further optimization on this job + //val conf = Config.empty.setOptimizationPhases(classOf[EmptyOptimizationPhases]) + // cascading3 needs this + val conf = Config.empty.setOptimizationPhases(classOf[JustHashJoinForce]) + assert(TypedPipeDiff.diff(t, optimized) + .toIterableExecution + .waitFor(conf, Local(true)).get.isEmpty) + } + + // How many steps would this be in Hadoop on Cascading + def steps[T](p0: TypedPipe[T]): Int = { + val mode = Hdfs.default + val fd = new FlowDef + // cascading3 requires this rule + val p = Dag.applyRule(p0, toLiteral, OptimizationRules.ForceToDiskBeforeHashJoin) + val pipe = CascadingBackend.toPipeUnoptimized(p, NullSink.sinkFields)(fd, mode, NullSink.setter) + NullSink.writeFrom(pipe)(fd, mode) + val conf = Config.defaultFrom(mode) ++ + Map.empty[String, String] + // turn on tracing with this, but you probably want to comment out almost all the tests + // Map(FlowPlanner.TRACE_PLAN_PATH -> "/tmp/scalding/cascading/trace/plan/", + // FlowPlanner.TRACE_PLAN_TRANSFORM_PATH -> "/tmp/scalding/cascading/trace/plan/", + // FlowPlanner.TRACE_STATS_PATH -> "/tmp/scalding/cascading/trace/plan/") + val ec = ExecutionContext.newContext(conf)(fd, mode) + val flow = ec.buildFlow.get + flow.getFlowSteps.size + } + + def optimizationReducesSteps[T](init: TypedPipe[T], rule: Rule[TypedPipe]) = { + val optimized = Dag.applyRule(init, toLiteral, rule) + assert(steps(init) >= steps(optimized)) + } + + test("test planning of some example graphs that have given us trouble in cascading3") { + /** + * This is a self hashJoin + */ + val p = TypedPipe.from(List(1, 2, 3)).map { k => (k.toString, k) } + val pSelfJoin = p.hashJoin(p) + + assert(steps(pSelfJoin) <= 2) + assert(steps(pSelfJoin.hashJoin(pSelfJoin)) <= 3) + + def intOrder: Ordering[Int] = implicitly[Ordering[Int]] + + { + import TypedPipe._ + import CoGrouped._ + + val fn11: Int => Int = { x => x } + val fn11s: Int => List[Int] = List(_) + val fn12s: Int => List[(Int, Int)] = { x => List((x, 1)) } + val fn21: ((Int, Int)) => Int = { case (a, b) => a * b } + val mg: (Int, Iterator[Int]) => Iterator[Int] = { (_, b) => b } + val mg21: (Int, Iterator[(Int, Int)]) => Iterator[Int] = { (_, b) => b.map(_._1) } + + val arg0 = WithDescriptionTypedPipe(Mapped(WithDescriptionTypedPipe(MapValues(CoGroupedPipe(MapGroup(Pair(IdentityReduce(intOrder, + WithDescriptionTypedPipe(WithDescriptionTypedPipe(FlatMapped[Int, (Int, Int)](EmptyTypedPipe, fn12s), "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), + "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), None, List()), + IdentityReduce(intOrder, + WithDescriptionTypedPipe(CoGroupedPipe(MapGroup(Pair(IdentityReduce(intOrder, + WithDescriptionTypedPipe(WithDescriptionTypedPipe(FlatMapped(WithDescriptionTypedPipe[Int](EmptyTypedPipe, "tvo3aakgrh9jrzxoyeuqnfawbmjnxhaixoNgomuxeg41zfcpu", false), + fn12s), "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), + "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), None, List()), + IdentityReduce(intOrder, WithDescriptionTypedPipe(WithDescriptionTypedPipe(FlatMapped(WithDescriptionTypedPipe(MergedTypedPipe(WithDescriptionTypedPipe( + WithDescriptionTypedPipe(WithDescriptionTypedPipe(Mapped(WithDescriptionTypedPipe(CrossPipe(WithDescriptionTypedPipe(Mapped(WithDescriptionTypedPipe(CrossValue(WithDescriptionTypedPipe( + TrappedPipe[Int](EmptyTypedPipe, TypedText.tsv[Int]("m8x5mxgwljgg4zWaq"), TupleConverter.singleConverter), + "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), LiteralValue(2)), + "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), fn21), + "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), EmptyTypedPipe), "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), fn21), + "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), "pqbttw", false), "rzeykwyetbqpay9k7kmyfqrihXolLbo1gkqhq", false), + EmptyTypedPipe), + "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), fn12s), "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), + "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), None, List()), Joiner.inner2[Int, Int, Int]), mg21)), + "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), None, List()), Joiner.inner2[Int, Int, Int]), mg21)), + fn11 /**/ ), + "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true), fn21 /**/ ), + "org.scalacheck.Gen$R$class.map(Gen.scala:237)", true) + + // this is just a test that we can plan, which we can't + assert(steps(arg0) < 10) + } + + { + import TypedPipe._ + import CoGrouped._ + + val fn21: ((Int, Int)) => Int = { case (a, b) => a * b } + + val p1 = + TypedPipe.from(List(1, 2)) + .cross(TypedPipe.from(List(3, 4))) + + val p2 = + TypedPipe.from(List(5, 6)) + .cross(TypedPipe.from(List(8, 9))) + + val p3 = (p1 ++ p2) + val p4 = (TypedPipe.from(List((8, 1), (10, 2))) ++ p3) + + assert(steps(p3) < 10) // this passes + assert(steps(p4) < 10) // FAILS to plan, throwing + } + + } + + val TrialCount = PosInt(200) + + test("all optimization rules don't change results") { + import TypedPipeGen.{ genWithIterableSources, genRule } + implicit val generatorDrivenConfig: PropertyCheckConfiguration = PropertyCheckConfiguration(minSuccessful = TrialCount) + forAll(genWithIterableSources, genRule)(optimizationLaw[Int] _) + } + + test("all optimization rules do not increase steps") { + import TypedPipeGen.{ allRules, genWithIterableSources, genRuleFrom } + implicit val generatorDrivenConfig: PropertyCheckConfiguration = PropertyCheckConfiguration(minSuccessful = TrialCount) + + val possiblyIncreasesSteps: Set[Rule[TypedPipe]] = + Set(OptimizationRules.AddExplicitForks, // explicit forks can cause cascading to add steps instead of recomputing values + OptimizationRules.ForceToDiskBeforeHashJoin, // adding a forceToDisk can increase the number of steps + OptimizationRules.HashToShuffleCoGroup // obviously changing a hashjoin to a cogroup can increase steps + ) + + val gen = genRuleFrom(allRules.filterNot(possiblyIncreasesSteps)) + + forAll(genWithIterableSources, gen)(optimizationReducesSteps[Int] _) + } + + test("ThrowingOptimizer is triggered") { + forAll(TypedPipeGen.genWithFakeSources) { t => + val conf = new Configuration() + conf.set(Config.OptimizationPhases, classOf[ThrowingOptimizer].getName) + implicit val mode = Hdfs(true, conf) + implicit val fd = new FlowDef + Try(CascadingBackend.toPipe(t, new Fields("value"))) match { + case Failure(ex) => assert(ex.getMessage == "booom") + case Success(res) => fail(s"expected failure, got $res") + } + } + + forAll(TypedPipeGen.genWithFakeSources) { t => + val ex = t.toIterableExecution + + val config = Config.empty.setOptimizationPhases(classOf[ThrowingOptimizer]) + ex.waitFor(config, Local(true)) match { + case Failure(ex) => assert(ex.getMessage == "booom") + case Success(res) => fail(s"expected failure, got $res") + } + } + } + + test("ConstantOptimizer is triggered") { + forAll(TypedPipeGen.genWithFakeSources) { t => + val conf = new Configuration() + conf.set(Config.OptimizationPhases, classOf[ConstantOptimizer].getName) + implicit val mode = Hdfs(true, conf) + implicit val fd = new FlowDef + Try(CascadingBackend.toPipe(t, new Fields("value"))) match { + case Failure(ex) => fail(s"$ex") + case Success(pipe) => + FlowStateMap.get(fd) match { + case None => fail("expected a flow state") + case Some(FlowState(m, _)) => + assert(m.size == 1) + m.head._2 match { + case it: IterableSource[_] => + assert(it.iter == Nil) + case _ => + fail(s"$m") + } + } + } + } + + forAll(TypedPipeGen.genWithFakeSources) { t => + val ex = t.toIterableExecution + + val config = Config.empty.setOptimizationPhases(classOf[ConstantOptimizer]) + ex.waitFor(config, Local(true)) match { + case Failure(ex) => fail(s"$ex") + case Success(res) => assert(res.isEmpty) + } + } + } + + test("OptimizationRules.toLiteral is invertible on some specific instances") { + + invert(TypedPipe.from(TypedText.tsv[Int]("foo"))) + invert(TypedPipe.from(List(1, 2, 3))) + invert(TypedPipe.from(List(1, 2, 3)).map(_ * 2)) + invert { + TypedPipe.from(List(1, 2, 3)).map { i => (i, i) }.sumByKey.toTypedPipe + } + + invert { + val p = TypedPipe.from(List(1, 2, 3)).map { i => (i, i) }.sumByKey + + p.mapGroup { (k, its) => Iterator.single(its.sum * k) } + } + + invert { + val p = TypedPipe.from(List(1, 2, 3)).map { i => (i, i) }.sumByKey + p.cross(TypedPipe.from(List("a", "b", "c")).sum) + } + + invert { + val p = TypedPipe.from(List(1, 2, 3)).map { i => (i, i) }.sumByKey + p.cross(TypedPipe.from(List("a", "b", "c"))) + } + + invert { + val p = TypedPipe.from(List(1, 2, 3)).map { i => (i, i) }.sumByKey + p.forceToDisk + } + + invert { + val p = TypedPipe.from(List(1, 2, 3)).map { i => (i, i) }.sumByKey + p.fork + } + + invert { + val p1 = TypedPipe.from(List(1, 2, 3)).map { i => (i, i) } + val p2 = TypedPipe.from(TypedText.tsv[(Int, String)]("foo")) + + p1.join(p2).toTypedPipe + } + + invert { + val p1 = TypedPipe.from(List(1, 2, 3)).map { i => (i, i) } + val p2 = TypedPipe.from(TypedText.tsv[(Int, String)]("foo")) + + p1.hashJoin(p2) + } + + invert { + val p1 = TypedPipe.from(List(1, 2, 3)).map { i => (i, i) } + val p2 = TypedPipe.from(TypedText.tsv[(Int, String)]("foo")) + + p1.join(p2).filterKeys(_ % 2 == 0) + } + } + + test("all transforms preserve equality") { + + forAll(TypedPipeGen.genWithFakeSources, TypedPipeGen.genKeyedWithFake) { (tp, keyed) => + val fn0 = { i: Int => i * 2 } + val filterFn = { i: Int => i % 2 == 0 } + val fn1 = { i: Int => (0 to i) } + + def eqCheck[T](t: => T) = { + assert(t == t) + } + + eqCheck(tp.map(fn0)) + eqCheck(tp.filter(filterFn)) + eqCheck(tp.flatMap(fn1)) + + eqCheck(keyed.mapValues(fn0)) + eqCheck(keyed.flatMapValues(fn1)) + eqCheck(keyed.filterKeys(filterFn)) + + eqCheck(tp.groupAll) + eqCheck(tp.groupBy(fn0)) + eqCheck(tp.asKeys) + eqCheck(tp.either(keyed)) + eqCheck(keyed.eitherValues(keyed.mapValues(fn0))) + eqCheck(tp.map(fn1).flatten) + eqCheck(keyed.swap) + eqCheck(keyed.keys) + eqCheck(keyed.values) + + val valueFn: (Int, Option[Int]) => String = { (a, b) => a.toString + b.toString } + val valueFn2: (Int, Option[Int]) => List[Int] = { (a, b) => a :: (b.toList) } + val valueFn3: (Int, Option[Int]) => Boolean = { (a, b) => true } + + eqCheck(tp.mapWithValue(LiteralValue(1))(valueFn)) + eqCheck(tp.flatMapWithValue(LiteralValue(1))(valueFn2)) + eqCheck(tp.filterWithValue(LiteralValue(1))(valueFn3)) + + eqCheck(tp.hashLookup(keyed)) + eqCheck(tp.groupRandomly(100)) + val ordInt = implicitly[Ordering[Int]] + eqCheck(tp.distinctBy(fn0)(ordInt)) + } + } +} diff --git a/scalding-core/src/test/scala/com/twitter/scalding/typed/RequireOrderedSerializationTest.scala b/scalding-core/src/test/scala/com/twitter/scalding/typed/RequireOrderedSerializationTest.scala index 0baca83837..a54439a482 100644 --- a/scalding-core/src/test/scala/com/twitter/scalding/typed/RequireOrderedSerializationTest.scala +++ b/scalding-core/src/test/scala/com/twitter/scalding/typed/RequireOrderedSerializationTest.scala @@ -18,12 +18,13 @@ package com.twitter.scalding import com.twitter.scalding.serialization.CascadingBinaryComparator import com.twitter.scalding.serialization.OrderedSerialization import com.twitter.scalding.serialization.StringOrderedSerialization +import com.twitter.scalding.serialization.RequireOrderedSerializationMode import org.scalatest.{ Matchers, WordSpec } -class NoOrderdSerJob(args: Args) extends Job(args) { +class NoOrderdSerJob(args: Args, requireOrderedSerializationMode: String) extends Job(args) { - override def config = super.config + (Config.ScaldingRequireOrderedSerialization -> "true") + override def config = super.config + (Config.ScaldingRequireOrderedSerialization -> requireOrderedSerializationMode) TypedPipe.from(TypedTsv[(String, String)]("input")) .group @@ -31,11 +32,11 @@ class NoOrderdSerJob(args: Args) extends Job(args) { .write(TypedTsv[(String, String)]("output")) } -class OrderdSerJob(args: Args) extends Job(args) { +class OrderdSerJob(args: Args, requireOrderedSerializationMode: String) extends Job(args) { implicit def stringOS: OrderedSerialization[String] = new StringOrderedSerialization - override def config = super.config + (Config.ScaldingRequireOrderedSerialization -> "true") + override def config = super.config + (Config.ScaldingRequireOrderedSerialization -> requireOrderedSerializationMode) TypedPipe.from(TypedTsv[(String, String)]("input")) .group @@ -45,29 +46,64 @@ class OrderdSerJob(args: Args) extends Job(args) { } class RequireOrderedSerializationTest extends WordSpec with Matchers { + "A NoOrderedSerJob" should { - // throw if we try to run in: - "throw when run" in { + + def test(job: Args => Job) = + JobTest(job) + .source(TypedTsv[(String, String)]("input"), List(("a", "a"), ("b", "b"))) + .sink[(String, String)](TypedTsv[(String, String)]("output")) { outBuf => () } + .run + .finish() + + "throw when mode is Fail" in { + val ex = the[Exception] thrownBy { + test(new NoOrderdSerJob(_, RequireOrderedSerializationMode.Fail.toString)) + } + ex.getMessage should include("SerializationTest.scala:") + } + + "not throw when mode is Log" in { + test(new NoOrderdSerJob(_, RequireOrderedSerializationMode.Log.toString)) + } + + "throw when mode is true" in { val ex = the[Exception] thrownBy { - JobTest(new NoOrderdSerJob(_)) - .source(TypedTsv[(String, String)]("input"), List(("a", "a"), ("b", "b"))) - .sink[(String, String)](TypedTsv[(String, String)]("output")) { outBuf => () } - .run - .finish() + test(new NoOrderdSerJob(_, "true")) } ex.getMessage should include("SerializationTest.scala:") } + + "not throw when mode is false" in { + test(new NoOrderdSerJob(_, "false")) + } } + "A OrderedSerJob" should { - // throw if we try to run in: - "run" in { - JobTest(new OrderdSerJob(_)) + + def test(job: Args => Job) = + JobTest(job) .source(TypedTsv[(String, String)]("input"), List(("a", "a"), ("a", "b"), ("b", "b"))) .sink[(String, String)](TypedTsv[(String, String)]("output")) { outBuf => outBuf.toSet shouldBe Set(("a", "b"), ("b", "b")) } .run .finish() + + "run when mode is Fail" in { + test(new OrderdSerJob(_, RequireOrderedSerializationMode.Fail.toString)) + } + + "run when mode is Log" in { + test(new OrderdSerJob(_, RequireOrderedSerializationMode.Log.toString)) + } + + "run when mode is true" in { + test(new OrderdSerJob(_, "true")) + } + + "run when mode is false" in { + test(new OrderdSerJob(_, "false")) } } } diff --git a/scalding-date/src/main/scala/com/twitter/scalding/AbsoluteDuration.scala b/scalding-date/src/main/scala/com/twitter/scalding/AbsoluteDuration.scala index 59b42d837b..23fc883f3a 100644 --- a/scalding-date/src/main/scala/com/twitter/scalding/AbsoluteDuration.scala +++ b/scalding-date/src/main/scala/com/twitter/scalding/AbsoluteDuration.scala @@ -137,31 +137,31 @@ sealed trait AbsoluteDuration extends Duration with Ordered[AbsoluteDuration] { override def hashCode: Int = toMillisecs.hashCode } -case class Millisecs(cnt: Int) extends Duration(Calendar.MILLISECOND, cnt, DateOps.UTC) +final case class Millisecs(cnt: Int) extends Duration(Calendar.MILLISECOND, cnt, DateOps.UTC) with AbsoluteDuration { override def toSeconds = cnt / 1000.0 override def toMillisecs = cnt.toLong } -case class Seconds(cnt: Int) extends Duration(Calendar.SECOND, cnt, DateOps.UTC) +final case class Seconds(cnt: Int) extends Duration(Calendar.SECOND, cnt, DateOps.UTC) with AbsoluteDuration { override def toSeconds = cnt.toDouble override def toMillisecs = (cnt.toLong) * 1000L } -case class Minutes(cnt: Int) extends Duration(Calendar.MINUTE, cnt, DateOps.UTC) +final case class Minutes(cnt: Int) extends Duration(Calendar.MINUTE, cnt, DateOps.UTC) with AbsoluteDuration { override def toSeconds = cnt * 60.0 override def toMillisecs = cnt.toLong * 60L * 1000L } -case class Hours(cnt: Int) extends Duration(Calendar.HOUR, cnt, DateOps.UTC) +final case class Hours(cnt: Int) extends Duration(Calendar.HOUR, cnt, DateOps.UTC) with AbsoluteDuration { override def toSeconds = cnt * 60.0 * 60.0 override def toMillisecs = cnt.toLong * 60L * 60L * 1000L } -case class AbsoluteDurationList(parts: List[AbsoluteDuration]) +final case class AbsoluteDurationList(parts: List[AbsoluteDuration]) extends AbstractDurationList[AbsoluteDuration](parts) with AbsoluteDuration { override def toSeconds = parts.map{ _.toSeconds }.sum override def toMillisecs: Long = parts.map{ _.toMillisecs }.sum diff --git a/scalding-date/src/main/scala/com/twitter/scalding/CalendarOps.scala b/scalding-date/src/main/scala/com/twitter/scalding/CalendarOps.scala index d79cff3d8d..a0d43ab272 100644 --- a/scalding-date/src/main/scala/com/twitter/scalding/CalendarOps.scala +++ b/scalding-date/src/main/scala/com/twitter/scalding/CalendarOps.scala @@ -13,12 +13,12 @@ object CalendarOps { if (currentField > field) { currentField match { case Calendar.DAY_OF_MONTH => cal.set(currentField, 1) - case Calendar.DAY_OF_WEEK_IN_MONTH => () // Skip - case Calendar.DAY_OF_WEEK => () // Skip - case Calendar.DAY_OF_YEAR => () // Skip - case Calendar.WEEK_OF_MONTH => () // Skip - case Calendar.WEEK_OF_YEAR => () // Skip - case Calendar.HOUR_OF_DAY => () // Skip + case Calendar.DAY_OF_WEEK_IN_MONTH | + Calendar.DAY_OF_WEEK | + Calendar.DAY_OF_YEAR | + Calendar.WEEK_OF_MONTH | + Calendar.WEEK_OF_YEAR | + Calendar.HOUR_OF_DAY => () // Skip case _ => cal.set(currentField, 0) } diff --git a/scalding-date/src/main/scala/com/twitter/scalding/DateOps.scala b/scalding-date/src/main/scala/com/twitter/scalding/DateOps.scala index d3809fcdad..88ee18c489 100644 --- a/scalding-date/src/main/scala/com/twitter/scalding/DateOps.scala +++ b/scalding-date/src/main/scala/com/twitter/scalding/DateOps.scala @@ -67,7 +67,7 @@ object DateOps extends java.io.Serializable { * Return the guessed format for this datestring */ private[scalding] def getFormatObject(s: String): Option[Format] = { - val formats: List[Format] = List( + val formats: List[Format] = List[Format]( Format.DATE_WITH_DASH, Format.DATEHOUR_WITH_DASH, Format.DATETIME_WITH_DASH, diff --git a/scalding-date/src/main/scala/com/twitter/scalding/DateRange.scala b/scalding-date/src/main/scala/com/twitter/scalding/DateRange.scala index 9bd1025ce5..bb40869691 100644 --- a/scalding-date/src/main/scala/com/twitter/scalding/DateRange.scala +++ b/scalding-date/src/main/scala/com/twitter/scalding/DateRange.scala @@ -93,6 +93,12 @@ case class DateRange(val start: RichDate, val end: RichDate) { */ def extend(delta: Duration) = DateRange(start, end + delta) + /** + * Extend the length by moving the start. + * Turns out, we can start the party early. + */ + def prepend(delta: Duration) = DateRange(start - delta, end) + def contains(point: RichDate) = (start <= point) && (point <= end) /** * Is the given Date range a (non-strict) subset of the given range diff --git a/scalding-date/src/main/scala/com/twitter/scalding/Duration.scala b/scalding-date/src/main/scala/com/twitter/scalding/Duration.scala index 7e62d67c04..c1aa49bfd0 100644 --- a/scalding-date/src/main/scala/com/twitter/scalding/Duration.scala +++ b/scalding-date/src/main/scala/com/twitter/scalding/Duration.scala @@ -26,10 +26,11 @@ import scala.annotation.tailrec */ object Duration extends java.io.Serializable { // TODO: remove this in 0.9.0 - val SEC_IN_MS = 1000 - val MIN_IN_MS = 60 * SEC_IN_MS - val HOUR_IN_MS = 60 * MIN_IN_MS - val UTC_UNITS = List((Hours, HOUR_IN_MS), (Minutes, MIN_IN_MS), (Seconds, SEC_IN_MS), (Millisecs, 1)) + val SEC_IN_MS: Int = 1000 + val MIN_IN_MS: Int = 60 * SEC_IN_MS + val HOUR_IN_MS: Int = 60 * MIN_IN_MS + val UTC_UNITS: List[(Int => AbsoluteDuration, Int)] = + List[(Int => AbsoluteDuration, Int)]((Hours, HOUR_IN_MS), (Minutes, MIN_IN_MS), (Seconds, SEC_IN_MS), (Millisecs, 1)) } abstract class Duration(val calField: Int, val count: Int, val tz: TimeZone) diff --git a/scalding-date/src/test/scala/com/twitter/scalding/DateTest.scala b/scalding-date/src/test/scala/com/twitter/scalding/DateTest.scala index 2c33162ccd..f3500883a3 100644 --- a/scalding-date/src/test/scala/com/twitter/scalding/DateTest.scala +++ b/scalding-date/src/test/scala/com/twitter/scalding/DateTest.scala @@ -212,6 +212,13 @@ class DateTest extends WordSpec { "reject an end that is before its start" in { intercept[IllegalArgumentException] { DateRange("2010-10-02", "2010-10-01") } } + "correctly add time in either or both directions" in { + assert(DateRange("2010-10-01", "2010-10-02").extend(Days(3)).each(Days(1)).size === 5) + assert(DateRange("2010-10-01", "2010-10-02").prepend(Days(3)).each(Days(1)).size === 5) + assert(DateRange("2010-10-01", "2010-10-02").embiggen(Days(3)).each(Days(1)).size === 8) + assert(DateRange("2010-10-01", "2010-10-10").extend(Days(1)).prepend(Days(1)) == + DateRange("2010-10-01", "2010-10-10").embiggen(Days(1))) + } } "Time units" should { def isSame(d1: Duration, d2: Duration) = { diff --git a/scalding-db/src/main/scala/com/twitter/scalding/db/macros/DBMacro.scala b/scalding-db/src/main/scala/com/twitter/scalding/db/macros/DBMacro.scala index 0df3afefe5..9e1a8f1f8b 100644 --- a/scalding-db/src/main/scala/com/twitter/scalding/db/macros/DBMacro.scala +++ b/scalding-db/src/main/scala/com/twitter/scalding/db/macros/DBMacro.scala @@ -12,20 +12,20 @@ sealed trait ScaldingDBAnnotation // This is the size in characters for a char field // For integers its really for display purposes @scala.annotation.meta.getter -class size(val size: Int) extends annotation.StaticAnnotation with ScaldingDBAnnotation +final class size(val size: Int) extends annotation.StaticAnnotation with ScaldingDBAnnotation // JDBC TEXT type, this forces the String field in question to be a text type @scala.annotation.meta.getter -class text() extends annotation.StaticAnnotation with ScaldingDBAnnotation +final class text() extends annotation.StaticAnnotation with ScaldingDBAnnotation // JDBC VARCHAR type, this forces the String field in question to be a text type @scala.annotation.meta.getter -class varchar() extends annotation.StaticAnnotation with ScaldingDBAnnotation +final class varchar() extends annotation.StaticAnnotation with ScaldingDBAnnotation // JDBC DATE type, this toggles a java.util.Date field to be JDBC Date. // It will default to DATETIME to preserve the full resolution of java.util.Date @scala.annotation.meta.getter -class date() extends annotation.StaticAnnotation with ScaldingDBAnnotation +final class date() extends annotation.StaticAnnotation with ScaldingDBAnnotation // This is the entry point to explicitly calling the JDBC macros. // Most often the implicits will be used in the package however diff --git a/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/ColumnDefinitionProviderImpl.scala b/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/ColumnDefinitionProviderImpl.scala index 2fad43471c..3dfdefbaa0 100644 --- a/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/ColumnDefinitionProviderImpl.scala +++ b/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/ColumnDefinitionProviderImpl.scala @@ -107,7 +107,7 @@ object ColumnDefinitionProviderImpl { .declarations .collect { case m: MethodSymbol if m.isCaseAccessor => m } .map { m => - val fieldName = m.name.toTermName.toString.trim + val fieldName = m.name.toString.trim val defaultVal = defaultArgs.get(fieldName) val annotationInfo: List[(Type, Option[Int])] = annotationData.getOrElse(m.name.toString.trim, Nil) @@ -116,11 +116,8 @@ object ColumnDefinitionProviderImpl { case (tpe, _) if tpe =:= typeOf[com.twitter.scalding.db.macros.size] => c.abort(c.enclosingPosition, "Hit a size macro where we couldn't parse the value. Probably not a literal constant. Only literal constants are supported.") case (tpe, _) if tpe <:< typeOf[com.twitter.scalding.db.macros.ScaldingDBAnnotation] => (tpe, None) } - (m, fieldName, defaultVal, annotationInfo) - } - .map { - case (accessorMethod, fieldName, defaultVal, annotationInfo) => - matchField(outerAccessorTree :+ accessorMethod, accessorMethod.returnType, FieldName(fieldName), defaultVal, annotationInfo, false) + + matchField(outerAccessorTree :+ m, m.returnType, FieldName(fieldName), defaultVal, annotationInfo, false) } .toList // This algorithm returns the error from the first exception we run into. diff --git a/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/AnnotationHelper.scala b/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/AnnotationHelper.scala index 81fde78188..5dd48e5156 100644 --- a/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/AnnotationHelper.scala +++ b/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/AnnotationHelper.scala @@ -9,7 +9,7 @@ import com.twitter.scalding.db.ColumnDefinition import com.twitter.scalding.db.macros.impl.FieldName private[handler] sealed trait SizeAnno -private[handler] case class WithSize(v: Int) extends SizeAnno +private[handler] final case class WithSize(v: Int) extends SizeAnno private[handler] case object WithoutSize extends SizeAnno private[handler] sealed trait DateAnno diff --git a/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/DateTypeHandler.scala b/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/DateTypeHandler.scala index 965f781416..50b4ccc321 100644 --- a/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/DateTypeHandler.scala +++ b/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/DateTypeHandler.scala @@ -28,11 +28,9 @@ object DateTypeHandler { _ <- nextHelper.validateFinished } yield (dateAnno) - extracted.flatMap { t => - t match { - case WithDate => Success(List(ColumnFormat(c)(accessorTree, "DATE", None))) - case WithoutDate => Success(List(ColumnFormat(c)(accessorTree, "DATETIME", None))) - } + extracted.flatMap { + case WithDate => Success(List(ColumnFormat(c)(accessorTree, "DATE", None))) + case WithoutDate => Success(List(ColumnFormat(c)(accessorTree, "DATETIME", None))) } } } diff --git a/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/NumericTypeHandler.scala b/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/NumericTypeHandler.scala index 04146366ab..8e99576870 100644 --- a/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/NumericTypeHandler.scala +++ b/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/NumericTypeHandler.scala @@ -29,12 +29,10 @@ object NumericTypeHandler { _ <- nextHelper.validateFinished } yield (sizeAnno) - extracted.flatMap { t => - t match { - case WithSize(s) if s > 0 => Success(List(ColumnFormat(c)(accessorTree, numericType, Some(s)))) - case WithSize(s) => Failure(new Exception(s"Int field $fieldName, has a size defined that is <= 0.")) - case WithoutSize => Success(List(ColumnFormat(c)(accessorTree, numericType, None))) - } + extracted.flatMap { + case WithSize(s) if s > 0 => Success(List(ColumnFormat(c)(accessorTree, numericType, Some(s)))) + case WithSize(s) => Failure(new Exception(s"Int field $fieldName, has a size defined that is <= 0.")) + case WithoutSize => Success(List(ColumnFormat(c)(accessorTree, numericType, None))) } } } diff --git a/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/StringTypeHandler.scala b/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/StringTypeHandler.scala index f3bcd76f2c..54e594dc8f 100644 --- a/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/StringTypeHandler.scala +++ b/scalding-db/src/main/scala/com/twitter/scalding/db/macros/impl/handler/StringTypeHandler.scala @@ -29,17 +29,15 @@ object StringTypeHandler { _ <- nextHelper.validateFinished } yield (sizeAnno, varcharAnno, textAnno) - extracted.flatMap { t => - t match { - case (_, WithVarchar, WithText) => Failure(new Exception(s"String field $fieldName, has mutually exclusive annotations @text and @varchar")) - case (WithoutSize, WithVarchar, WithoutText) => Failure(new Exception(s"String field $fieldName, is forced varchar but has no size annotation. size is required in the presence of varchar.")) - case (WithoutSize, WithoutVarchar, WithoutText) => Failure(new Exception(s"String field $fieldName, at least one of size, varchar, text must be present.")) - case (WithSize(siz), _, _) if siz <= 0 => Failure(new Exception(s"String field $fieldName, has a size $siz which is <= 0. Doesn't make sense for a string.")) - case (WithSize(siz), WithoutVarchar, WithoutText) if siz <= 255 => Success(List(ColumnFormat(c)(accessorTree, "VARCHAR", Some(siz)))) - case (WithSize(siz), WithoutVarchar, WithoutText) if siz > 255 => Success(List(ColumnFormat(c)(accessorTree, "TEXT", None))) - case (WithSize(siz), WithVarchar, WithoutText) => Success(List(ColumnFormat(c)(accessorTree, "VARCHAR", Some(siz)))) - case (_, WithoutVarchar, WithText) => Success(List(ColumnFormat(c)(accessorTree, "TEXT", None))) - } + extracted.flatMap { + case (_, WithVarchar, WithText) => Failure(new Exception(s"String field $fieldName, has mutually exclusive annotations @text and @varchar")) + case (WithoutSize, WithVarchar, WithoutText) => Failure(new Exception(s"String field $fieldName, is forced varchar but has no size annotation. size is required in the presence of varchar.")) + case (WithoutSize, WithoutVarchar, WithoutText) => Failure(new Exception(s"String field $fieldName, at least one of size, varchar, text must be present.")) + case (WithSize(siz), _, _) if siz <= 0 => Failure(new Exception(s"String field $fieldName, has a size $siz which is <= 0. Doesn't make sense for a string.")) + case (WithSize(siz), WithoutVarchar, WithoutText) if siz <= 255 => Success(List(ColumnFormat(c)(accessorTree, "VARCHAR", Some(siz)))) + case (WithSize(siz), WithoutVarchar, WithoutText) if siz > 255 => Success(List(ColumnFormat(c)(accessorTree, "TEXT", None))) + case (WithSize(siz), WithVarchar, WithoutText) => Success(List(ColumnFormat(c)(accessorTree, "VARCHAR", Some(siz)))) + case (_, WithoutVarchar, WithText) => Success(List(ColumnFormat(c)(accessorTree, "TEXT", None))) } } } diff --git a/scalding-graph/src/main/scala/com/twitter/scalding/graph/DependantGraph.scala b/scalding-graph/src/main/scala/com/twitter/scalding/graph/DependantGraph.scala deleted file mode 100644 index 02268c4718..0000000000 --- a/scalding-graph/src/main/scala/com/twitter/scalding/graph/DependantGraph.scala +++ /dev/null @@ -1,58 +0,0 @@ -/* - Copyright 2013 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.scalding.graph - -/** - * Given Dag and a List of immutable nodes, and a function to get - * dependencies, compute the dependants (reverse the graph) - */ -abstract class DependantGraph[T] { - def nodes: List[T] - def dependenciesOf(t: T): Iterable[T] - - lazy val allTails: List[T] = nodes.filter { t => - fanOut(t) match { - case Some(n) => n == 0 - case None => false - } - } - private lazy val nodeSet: Set[T] = nodes.toSet - - /** - * This is the dependants graph. Each node knows who it depends on - * but not who depends on it without doing this computation - */ - private lazy val graph: NeighborFn[T] = reversed(nodes)(dependenciesOf(_)) - - private lazy val depths: Map[T, Int] = dagDepth(nodes)(dependenciesOf(_)) - - /** - * The max of zero and 1 + depth of all parents if the node is the graph - */ - def isNode(p: T): Boolean = nodeSet.contains(p) - def depth(p: T): Option[Int] = depths.get(p) - - def dependantsOf(p: T): Option[List[T]] = - if (isNode(p)) Some(graph(p).toList) else None - - def fanOut(p: T): Option[Int] = dependantsOf(p).map { _.size } - /** - * Return all dependendants of a given node. - * Does not include itself - */ - def transitiveDependantsOf(p: T): List[T] = depthFirstOf(p)(graph) -} diff --git a/scalding-graph/src/main/scala/com/twitter/scalding/graph/Expr.scala b/scalding-graph/src/main/scala/com/twitter/scalding/graph/Expr.scala deleted file mode 100644 index 64006b8c61..0000000000 --- a/scalding-graph/src/main/scala/com/twitter/scalding/graph/Expr.scala +++ /dev/null @@ -1,84 +0,0 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.scalding.graph - -/** - * The Expressions are assigned Ids. Each Id is associated with - * an expression of inner type T. - * - * This is done to put an indirection in the ExpressionDag that - * allows us to rewrite nodes by simply replacing the expressions - * associated with given Ids. - * - * T is a phantom type used by the type system - */ -final case class Id[T](id: Int) - -/** - * Expr[T, N] is an expression of a graph of container nodes N[_] with - * result type N[T]. These expressions are like the Literal[T, N] graphs - * except that functions always operate with an indirection of a Id[T] - * where N[T] is the type of the input node. - * - * Nodes can be deleted from the graph by replacing an Expr at Id = idA - * with Var(idB) pointing to some upstream node. - * - * To add nodes to the graph, add depth to the final node returned in - * a Unary or Binary expression. - * - * TODO: see the approach here: https://gist.github.com/pchiusano/1369239 - * Which seems to show a way to do currying, so we can handle general - * arity - */ -sealed trait Expr[T, N[_]] { - def evaluate(idToExp: HMap[Id, ({ type E[t] = Expr[t, N] })#E]): N[T] = - Expr.evaluate(idToExp, this) -} -case class Const[T, N[_]](value: N[T]) extends Expr[T, N] { - override def evaluate(idToExp: HMap[Id, ({ type E[t] = Expr[t, N] })#E]): N[T] = value -} -case class Var[T, N[_]](name: Id[T]) extends Expr[T, N] -case class Unary[T1, T2, N[_]](arg: Id[T1], fn: N[T1] => N[T2]) extends Expr[T2, N] -case class Binary[T1, T2, T3, N[_]](arg1: Id[T1], - arg2: Id[T2], - fn: (N[T1], N[T2]) => N[T3]) extends Expr[T3, N] - -object Expr { - def evaluate[T, N[_]](idToExp: HMap[Id, ({ type E[t] = Expr[t, N] })#E], expr: Expr[T, N]): N[T] = - evaluate(idToExp, HMap.empty[({ type E[t] = Expr[t, N] })#E, N], expr)._2 - - private def evaluate[T, N[_]](idToExp: HMap[Id, ({ type E[t] = Expr[t, N] })#E], - cache: HMap[({ type E[t] = Expr[t, N] })#E, N], - expr: Expr[T, N]): (HMap[({ type E[t] = Expr[t, N] })#E, N], N[T]) = cache.get(expr) match { - case Some(node) => (cache, node) - case None => expr match { - case Const(n) => (cache + (expr -> n), n) - case Var(id) => - val (c1, n) = evaluate(idToExp, cache, idToExp(id)) - (c1 + (expr -> n), n) - case Unary(id, fn) => - val (c1, n1) = evaluate(idToExp, cache, idToExp(id)) - val n2 = fn(n1) - (c1 + (expr -> n2), n2) - case Binary(id1, id2, fn) => - val (c1, n1) = evaluate(idToExp, cache, idToExp(id1)) - val (c2, n2) = evaluate(idToExp, c1, idToExp(id2)) - val n3 = fn(n1, n2) - (c2 + (expr -> n3), n3) - } - } -} diff --git a/scalding-graph/src/main/scala/com/twitter/scalding/graph/ExpressionDag.scala b/scalding-graph/src/main/scala/com/twitter/scalding/graph/ExpressionDag.scala deleted file mode 100644 index 5ac3eb79ae..0000000000 --- a/scalding-graph/src/main/scala/com/twitter/scalding/graph/ExpressionDag.scala +++ /dev/null @@ -1,372 +0,0 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.scalding.graph - -///////////////////// -// There is no logical reason for Literal[T, N] to be here, -// but the scala compiler crashes in 2.9.3 if it is not. -// with: -// java.lang.Error: typeConstructor inapplicable for -// at scala.tools.nsc.symtab.SymbolTable.abort(SymbolTable.scala:34) -// at scala.tools.nsc.symtab.Symbols$Symbol.typeConstructor(Symbols.scala:880) -//////////////////// - -/** - * This represents literal expressions (no variable redirection) - * of container nodes of type N[T] - */ -sealed trait Literal[T, N[_]] { - def evaluate: N[T] = Literal.evaluate(this) -} -case class ConstLit[T, N[_]](override val evaluate: N[T]) extends Literal[T, N] -case class UnaryLit[T1, T2, N[_]](arg: Literal[T1, N], - fn: N[T1] => N[T2]) extends Literal[T2, N] { -} -case class BinaryLit[T1, T2, T3, N[_]](arg1: Literal[T1, N], arg2: Literal[T2, N], - fn: (N[T1], N[T2]) => N[T3]) extends Literal[T3, N] { -} - -object Literal { - /** - * This evaluates a literal formula back to what it represents - * being careful to handle diamonds by creating referentially - * equivalent structures (not just structurally equivalent) - */ - def evaluate[T, N[_]](lit: Literal[T, N]): N[T] = - evaluate(HMap.empty[({ type L[T] = Literal[T, N] })#L, N], lit)._2 - - // Memoized version of the above to handle diamonds - private def evaluate[T, N[_]](hm: HMap[({ type L[T1] = Literal[T1, N] })#L, N], lit: Literal[T, N]): (HMap[({ type L[T1] = Literal[T1, N] })#L, N], N[T]) = - hm.get(lit) match { - case Some(prod) => (hm, prod) - case None => - lit match { - case ConstLit(prod) => (hm + (lit -> prod), prod) - case UnaryLit(in, fn) => - val (h1, p1) = evaluate(hm, in) - val p2 = fn(p1) - (h1 + (lit -> p2), p2) - case BinaryLit(in1, in2, fn) => - val (h1, p1) = evaluate(hm, in1) - val (h2, p2) = evaluate(h1, in2) - val p3 = fn(p1, p2) - (h2 + (lit -> p3), p3) - } - } -} - -sealed trait ExpressionDag[N[_]] { self => - // Once we fix N above, we can make E[T] = Expr[T, N] - type E[t] = Expr[t, N] - type Lit[t] = Literal[t, N] - - /** - * These have package visibility to test - * the law that for all Expr, the node they - * evaluate to is unique - */ - protected[graph] def idToExp: HMap[Id, E] - protected def nodeToLiteral: GenFunction[N, Lit] - protected def roots: Set[Id[_]] - protected def nextId: Int - - private def copy(id2Exp: HMap[Id, E] = self.idToExp, - node2Literal: GenFunction[N, Lit] = self.nodeToLiteral, - gcroots: Set[Id[_]] = self.roots, - id: Int = self.nextId): ExpressionDag[N] = new ExpressionDag[N] { - def idToExp = id2Exp - def roots = gcroots - def nodeToLiteral = node2Literal - def nextId = id - } - - override def toString: String = - "ExpressionDag(idToExp = %s)".format(idToExp) - - // This is a cache of Id[T] => Option[N[T]] - private val idToN = - new HCache[Id, ({ type ON[T] = Option[N[T]] })#ON]() - private val nodeToId = - new HCache[N, ({ type OID[T] = Option[Id[T]] })#OID]() - - /** - * Add a GC root, or tail in the DAG, that can never be deleted - * currently, we only support a single root - */ - private def addRoot[_](id: Id[_]) = copy(gcroots = roots + id) - - /** - * Which ids are reachable from the roots - */ - private def reachableIds: Set[Id[_]] = { - // We actually don't care about the return type of the Set - // This is a constant function at the type level - type IdSet[t] = Set[Id[_]] - def expand(s: Set[Id[_]]): Set[Id[_]] = { - val partial = new GenPartial[HMap[Id, E]#Pair, IdSet] { - def apply[T] = { - case (id, Const(_)) if s(id) => s - case (id, Var(v)) if s(id) => s + v - case (id, Unary(id0, _)) if s(id) => s + id0 - case (id, Binary(id0, id1, _)) if s(id) => (s + id0) + id1 - } - } - // Note this Stream must always be non-empty as long as roots are - idToExp.collect[IdSet](partial) - .reduce(_ ++ _) - } - // call expand while we are still growing - def go(s: Set[Id[_]]): Set[Id[_]] = { - val step = expand(s) - if (step == s) s - else go(step) - } - go(roots) - } - - private def gc: ExpressionDag[N] = { - val goodIds = reachableIds - type BoolT[t] = Boolean - val toKeepI2E = idToExp.filter(new GenFunction[HMap[Id, E]#Pair, BoolT] { - def apply[T] = { idExp => goodIds(idExp._1) } - }) - copy(id2Exp = toKeepI2E) - } - - /** - * Apply the given rule to the given dag until - * the graph no longer changes. - */ - def apply(rule: Rule[N]): ExpressionDag[N] = { - // for some reason, scala can't optimize this with tailrec - var prev: ExpressionDag[N] = null - var curr: ExpressionDag[N] = this - while (!(curr eq prev)) { - prev = curr - curr = curr.applyOnce(rule) - } - curr - } - - protected def toExpr[T](n: N[T]): (ExpressionDag[N], Expr[T, N]) = { - val (dag, id) = ensure(n) - val exp = dag.idToExp(id) - (dag, exp) - } - - /** - * Convert a N[T] to a Literal[T, N] - */ - def toLiteral[T](n: N[T]): Literal[T, N] = nodeToLiteral.apply[T](n) - - /** - * apply the rule at the first place that satisfies - * it, and return from there. - */ - def applyOnce(rule: Rule[N]): ExpressionDag[N] = { - val getN = new GenPartial[HMap[Id, E]#Pair, HMap[Id, N]#Pair] { - def apply[U] = { - val fn = rule.apply[U](self) - - { - case (id, exp) if fn(exp.evaluate(idToExp)).isDefined => - // Sucks to have to call fn, twice, but oh well - - fn(exp.evaluate(idToExp)) match { - case Some(n) => (id, n) - case None => sys.error("unreachable since isDefined checked above") - } - } - } - } - idToExp.collect[HMap[Id, N]#Pair](getN).headOption match { - case None => this - case Some(tup) => - // some type hand holding - def act[T](in: HMap[Id, N]#Pair[T]) = { - val (i, n) = in - val oldNode = evaluate(i) - val (dag, exp) = toExpr(n) - dag.copy(id2Exp = dag.idToExp + (i -> exp)) - } - // This cast should not be needed - act(tup.asInstanceOf[HMap[Id, N]#Pair[Any]]).gc - } - } - - // This is only called by ensure - private def addExp[T](node: N[T], exp: Expr[T, N]): (ExpressionDag[N], Id[T]) = { - val nodeId = Id[T](nextId) - (copy(id2Exp = idToExp + (nodeId -> exp), id = nextId + 1), nodeId) - } - - /** - * This finds the Id[T] in the current graph that is equivalent - * to the given N[T] - */ - def find[T](node: N[T]): Option[Id[T]] = nodeToId.getOrElseUpdate(node, { - val partial = new GenPartial[HMap[Id, E]#Pair, Id] { - def apply[T1] = { case (thisId, expr) if node == expr.evaluate(idToExp) => thisId } - } - idToExp.collect(partial).headOption.asInstanceOf[Option[Id[T]]] - }) - - /** - * This throws if the node is missing, use find if this is not - * a logic error in your programming. With dependent types we could - * possibly get this to not compile if it could throw. - */ - def idOf[T](node: N[T]): Id[T] = - find(node) - .getOrElse(sys.error("could not get node: %s\n from %s".format(node, this))) - - /** - * ensure the given literal node is present in the Dag - * Note: it is important that at each moment, each node has - * at most one id in the graph. Put another way, for all - * Id[T] in the graph evaluate(id) is distinct. - */ - protected def ensure[T](node: N[T]): (ExpressionDag[N], Id[T]) = - find(node) match { - case Some(id) => (this, id) - case None => { - val lit: Lit[T] = toLiteral(node) - lit match { - case ConstLit(n) => - /** - * Since the code is not performance critical, but correctness critical, and we can't - * check this property with the typesystem easily, check it here - */ - assert(n == node, - "Equality or nodeToLiteral is incorrect: nodeToLit(%s) = ConstLit(%s)".format(node, n)) - addExp(node, Const(n)) - case UnaryLit(prev, fn) => - val (exp1, idprev) = ensure(prev.evaluate) - exp1.addExp(node, Unary(idprev, fn)) - case BinaryLit(n1, n2, fn) => - val (exp1, id1) = ensure(n1.evaluate) - val (exp2, id2) = exp1.ensure(n2.evaluate) - exp2.addExp(node, Binary(id1, id2, fn)) - } - } - } - - /** - * After applying rules to your Dag, use this method - * to get the original node type. - * Only call this on an Id[T] that was generated by - * this dag or a parent. - */ - def evaluate[T](id: Id[T]): N[T] = - evaluateOption(id).getOrElse(sys.error("Could not evaluate: %s\nin %s".format(id, this))) - - def evaluateOption[T](id: Id[T]): Option[N[T]] = - idToN.getOrElseUpdate(id, { - val partial = new GenPartial[HMap[Id, E]#Pair, N] { - def apply[T1] = { case (thisId, expr) if (id == thisId) => expr.evaluate(idToExp) } - } - idToExp.collect(partial).headOption.asInstanceOf[Option[N[T]]] - }) - - /** - * Return the number of nodes that depend on the - * given Id, TODO we might want to cache these. - * We need to garbage collect nodes that are - * no longer reachable from the root - */ - def fanOut(id: Id[_]): Int = { - // We make a fake IntT[T] which is just Int - val partial = new GenPartial[E, ({ type IntT[T] = Int })#IntT] { - def apply[T] = { - case Var(id1) if (id1 == id) => 1 - case Unary(id1, fn) if (id1 == id) => 1 - case Binary(id1, id2, fn) if (id1 == id) && (id2 == id) => 2 - case Binary(id1, id2, fn) if (id1 == id) || (id2 == id) => 1 - case _ => 0 - } - } - idToExp.collectValues[({ type IntT[T] = Int })#IntT](partial).sum - } - - /** - * Returns 0 if the node is absent, which is true - * use .contains(n) to check for containment - */ - def fanOut(node: N[_]): Int = find(node).map(fanOut(_)).getOrElse(0) - def contains(node: N[_]): Boolean = find(node).isDefined -} - -object ExpressionDag { - private def empty[N[_]](n2l: GenFunction[N, ({ type L[t] = Literal[t, N] })#L]): ExpressionDag[N] = - new ExpressionDag[N] { - val idToExp = HMap.empty[Id, ({ type E[t] = Expr[t, N] })#E] - val nodeToLiteral = n2l - val roots = Set.empty[Id[_]] - val nextId = 0 - } - - /** - * This creates a new ExpressionDag rooted at the given tail node - */ - def apply[T, N[_]](n: N[T], - nodeToLit: GenFunction[N, ({ type L[t] = Literal[t, N] })#L]): (ExpressionDag[N], Id[T]) = { - val (dag, id) = empty(nodeToLit).ensure(n) - (dag.addRoot(id), id) - } - - /** - * This is the most useful function. Given a N[T] and a way to convert to Literal[T, N], - * apply the given rule until it no longer applies, and return the N[T] which is - * equivalent under the given rule - */ - def applyRule[T, N[_]](n: N[T], - nodeToLit: GenFunction[N, ({ type L[t] = Literal[t, N] })#L], - rule: Rule[N]): N[T] = { - val (dag, id) = apply(n, nodeToLit) - dag(rule).evaluate(id) - } -} - -/** - * This implements a simplification rule on ExpressionDags - */ -trait Rule[N[_]] { self => - /** - * If the given Id can be replaced with a simpler expression, - * return Some(expr) else None. - * - * If it is convenient, you might write a partial function - * and then call .lift to get the correct Function type - */ - def apply[T](on: ExpressionDag[N]): (N[T] => Option[N[T]]) - - // If the current rule cannot apply, then try the argument here - def orElse(that: Rule[N]): Rule[N] = new Rule[N] { - def apply[T](on: ExpressionDag[N]) = { n => - self.apply(on)(n).orElse(that.apply(on)(n)) - } - } -} - -/** - * Often a partial function is an easier way to express rules - */ -trait PartialRule[N[_]] extends Rule[N] { - final def apply[T](on: ExpressionDag[N]) = applyWhere[T](on).lift - def applyWhere[T](on: ExpressionDag[N]): PartialFunction[N[T], N[T]] -} - diff --git a/scalding-graph/src/main/scala/com/twitter/scalding/graph/HMap.scala b/scalding-graph/src/main/scala/com/twitter/scalding/graph/HMap.scala deleted file mode 100644 index d5e0f9fe48..0000000000 --- a/scalding-graph/src/main/scala/com/twitter/scalding/graph/HMap.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.scalding.graph - -/** - * This is a weak heterogenous map. It uses equals on the keys, - * so it is your responsibilty that if k: K[_] == k2: K[_] then - * the types are actually equal (either be careful or store a - * type identifier). - */ -sealed abstract class HMap[K[_], V[_]] { - type Pair[t] = (K[t], V[t]) - protected val map: Map[K[_], V[_]] - override def toString: String = - "H%s".format(map) - - override def equals(that: Any): Boolean = that match { - case null => false - case h: HMap[_, _] => map.equals(h.map) - case _ => false - } - override def hashCode = map.hashCode - - def +[T](kv: (K[T], V[T])): HMap[K, V] = - HMap.from[K, V](map + kv) - - def -(k: K[_]): HMap[K, V] = - HMap.from[K, V](map - k) - - def apply[T](id: K[T]): V[T] = get(id) match { - case Some(v) => v - case None => throw new java.util.NoSuchElementException(s"$id has no value") - } - - def contains[T](id: K[T]): Boolean = get(id).isDefined - - def filter(pred: GenFunction[Pair, ({ type BoolT[T] = Boolean })#BoolT]): HMap[K, V] = { - val filtered = map.asInstanceOf[Map[K[Any], V[Any]]].filter(pred.apply[Any]) - HMap.from[K, V](filtered.asInstanceOf[Map[K[_], V[_]]]) - } - - def get[T](id: K[T]): Option[V[T]] = - map.get(id).asInstanceOf[Option[V[T]]] - - def keysOf[T](v: V[T]): Set[K[T]] = map.collect { - case (k, w) if v == w => - k.asInstanceOf[K[T]] - }.toSet - - // go through all the keys, and find the first key that matches this - // function and apply - def updateFirst(p: GenPartial[K, V]): Option[(HMap[K, V], K[_])] = { - def collector[T]: PartialFunction[(K[T], V[T]), (K[T], V[T])] = { - val pf = p.apply[T] - - { - case (kv: (K[T], V[T])) if pf.isDefinedAt(kv._1) => - val v2 = pf(kv._1) - (kv._1, v2) - } - } - - map.asInstanceOf[Map[K[Any], V[Any]]].collectFirst(collector) - .map { kv => - (this + kv, kv._1) - } - } - - def collect[R[_]](p: GenPartial[Pair, R]): Stream[R[_]] = - map.toStream.asInstanceOf[Stream[(K[Any], V[Any])]].collect(p.apply) - - def collectValues[R[_]](p: GenPartial[V, R]): Stream[R[_]] = - map.values.toStream.asInstanceOf[Stream[V[Any]]].collect(p.apply) -} - -// This is a function that preserves the inner type -trait GenFunction[T[_], R[_]] { - def apply[U]: (T[U] => R[U]) -} - -trait GenPartial[T[_], R[_]] { - def apply[U]: PartialFunction[T[U], R[U]] -} - -object HMap { - def empty[K[_], V[_]]: HMap[K, V] = from[K, V](Map.empty[K[_], V[_]]) - private def from[K[_], V[_]](m: Map[K[_], V[_]]): HMap[K, V] = - new HMap[K, V] { override val map = m } -} - -/** - * This is a useful cache for memoizing heterogenously types functions - */ -class HCache[K[_], V[_]]() { - private var hmap: HMap[K, V] = HMap.empty[K, V] - - /** - * Get snapshot of the current state - */ - def snapshot: HMap[K, V] = hmap - - def getOrElseUpdate[T](k: K[T], v: => V[T]): V[T] = - hmap.get(k) match { - case Some(exists) => exists - case None => - val res = v - hmap = hmap + (k -> res) - res - } -} - diff --git a/scalding-graph/src/main/scala/com/twitter/scalding/graph/package.scala b/scalding-graph/src/main/scala/com/twitter/scalding/graph/package.scala deleted file mode 100644 index 5bb10f71d1..0000000000 --- a/scalding-graph/src/main/scala/com/twitter/scalding/graph/package.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - Copyright 2013 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.scalding - -import scala.collection.mutable.{ Map => MMap } - -/** Collection of graph algorithms */ -package object graph { - type NeighborFn[T] = (T => Iterable[T]) - - /** - * Return the depth first enumeration of reachable nodes, - * NOT INCLUDING INPUT, unless it can be reached via neighbors - */ - def depthFirstOf[T](t: T)(nf: NeighborFn[T]): List[T] = { - @annotation.tailrec - def loop(stack: List[T], deps: List[T], acc: Set[T]): List[T] = { - stack match { - case Nil => deps - case h :: tail => - val newStack = nf(h).filterNot(acc).foldLeft(tail) { (s, it) => it :: s } - val newDeps = if (acc(h)) deps else h :: deps - loop(newStack, newDeps, acc + h) - } - } - val start = nf(t).toList - loop(start, start.distinct, start.toSet).reverse - } - - /** - * Return a NeighborFn for the graph of reversed edges defined by - * this set of nodes and nf - * We avoid Sets which use hash-codes which may depend on addresses - * which are not stable from one run to the next. - */ - def reversed[T](nodes: Iterable[T])(nf: NeighborFn[T]): NeighborFn[T] = { - val graph: Map[T, List[T]] = nodes - .foldLeft(Map.empty[T, List[T]]) { (g, child) => - val gWithChild = g + (child -> g.getOrElse(child, Nil)) - nf(child).foldLeft(gWithChild) { (innerg, parent) => - innerg + (parent -> (child :: innerg.getOrElse(parent, Nil))) - } - } - // make sure the values are sets, not .mapValues is lazy in scala - .map { case (k, v) => (k, v.distinct) }; - graph.getOrElse(_, Nil) - } - - /** - * Return the depth of each node in the dag. - * a node that has no dependencies has depth == 0 - * else it is max of parent + 1 - * - * Behavior is not defined if the graph is not a DAG (for now, it runs forever, may throw later) - */ - def dagDepth[T](nodes: Iterable[T])(nf: NeighborFn[T]): Map[T, Int] = { - val acc = MMap[T, Int]() - @annotation.tailrec - def computeDepth(todo: Set[T]): Unit = - if (!todo.isEmpty) { - def withParents(n: T) = (n :: (nf(n).toList)).filterNot(acc.contains(_)).distinct - - val (doneThisStep, rest) = todo.map { withParents(_) }.partition { _.size == 1 } - - acc ++= (doneThisStep.flatten.map { n => - val depth = nf(n) //n is done now, so all it's neighbors must be too. - .map { acc(_) + 1 } - .reduceOption { _ max _ } - .getOrElse(0) - n -> depth - }) - computeDepth(rest.flatten) - } - - computeDepth(nodes.toSet) - acc.toMap - } -} diff --git a/scalding-graph/src/test/scala/com/twitter/scalding/graph/ExpressionDagTests.scala b/scalding-graph/src/test/scala/com/twitter/scalding/graph/ExpressionDagTests.scala deleted file mode 100644 index 326c0ce117..0000000000 --- a/scalding-graph/src/test/scala/com/twitter/scalding/graph/ExpressionDagTests.scala +++ /dev/null @@ -1,205 +0,0 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.scalding.graph - -import org.scalacheck.Prop._ -import org.scalacheck.{ Gen, Properties } - -object ExpressionDagTests extends Properties("ExpressionDag") { - /* - * Here we test with a simple algebra optimizer - */ - - sealed trait Formula[T] { // we actually will ignore T - def evaluate: Int - def closure: Set[Formula[T]] - } - case class Constant[T](override val evaluate: Int) extends Formula[T] { - def closure = Set(this) - } - case class Inc[T](in: Formula[T], by: Int) extends Formula[T] { - def evaluate = in.evaluate + by - def closure = in.closure + this - } - case class Sum[T](left: Formula[T], right: Formula[T]) extends Formula[T] { - def evaluate = left.evaluate + right.evaluate - def closure = (left.closure ++ right.closure) + this - } - case class Product[T](left: Formula[T], right: Formula[T]) extends Formula[T] { - def evaluate = left.evaluate * right.evaluate - def closure = (left.closure ++ right.closure) + this - } - - def genForm: Gen[Formula[Int]] = Gen.frequency((1, genProd), - (1, genSum), - (4, genInc), - (4, genConst)) - - def genConst: Gen[Formula[Int]] = Gen.chooseNum(Int.MinValue, Int.MaxValue).map(Constant(_)) - def genInc: Gen[Formula[Int]] = for { - by <- Gen.chooseNum(Int.MinValue, Int.MaxValue) - f <- Gen.lzy(genForm) - } yield Inc(f, by) - - def genSum: Gen[Formula[Int]] = for { - left <- Gen.lzy(genForm) - // We have to make dags, so select from the closure of left sometimes - right <- Gen.oneOf(genForm, Gen.oneOf(left.closure.toSeq)) - } yield Sum(left, right) - def genProd: Gen[Formula[Int]] = for { - left <- Gen.lzy(genForm) - // We have to make dags, so select from the closure of left sometimes - right <- Gen.oneOf(genForm, Gen.oneOf(left.closure.toSeq)) - } yield Product(left, right) - - type L[T] = Literal[T, Formula] - - /** - * Here we convert our dag nodes into Literal[Formula, T] - */ - def toLiteral = new GenFunction[Formula, L] { - def apply[T] = { (form: Formula[T]) => - def recurse[T2](memo: HMap[Formula, L], f: Formula[T2]): (HMap[Formula, L], L[T2]) = memo.get(f) match { - case Some(l) => (memo, l) - case None => f match { - case c @ Constant(_) => - def makeLit[T1](c: Constant[T1]) = { - val lit: L[T1] = ConstLit(c) - (memo + (c -> lit), lit) - } - makeLit(c) - case inc @ Inc(_, _) => - def makeLit[T1](i: Inc[T1]) = { - val (m1, f1) = recurse(memo, i.in) - val lit = UnaryLit(f1, { f: Formula[T1] => Inc(f, i.by) }) - (m1 + (i -> lit), lit) - } - makeLit(inc) - case sum @ Sum(_, _) => - def makeLit[T1](s: Sum[T1]) = { - val (m1, fl) = recurse(memo, s.left) - val (m2, fr) = recurse(m1, s.right) - val lit = BinaryLit(fl, fr, { (f: Formula[T1], g: Formula[T1]) => Sum(f, g) }) - (m2 + (s -> lit), lit) - } - makeLit(sum) - case prod @ Product(_, _) => - def makeLit[T1](p: Product[T1]) = { - val (m1, fl) = recurse(memo, p.left) - val (m2, fr) = recurse(m1, p.right) - val lit = BinaryLit(fl, fr, { (f: Formula[T1], g: Formula[T1]) => Product(f, g) }) - (m2 + (p -> lit), lit) - } - makeLit(prod) - } - } - recurse(HMap.empty[Formula, L], form)._2 - } - } - - /** - * Inc(Inc(a, b), c) = Inc(a, b + c) - */ - object CombineInc extends Rule[Formula] { - def apply[T](on: ExpressionDag[Formula]) = { - case Inc(i @ Inc(a, b), c) if on.fanOut(i) == 1 => Some(Inc(a, b + c)) - case _ => None - } - } - - object RemoveInc extends PartialRule[Formula] { - def applyWhere[T](on: ExpressionDag[Formula]) = { - case Inc(f, by) => Sum(f, Constant(by)) - } - } - - //Check the Node[T] <=> Id[T] is an Injection for all nodes reachable from the root - - property("toLiteral/Literal.evaluate is a bijection") = forAll(genForm) { form => - toLiteral.apply(form).evaluate == form - } - - property("Going to ExpressionDag round trips") = forAll(genForm) { form => - val (dag, id) = ExpressionDag(form, toLiteral) - dag.evaluate(id) == form - } - - property("CombineInc does not change results") = forAll(genForm) { form => - val simplified = ExpressionDag.applyRule(form, toLiteral, CombineInc) - form.evaluate == simplified.evaluate - } - - property("RemoveInc removes all Inc") = forAll(genForm) { form => - val noIncForm = ExpressionDag.applyRule(form, toLiteral, RemoveInc) - def noInc(f: Formula[Int]): Boolean = f match { - case Constant(_) => true - case Inc(_, _) => false - case Sum(l, r) => noInc(l) && noInc(r) - case Product(l, r) => noInc(l) && noInc(r) - } - noInc(noIncForm) && (noIncForm.evaluate == form.evaluate) - } - - /** - * This law is important for the rules to work as expected, and not have equivalent - * nodes appearing more than once in the Dag - */ - property("Node structural equality implies Id equality") = forAll(genForm) { form => - val (dag, id) = ExpressionDag(form, toLiteral) - type BoolT[T] = Boolean // constant type function - dag.idToExp.collect(new GenPartial[HMap[Id, ExpressionDag[Formula]#E]#Pair, BoolT] { - def apply[T] = { - case (id, expr) => - val node = expr.evaluate(dag.idToExp) - dag.idOf(node) == id - } - }).forall(identity) - } - - // The normal Inc gen recursively calls the general dag Generator - def genChainInc: Gen[Formula[Int]] = for { - by <- Gen.chooseNum(Int.MinValue, Int.MaxValue) - chain <- genChain - } yield Inc(chain, by) - - def genChain: Gen[Formula[Int]] = Gen.frequency((1, genConst), (3, genChainInc)) - property("CombineInc compresses linear Inc chains") = forAll(genChain) { chain => - ExpressionDag.applyRule(chain, toLiteral, CombineInc) match { - case Constant(n) => true - case Inc(Constant(n), b) => true - case _ => false // All others should have been compressed - } - } - - /** - * We should be able to totally evaluate these formulas - */ - object EvaluationRule extends Rule[Formula] { - def apply[T](on: ExpressionDag[Formula]) = { - case Sum(Constant(a), Constant(b)) => Some(Constant(a + b)) - case Product(Constant(a), Constant(b)) => Some(Constant(a * b)) - case Inc(Constant(a), b) => Some(Constant(a + b)) - case _ => None - } - } - property("EvaluationRule totally evaluates") = forAll(genForm) { form => - ExpressionDag.applyRule(form, toLiteral, EvaluationRule) match { - case Constant(x) if x == form.evaluate => true - case _ => false - } - } -} diff --git a/scalding-graph/src/test/scala/com/twitter/scalding/graph/HMapTests.scala b/scalding-graph/src/test/scala/com/twitter/scalding/graph/HMapTests.scala deleted file mode 100644 index a1d8e8da6f..0000000000 --- a/scalding-graph/src/test/scala/com/twitter/scalding/graph/HMapTests.scala +++ /dev/null @@ -1,107 +0,0 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.scalding.graph - -import org.scalacheck.Prop._ -import org.scalacheck.{ Arbitrary, Gen, Properties } - -/** - * This tests the HMap. We use the type system to - * prove the types are correct and don't (yet?) engage - * in the problem of higher kinded Arbitraries. - */ -object HMapTests extends Properties("HMap") { - case class Key[T](key: Int) - case class Value[T](value: Int) - - implicit def keyGen: Gen[Key[Int]] = Gen.choose(Int.MinValue, Int.MaxValue).map(Key(_)) - implicit def valGen: Gen[Value[Int]] = Gen.choose(Int.MinValue, Int.MaxValue).map(Value(_)) - - def zip[T, U](g: Gen[T], h: Gen[U]): Gen[(T, U)] = for { - a <- g - b <- h - } yield (a, b) - - implicit def hmapGen: Gen[HMap[Key, Value]] = - Gen.listOf(zip(keyGen, valGen)).map { list => - list.foldLeft(HMap.empty[Key, Value]) { (hm, kv) => - hm + kv - } - } - - implicit def arb[T](implicit g: Gen[T]): Arbitrary[T] = Arbitrary(g) - - property("adding a pair works") = forAll { (hmap: HMap[Key, Value], k: Key[Int], v: Value[Int]) => - val initContains = hmap.contains(k) - val added = hmap + (k -> v) - // Adding puts the item in, and does not change the initial - (added.get(k) == Some(v)) && - (initContains == hmap.contains(k)) && - (initContains == hmap.get(k).isDefined) - } - property("removing a key works") = forAll { (hmap: HMap[Key, Value], k: Key[Int]) => - val initContains = hmap.get(k).isDefined - val next = hmap - k - // Adding puts the item in, and does not change the initial - (!next.contains(k)) && - (initContains == hmap.contains(k)) && - (next.get(k) == None) - } - - property("keysOf works") = forAll { (hmap: HMap[Key, Value], k: Key[Int], v: Value[Int]) => - val initKeys = hmap.keysOf(v) - val added = hmap + (k -> v) - val finalKeys = added.keysOf(v) - val sizeIsConsistent = (finalKeys -- initKeys).size match { - case 0 => hmap.contains(k) // initially present - case 1 => !hmap.contains(k) // initially absent - case _ => false // we can't change the count by more than 1. - } - - sizeIsConsistent && added.contains(k) - } - - property("updateFirst works") = forAll { (hmap: HMap[Key, Value]) => - val partial = new GenPartial[Key, Value] { - def apply[T] = { case Key(id) if (id % 2 == 0) => Value(0) } - } - hmap.updateFirst(partial) match { - case Some((updated, k)) => updated.get(k) == Some(Value(0)) - case None => true - } - } - - property("collect works") = forAll { (map: Map[Key[Int], Value[Int]]) => - val hm = map.foldLeft(HMap.empty[Key, Value])(_ + _) - val partial = new GenPartial[HMap[Key, Value]#Pair, Value] { - def apply[T] = { case (Key(k), Value(v)) if k > v => Value(k * v) } - } - val collected = hm.collect(partial).map { case Value(v) => v }.toSet - val mapCollected = map.collect(partial.apply[Int]).map { case Value(v) => v }.toSet - collected == mapCollected - } - - property("collectValues works") = forAll { (map: Map[Key[Int], Value[Int]]) => - val hm = map.foldLeft(HMap.empty[Key, Value])(_ + _) - val partial = new GenPartial[Value, Value] { - def apply[T] = { case Value(v) if v < 0 => Value(v * v) } - } - val collected = hm.collectValues(partial).map { case Value(v) => v }.toSet - val mapCollected = map.values.collect(partial.apply[Int]).map { case Value(v) => v }.toSet - collected == mapCollected - } -} diff --git a/scalding-graph/src/test/scala/com/twitter/scalding/graph/LiteralTests.scala b/scalding-graph/src/test/scala/com/twitter/scalding/graph/LiteralTests.scala deleted file mode 100644 index c3944fce99..0000000000 --- a/scalding-graph/src/test/scala/com/twitter/scalding/graph/LiteralTests.scala +++ /dev/null @@ -1,68 +0,0 @@ -/* - Copyright 2014 Twitter, Inc. - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - */ - -package com.twitter.scalding.graph - -import org.scalacheck.Prop._ -import org.scalacheck.{ Arbitrary, Gen, Properties } - -object LiteralTests extends Properties("Literal") { - case class Box[T](get: T) - - def transitiveClosure[N[_]](l: Literal[_, N], acc: Set[Literal[_, N]] = Set.empty[Literal[_, N]]): Set[Literal[_, N]] = l match { - case c @ ConstLit(_) => acc + c - case u @ UnaryLit(prev, _) => if (acc(u)) acc else transitiveClosure(prev, acc + u) - case b @ BinaryLit(p1, p2, _) => if (acc(b)) acc else transitiveClosure(p2, transitiveClosure(p1, acc + b)) - } - - def genBox: Gen[Box[Int]] = Gen.chooseNum(Int.MinValue, Int.MaxValue).map(Box(_)) - - def genConst: Gen[Literal[Int, Box]] = genBox.map(ConstLit(_)) - def genUnary: Gen[Literal[Int, Box]] = for { - fn <- Arbitrary.arbitrary[(Int) => (Int)] - bfn = { case Box(b) => Box(fn(b)) }: Box[Int] => Box[Int] - input <- genLiteral - } yield UnaryLit(input, bfn) - - def genBinary: Gen[Literal[Int, Box]] = for { - fn <- Arbitrary.arbitrary[(Int, Int) => (Int)] - bfn = { case (Box(l), Box(r)) => Box(fn(l, r)) }: (Box[Int], Box[Int]) => Box[Int] - left <- genLiteral - // We have to make dags, so select from the closure of left sometimes - right <- Gen.oneOf(genLiteral, genChooseFrom(transitiveClosure[Box](left))) - } yield BinaryLit(left, right, bfn) - - def genChooseFrom[N[_]](s: Set[Literal[_, N]]): Gen[Literal[Int, N]] = - Gen.oneOf(s.toSeq.asInstanceOf[Seq[Literal[Int, N]]]) - - /* - * Create dags. Don't use binary too much as it can create exponentially growing dags - */ - def genLiteral: Gen[Literal[Int, Box]] = Gen.frequency((3, genConst), - (6, genUnary), (1, genBinary)) - - //This evaluates by recursively walking the tree without memoization - //as lit.evaluate should do - def slowEvaluate[T](lit: Literal[T, Box]): Box[T] = lit match { - case ConstLit(n) => n - case UnaryLit(in, fn) => fn(slowEvaluate(in)) - case BinaryLit(a, b, fn) => fn(slowEvaluate(a), slowEvaluate(b)) - } - - property("Literal.evaluate must match simple explanation") = forAll(genLiteral) { (l: Literal[Int, Box]) => - l.evaluate == slowEvaluate(l) - } -} diff --git a/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopPlatformExecutionTest.scala b/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopPlatformExecutionTest.scala index 04a81c11a9..e0254f84e9 100644 --- a/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopPlatformExecutionTest.scala +++ b/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopPlatformExecutionTest.scala @@ -33,7 +33,7 @@ case class HadoopPlatformExecutionTest( override def run(): Unit = { System.setProperty("cascading.update.skip", "true") - val execution = init(cons) + val execution: Execution[Any] = init(cons) cluster.addClassSourceToClassPath(cons.getClass) cluster.addClassSourceToClassPath(execution.getClass) createSources() @@ -45,7 +45,7 @@ case class HadoopPlatformExecutionTest( override def execute(unit: Execution[_]): Unit = unit.waitFor(config, cluster.mode) match { - case Success(s) => s + case Success(_) => () case Failure(e) => throw e } -} \ No newline at end of file +} diff --git a/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopPlatformJobTest.scala b/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopPlatformJobTest.scala index a7e5729ce4..84fad089fd 100644 --- a/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopPlatformJobTest.scala +++ b/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/HadoopPlatformJobTest.scala @@ -63,7 +63,7 @@ case class HadoopPlatformJobTest( checkSinks() flowCheckers.foreach { checker => job.completedFlow.collect { - case f: Flow[JobConf] => checker(f) + case f: Flow[JobConf @unchecked] => checker(f) } } } @@ -74,7 +74,7 @@ case class HadoopPlatformJobTest( override final def execute(job: Job): Unit = { job.run() job.clear() - job.next match { + job.next match { // linter:ignore:UseOptionForeachNotPatMatch case Some(nextJob) => execute(nextJob) case None => () } diff --git a/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/MakeJar.scala b/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/MakeJar.scala index 70e8f59b50..438efb6e33 100644 --- a/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/MakeJar.scala +++ b/scalding-hadoop-test/src/main/scala/com/twitter/scalding/platform/MakeJar.scala @@ -66,7 +66,7 @@ object MakeJar { @annotation.tailrec private[this] def getRelativeFileBetween( parent: File, source: File, result: List[String] = List.empty): Option[File] = - Option(source) match { + Option(source) match { // linter:disable:UseOptionFlatMapNotPatMatch // need as is for tailrec case Some(src) => { if (parent == src) { result.foldLeft(None: Option[File]) { (cum, part) => diff --git a/scalding-hadoop-test/src/test/scala/com/twitter/scalding/platform/PlatformTest.scala b/scalding-hadoop-test/src/test/scala/com/twitter/scalding/platform/PlatformTest.scala index dff7cc316a..89d1d54ada 100644 --- a/scalding-hadoop-test/src/test/scala/com/twitter/scalding/platform/PlatformTest.scala +++ b/scalding-hadoop-test/src/test/scala/com/twitter/scalding/platform/PlatformTest.scala @@ -651,7 +651,7 @@ class PlatformTest extends WordSpec with Matchers with HadoopSharedPlatformTest .sink(output2) { _.toSet == outputData.toSet } .inspectCompletedFlow { flow => val steps = flow.getFlowSteps.asScala - steps should have size 3 + assert(steps.size <= 4) } .run() } @@ -738,7 +738,7 @@ class PlatformTest extends WordSpec with Matchers with HadoopSharedPlatformTest } .inspectCompletedFlow { flow => val steps = flow.getFlowSteps.asScala - steps should have size 2 + steps should have size 3 // TODO: this used to be 2, but we seem to be taking 3 steps on this now due to forcing hashJoins to disk // two steps given we auto checkpoint before the merge // user supplied forceToDisk should not add a third step } diff --git a/scalding-hraven/src/main/scala/com/twitter/scalding/hraven/estimation/HRavenHistoryService.scala b/scalding-hraven/src/main/scala/com/twitter/scalding/hraven/estimation/HRavenHistoryService.scala index a599e6c1e2..404975f949 100644 --- a/scalding-hraven/src/main/scala/com/twitter/scalding/hraven/estimation/HRavenHistoryService.scala +++ b/scalding-hraven/src/main/scala/com/twitter/scalding/hraven/estimation/HRavenHistoryService.scala @@ -197,7 +197,7 @@ trait HRavenHistoryService extends HistoryService { override def fetchHistory(info: FlowStrategyInfo, maxHistory: Int): Try[Seq[FlowStepHistory]] = fetchPastJobDetails(info.step, maxHistory).map { history => for { - step <- history + step <- history // linter:disable:MergeMaps keys = FlowStepKeys(step.getJobName, step.getUser, step.getPriority, step.getStatus, step.getVersion, "") // update HRavenHistoryService.TaskDetailFields when consuming additional task fields from hraven below tasks = step.getTasks.asScala.flatMap { taskDetails => @@ -246,4 +246,4 @@ trait HRavenHistoryService extends HistoryService { val counter = counters.getCounter(counterGroupName, counterName) if (counter != null) counter.getValue else 0L } -} \ No newline at end of file +} diff --git a/scalding-jdbc/src/main/scala/com/twitter/scalding/jdbc/DriverColumnDefiner.scala b/scalding-jdbc/src/main/scala/com/twitter/scalding/jdbc/DriverColumnDefiner.scala index 03942b531b..53195275fe 100644 --- a/scalding-jdbc/src/main/scala/com/twitter/scalding/jdbc/DriverColumnDefiner.scala +++ b/scalding-jdbc/src/main/scala/com/twitter/scalding/jdbc/DriverColumnDefiner.scala @@ -34,7 +34,7 @@ trait DriverColumnDefiner[Type <: JdbcType] { sizeOp: Option[Int] = None, defOp: Option[String]) = { val sizeStr = sizeOp.map { "(" + _.toString + ")" }.getOrElse("") - val defStr = defOp.map { " DEFAULT '" + _.toString + "' " }.getOrElse(" ") + val defStr = defOp.map { " DEFAULT '" + _ + "' " }.getOrElse(" ") ColumnDefinition(ColumnName(name), Definition(typeName + sizeStr + defStr + nullable.get)) } diff --git a/scalding-json/src/main/scala/com/twitter/scalding/JsonLine.scala b/scalding-json/src/main/scala/com/twitter/scalding/JsonLine.scala index f50850d647..e9676d399c 100644 --- a/scalding-json/src/main/scala/com/twitter/scalding/JsonLine.scala +++ b/scalding-json/src/main/scala/com/twitter/scalding/JsonLine.scala @@ -56,7 +56,7 @@ case class JsonLine(p: String, fields: Fields = Fields.ALL, case (_, None) => null case (h :: Nil, Some(fs)) => fs.get(h).orNull case (h :: tail, Some(fs)) => fs.get(h).orNull match { - case fs: Map[String, AnyRef] => nestedRetrieval(Option(fs), tail) + case fs: Map[String @unchecked, AnyRef @unchecked] => nestedRetrieval(Option(fs), tail) case _ => null } case (Nil, _) => null diff --git a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/HasColumnProjection.scala b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/HasColumnProjection.scala index 14cd243194..3a14111a48 100644 --- a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/HasColumnProjection.scala +++ b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/HasColumnProjection.scala @@ -63,5 +63,5 @@ sealed trait ColumnProjectionString { def globStrings: Set[String] def asSemicolonString: String = globStrings.mkString(";") } -case class DeprecatedColumnProjectionString(globStrings: Set[String]) extends ColumnProjectionString -case class StrictColumnProjectionString(globStrings: Set[String]) extends ColumnProjectionString +final case class DeprecatedColumnProjectionString(globStrings: Set[String]) extends ColumnProjectionString +final case class StrictColumnProjectionString(globStrings: Set[String]) extends ColumnProjectionString diff --git a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetReadSupportProvider.scala b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetReadSupportProvider.scala index d4ef737f8c..3c7cf49d97 100644 --- a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetReadSupportProvider.scala +++ b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetReadSupportProvider.scala @@ -106,20 +106,20 @@ class ParquetReadSupportProvider(schemaProvider: ParquetSchemaProvider) { } def matchPrimitiveField(converterType: Type): (Tree, Tree, Tree, Tree) = { - val converterName = newTermName(ctx.fresh(s"fieldConverter")) + val converterName = newTermName(ctx.fresh("fieldConverter")) val innerConverter: Tree = q"new $converterType()" val converter: Tree = fieldConverter(converterName, innerConverter, isPrimitive = true) createFieldMatchResult(converterName, converter) } def matchCaseClassField(groupConverter: Tree): (Tree, Tree, Tree, Tree) = { - val converterName = newTermName(ctx.fresh(s"fieldConverter")) + val converterName = newTermName(ctx.fresh("fieldConverter")) val converter: Tree = fieldConverter(converterName, groupConverter) createFieldMatchResult(converterName, converter) } def matchMapField(K: Type, V: Type, keyConverter: Tree, valueConverter: Tree): (Tree, Tree, Tree, Tree) = { - val converterName = newTermName(ctx.fresh(s"fieldConverter")) + val converterName = newTermName(ctx.fresh("fieldConverter")) val mapConverter = createMapFieldConverter(converterName, K, V, keyConverter, valueConverter) createFieldMatchResult(converterName, mapConverter) } diff --git a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetSchemaProvider.scala b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetSchemaProvider.scala index c2aa1ad1ad..78810dbb89 100644 --- a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetSchemaProvider.scala +++ b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/ParquetSchemaProvider.scala @@ -60,7 +60,7 @@ class ParquetSchemaProvider(fieldRenamer: (String => String)) { .declarations .collect { case m: MethodSymbol if m.isCaseAccessor => m } .map { accessorMethod => - val fieldName = accessorMethod.name.toTermName.toString + val fieldName = accessorMethod.name.toString val fieldType = accessorMethod.returnType matchField(fieldType, fieldName, isOption = false) }.toList diff --git a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/WriteSupportProvider.scala b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/WriteSupportProvider.scala index 6f5e1e6b39..2414907d35 100644 --- a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/WriteSupportProvider.scala +++ b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/macros/impl/WriteSupportProvider.scala @@ -54,7 +54,7 @@ class WriteSupportProvider(schemaProvider: ParquetSchemaProvider) { case tpe if tpe =:= typeOf[Byte] => writePrimitiveField(q"rc.addInteger($fValue.toInt)") case tpe if tpe.erasure =:= typeOf[Option[Any]] => - val cacheName = newTermName(ctx.fresh(s"optionIndex")) + val cacheName = newTermName(ctx.fresh("optionIndex")) val innerType = tpe.asInstanceOf[TypeRefApi].args.head val (_, subTree) = matchField(idx, innerType, q"$cacheName", groupName) (idx + 1, q"""if($fValue.isDefined) { diff --git a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/scheme/TypedParquetTupleScheme.scala b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/scheme/TypedParquetTupleScheme.scala index f604dfb421..e125c4603f 100644 --- a/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/scheme/TypedParquetTupleScheme.scala +++ b/scalding-parquet/src/main/scala/com/twitter/scalding/parquet/tuple/scheme/TypedParquetTupleScheme.scala @@ -19,7 +19,7 @@ import org.apache.parquet.hadoop.mapred.{ Container, DeprecatedParquetOutputForm import org.apache.parquet.hadoop.{ ParquetInputFormat, ParquetOutputFormat } import org.apache.parquet.schema._ -import scala.util.{ Failure, Success } +import scala.util.{ Failure, Success, Try } /** * Parquet tuple materializer permits to create user defined type record from parquet tuple values @@ -58,7 +58,7 @@ class ReadSupportInstanceProxy[T] extends ReadSupport[T] { def getDelegateInstance(conf: Configuration): ReadSupport[T] = { val readSupport = conf.get(ParquetInputOutputFormat.READ_SUPPORT_INSTANCE) require(readSupport != null && !readSupport.isEmpty, "no read support instance is configured") - val readSupportInstance = ParquetInputOutputFormat.injection.invert(readSupport) + val readSupportInstance: Try[Any] = ParquetInputOutputFormat.injection.invert(readSupport) readSupportInstance match { case Success(obj) => obj.asInstanceOf[ReadSupport[T]] @@ -111,7 +111,7 @@ class ParquetOutputFormatFromWriteSupportInstance[T] extends ParquetOutputFormat override def getWriteSupport(conf: Configuration): WriteSupport[T] = { val writeSupport = conf.get(ParquetInputOutputFormat.WRITE_SUPPORT_INSTANCE) require(writeSupport != null && !writeSupport.isEmpty, "no write support instance is configured") - val writeSupportInstance = ParquetInputOutputFormat.injection.invert(writeSupport) + val writeSupportInstance: Try[Any] = ParquetInputOutputFormat.injection.invert(writeSupport) writeSupportInstance match { case Success(obj) => obj.asInstanceOf[WriteSupport[T]] case Failure(e) => throw e diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Liftables.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Liftables.scala new file mode 100644 index 0000000000..032b2c3458 --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Liftables.scala @@ -0,0 +1,49 @@ +package com.twitter.scalding.quotation + +import scala.reflect.macros.blackbox.Context + +/** + * These Liftables allows us to lift values into quasiquote trees. + * For example: + * + * def test(v: Source) => q"$v" + * + * uses `sourceLiftable` + */ +trait Liftables { + val c: Context + import c.universe.{ TypeName => _, _ } + + protected implicit val sourceLiftable: Liftable[Source] = Liftable { + case Source(path, line) => q"_root_.com.twitter.scalding.quotation.Source($path, $line)" + } + + protected implicit val projectionsLiftable: Liftable[Projections] = Liftable { + case p => q"_root_.com.twitter.scalding.quotation.Projections(${p.set})" + } + + protected implicit val typeNameLiftable: Liftable[TypeName] = Liftable { + case TypeName(name) => q"_root_.com.twitter.scalding.quotation.TypeName($name)" + } + + protected implicit val accessorLiftable: Liftable[Accessor] = Liftable { + case Accessor(name) => q"_root_.com.twitter.scalding.quotation.Accessor($name)" + } + + protected implicit val quotedLiftable: Liftable[Quoted] = Liftable { + case Quoted(source, call, fa) => q"_root_.com.twitter.scalding.quotation.Quoted($source, $call, $fa)" + } + + protected implicit val projectionLiftable: Liftable[Projection] = Liftable { + case p: Property => q"$p" + case p: TypeReference => q"$p" + } + + protected implicit val propertyLiftable: Liftable[Property] = Liftable { + case Property(path, accessor, tpe) => q"_root_.com.twitter.scalding.quotation.Property($path, $accessor, $tpe)" + } + + protected implicit val typeReferenceLiftable: Liftable[TypeReference] = Liftable { + case TypeReference(name) => q"_root_.com.twitter.scalding.quotation.TypeReference($name)" + } +} diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Projection.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Projection.scala new file mode 100644 index 0000000000..fa39a391a4 --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Projection.scala @@ -0,0 +1,160 @@ +package com.twitter.scalding.quotation + +import scala.annotation.tailrec + +case class Accessor(asString: String) extends AnyVal +case class TypeName(asString: String) extends AnyVal + +sealed trait Projection { + def andThen(accessor: Accessor, typeName: TypeName): Projection = + Property(this, accessor, typeName) + + def rootProjection: TypeReference = { + @tailrec def loop(p: Projection): TypeReference = + p match { + case p @ TypeReference(_) => p + case Property(p, _, _) => loop(p) + } + loop(this) + } + + /** + * Given a base projection, returns the projection based on it if applicable. + * + * For instance, given a quoted function + * `val contact = Quoted.function { (c: Contact) => c.contact }` + * and a call + * `(p: Person) => contact(p.name)` + * produces the projection + * `Person.name.contact` + */ + def basedOn(base: Projection): Option[Projection] = + this match { + case TypeReference(tpe) => + base match { + case TypeReference(`tpe`) => Some(base) + case Property(_, _, `tpe`) => Some(base) + case other => None + } + case Property(path, name, tpe) => + path.basedOn(base).map(Property(_, name, tpe)) + } + + /** + * Limits projections to only values of `superClass`. Example: + * + * case class Person(name: String, contact: Contact) extends ThriftObject + * case class Contact(phone: Phone) extends ThriftObject + * case class Phone(number: String) + * + * For the super class `ThriftObject`, it produces the transformations: + * + * Person.contact.phone => Some(Person.contact.phone) + * Person.contact.phone.number => Some(Person.contact.phone) + * Person.name.isEmpty => Some(Person.name) + * Phone.number => None + */ + def bySuperClass(superClass: Class[_]): Option[Projection] = { + + def isSubclass(c: TypeName) = + try + superClass.isAssignableFrom(Class.forName(c.asString)) + catch { + case _: ClassNotFoundException => + false + } + + def loop(p: Projection): Either[Projection, Option[Projection]] = + p match { + case TypeReference(typeName) => + Either.cond(!isSubclass(typeName), None, p) + case p @ Property(path, name, typeName) => + loop(path) match { + case Left(_) => + Either.cond(!isSubclass(typeName), Some(p), p) + case Right(path) => + Right(path) + } + } + + loop(this) match { + case Left(path) => Some(path) + case Right(opt) => opt + } + } +} + +/** + * A reference of a type. If not nested within a `Property`, it means that all fields are used. + */ +final case class TypeReference(typeName: TypeName) extends Projection { + override def toString = typeName.asString.split('.').last +} + +/** + * A projection property (e.g. `Person.name`) + */ +final case class Property(path: Projection, accessor: Accessor, typeName: TypeName) extends Projection { + override def toString = s"$path.${accessor.asString}" +} + +/** + * Utility class to deal with a collection of projections. + */ +final class Projections private (val set: Set[Projection]) extends Serializable { + + /** + * Returns the projections that are based on `typeName` and limits projections + * to only properties that extend from `superClass`. + */ + def of(typeName: TypeName, superClass: Class[_]): Projections = + Projections { + set.filter(_.rootProjection.typeName == typeName) + .flatMap(_.bySuperClass(superClass)) + } + + def basedOn(base: Set[Projection]): Projections = + Projections { + set.flatMap { p => + base.flatMap(p.basedOn) + } + } + + def ++(p: Projections) = + Projections(set ++ p.set) + + override def toString = + s"Projections(${set.mkString(", ")})" + + override def equals(other: Any) = + other match { + case other: Projections => set == other.set + case other => false + } + + override def hashCode = + 31 * set.hashCode +} + +object Projections { + val empty = apply(Set.empty) + + /** + * Creates a normalized projections collection. For instance, + * given two projections `Person.contact` and `Person.contact.phone`, + * creates a collection with only `Person.contact`. + */ + def apply(set: Set[Projection]) = { + @tailrec def isNested(p: Projection): Boolean = + p match { + case Property(path, acessor, property) => + set.contains(path) || isNested(path) + case _ => + false + } + new Projections(set.filter(!isNested(_))) + } + + def flatten(list: Iterable[Projections]): Projections = + list.foldLeft(empty)(_ ++ _) +} \ No newline at end of file diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/ProjectionMacro.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/ProjectionMacro.scala new file mode 100644 index 0000000000..f4529ff2cb --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/ProjectionMacro.scala @@ -0,0 +1,118 @@ +package com.twitter.scalding.quotation + +import scala.reflect.macros.blackbox.Context + +trait ProjectionMacro extends TreeOps with Liftables { + val c: Context + import c.universe.{ TypeName => _, _ } + + def projections(params: List[Tree]): Tree = { + + def typeName(t: Tree) = + TypeName(t.symbol.typeSignature.typeSymbol.fullName) + + def accessor(m: TermName) = + Accessor(m.decodedName.toString) + + def typeReference(tpe: Type) = + TypeReference(TypeName(tpe.typeSymbol.fullName)) + + def isFunction(t: Tree) = + Option(t.symbol).map { + _.typeSignature + .erasure + .typeSymbol + .fullName + .contains("scala.Function") + }.getOrElse(false) + + def functionBodyProjections(param: Tree, inputs: List[Tree], body: Tree): List[Tree] = { + + val inputSymbols = inputs.map(_.symbol).toSet + + object ProjectionExtractor { + def unapply(t: Tree): Option[Tree] = + t match { + + case q"$v.$m(..$params)" => unapply(v) + + case q"$v.$m" if t.symbol.isMethod => + + if (inputSymbols.contains(v.symbol)) { + val p = + TypeReference(typeName(v)) + .andThen(accessor(m), typeName(t)) + Some(q"$p") + } else + unapply(v).map { n => + q"$n.andThen(${accessor(m)}, ${typeName(t)})" + } + + case t if inputSymbols.contains(t.symbol) => + Some(q"${TypeReference(typeName(t))}") + + case _ => None + } + } + + def functionCall(func: Tree, params: List[Tree]): Tree = { + val paramProjections = params.flatMap(ProjectionExtractor.unapply) + q""" + $func match { + case f: _root_.com.twitter.scalding.quotation.QuotedFunction => + f.quoted.projections.basedOn($paramProjections.toSet) + case _ => + _root_.com.twitter.scalding.quotation.Projections(Set(..$paramProjections)) + } + """ + } + + collect(body) { + case q"$func.apply[..$t](..$params)" => + functionCall(func, params) + case q"$func(..$params)" if isFunction(func) => + functionCall(func, params) + case t @ ProjectionExtractor(p) => + q"_root_.com.twitter.scalding.quotation.Projections(Set($p))" + } + } + + def functionInstanceProjections(func: Tree): List[Tree] = { + val paramProjections = + func.symbol.typeSignature.typeArgs.dropRight(1) + .map(typeReference) + q""" + $func match { + case f: _root_.com.twitter.scalding.quotation.QuotedFunction => + f.quoted.projections + case _ => + _root_.com.twitter.scalding.quotation.Projections(Set(..$paramProjections)) + } + """ :: Nil + } + + def methodProjections(method: Tree): List[Tree] = { + val paramRefs = + method.symbol.asMethod.paramLists.flatten + .map(param => typeReference(param.typeSignature)) + q"${Projections(paramRefs.toSet)}" :: Nil + } + + val nestedList = + params.flatMap { + case param @ q"(..$inputs) => $body" => + functionBodyProjections(param, inputs, body) + + case func if isFunction(func) => + functionInstanceProjections(func) + + case method if method.symbol != null && method.symbol.isMethod => + methodProjections(method) + + case other => + Nil + } + + q"_root_.com.twitter.scalding.quotation.Projections.flatten($nestedList)" + } +} diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Quoted.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Quoted.scala new file mode 100644 index 0000000000..805c174b5f --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Quoted.scala @@ -0,0 +1,32 @@ +package com.twitter.scalding.quotation + +import java.io.File + +/** + * Meta information about a method call. + */ +case class Quoted(position: Source, text: Option[String], projections: Projections) { + override def toString = s"$position ${text.getOrElse("")}" +} + +object Quoted { + import language.experimental.macros + implicit def method: Quoted = macro QuotedMacro.method + + private[scalding] def internal: Quoted = macro QuotedMacro.internal + + def function[T1, U](f: T1 => U): Function1[T1, U] with QuotedFunction = macro QuotedMacro.function + def function[T1, T2, U](f: (T1, T2) => U): Function2[T1, T2, U] with QuotedFunction = macro QuotedMacro.function + def function[T1, T2, T3, U](f: (T1, T2, T3) => U): Function3[T1, T2, T3, U] with QuotedFunction = macro QuotedMacro.function + def function[T1, T2, T3, T4, U](f: (T1, T2, T3, T4) => U): Function4[T1, T2, T3, T4, U] with QuotedFunction = macro QuotedMacro.function + def function[T1, T2, T3, T4, T5, U](f: (T1, T2, T3, T4, T5) => U): Function5[T1, T2, T3, T4, T5, U] with QuotedFunction = macro QuotedMacro.function +} + +case class Source(path: String, line: Int) { + def classFile = path.split(File.separator).last + override def toString = s"$classFile:$line" +} + +trait QuotedFunction { + def quoted: Quoted +} diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/QuotedMacro.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/QuotedMacro.scala new file mode 100644 index 0000000000..04e9e001b6 --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/QuotedMacro.scala @@ -0,0 +1,111 @@ +package com.twitter.scalding.quotation + +import language.experimental.macros +import scala.reflect.macros.blackbox.Context +import scala.reflect.internal.util.RangePosition +import scala.reflect.internal.util.OffsetPosition +import scala.reflect.macros.runtime.{ Context => ReflectContext } +import java.io.File + +class QuotedMacro(val c: Context) + extends TreeOps + with TextMacro + with ProjectionMacro + with Liftables { + import c.universe._ + + def internal: Tree = quoted + + def method: Tree = { + rejectScaldingSources + quoted + } + + private def quoted: Tree = + quoted( + c.asInstanceOf[ReflectContext] + .callsiteTyper + .context + .tree + .asInstanceOf[Tree]) + + val QuotedCompanion = q"_root_.com.twitter.scalding.quotation.Quoted" + + private def quoted(tree: Tree): Tree = { + val source = Source(tree.pos.source.path, tree.pos.line) + + find(tree) { t => + t.pos != NoPosition && t.pos.start <= c.enclosingPosition.start + }.flatMap { t => + collect(t) { + + // the start position of vals is wrong, so we workaround + case q"val $name = $body" => quoted(body) + + case q"$m.method" if m.symbol.fullName == classOf[Quoted].getName => + c.abort( + c.enclosingPosition, + "Quoted.method can be invoked only as an implicit parameter") + + case tree @ q"$instance.$method[..$t]" => + q"${Quoted(source, Some(callText(method, t)), Projections.empty)}" + + case tree @ q"$instance.$method[..$t](...$params)" => + q""" + $QuotedCompanion( + $source, + Some(${callText(method, t ++ params.flatten)}), + ${projections(params.flatten)}) + """ + + }.headOption + }.getOrElse { + q"${Quoted(source, None, Projections.empty)}" + } + } + + def function(f: Tree): Tree = { + val source = Source(f.pos.source.path, f.pos.line) + val text = paramsText(TermName("function"), f) + f match { + case q"(..$params) => $body" => + c.untypecheck { + q""" + new ${f.tpe.finalResultType} with ${c.symbolOf[QuotedFunction]} { + override def apply(..$params) = $body + override def quoted = + $QuotedCompanion( + $source, + Some($text), + ${projections(f :: Nil)} + ) + } + """ + } + case _ => + c.abort(f.pos, "Expected a function") + } + } + + private def rejectScaldingSources = { + + def whitelist = + Set("test", "example", "tutorial") + .exists(c.enclosingPosition.source.path.contains) + + def isScalding(sym: Symbol): Boolean = + sym.fullName.startsWith("com.twitter.scalding") || { + sym.owner match { + case NoSymbol => false + case owner => isScalding(owner) + } + } + + if (!whitelist && isScalding(c.internal.enclosingOwner)) + c.abort( + c.enclosingPosition, + "The quotation must happen at the level of the user-facing API. Add an `implicit q: Quoted` to the enclosing method. " + + "If that's not possible and the transformation doesn't introduce projections, use Quoted.internal.") + } +} + diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/TextMacro.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/TextMacro.scala new file mode 100644 index 0000000000..f5538c9969 --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/TextMacro.scala @@ -0,0 +1,88 @@ +package com.twitter.scalding.quotation + +import scala.reflect.macros.blackbox.Context + +trait TextMacro { + val c: Context + import c.universe._ + + def callText(method: TermName, params: List[Tree]): String = + params.headOption.map(callText(method, _)).getOrElse(s"$method") + + def callText(method: TermName, firstParam: Tree): String = + s"$method${paramsText(method, firstParam)}" + + /* + * This should be something simple since Scala trees have the start and + * end positions. However, there's a bug that makes the positions unreliable. + * This method uses an ad-hoc parsing to get the text from the source file. + */ + def paramsText(method: TermName, firstParam: Tree): String = { + import c.universe._ + + val fileContent = c.enclosingPosition.source.content.mkString + + /* + * The start position of a tree isn't its actual start. It's necessary + * to find the minimum start of the nested trees, which is reliable. + */ + def start(t: Tree) = { + def loop(t: List[Tree]): List[Position] = + t.map(_.pos) ++ t.flatMap(t => loop(t.children)) + + loop(List(t)).filter(_ != NoPosition).map(_.start).min + } + + /* + * From the first parameter start position, walk back until the method + * call start and return the position immediately after the method name. + */ + val content = { + val reverseMethodName = + method.decodedName.toString.reverse + + def paramsStartPosition(content: String, pos: Int): Int = + if (content.startsWith(reverseMethodName) || content.isEmpty) + pos + else + paramsStartPosition(content.drop(1), pos - 1) + + val firstParamStart = start(firstParam) + + val newStart = + paramsStartPosition( + fileContent.take(firstParamStart).reverse, + firstParamStart) + + fileContent.drop(newStart).toList + } + + val blockDelimiters = + Map( + '(' -> ')', + '{' -> '}', + '[' -> ']') + + /* + * Reads the parameters block. It takes in consideration nested blocks like `map(v => { ... })` + */ + def readParams(chars: List[Char], open: List[Char], acc: List[Char] = Nil): (List[Char], List[Char]) = + chars match { + case Nil => + (acc, Nil) + case head :: tail => + blockDelimiters.get(head) match { + case Some(closing) => + val (block, rest) = readParams(tail, open :+ closing) + readParams(rest, open, acc ++ (head +: block :+ closing)) + case None => + if (head != ' ' && (open.isEmpty || head == open.last)) + (acc, tail) + else + readParams(tail, open, acc :+ head) + } + } + + readParams(content, Nil)._1.mkString + } +} \ No newline at end of file diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/TreeOps.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/TreeOps.scala new file mode 100644 index 0000000000..09c459e502 --- /dev/null +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/TreeOps.scala @@ -0,0 +1,46 @@ +package com.twitter.scalding.quotation + +import scala.reflect.macros.blackbox.Context + +trait TreeOps { + val c: Context + import c.universe._ + + /** + * Finds the first tree that satisfies the condition. + */ + def find(tree: Tree)(f: Tree => Boolean): Option[Tree] = { + var res: Option[Tree] = None + val t = new Traverser { + override def traverse(t: Tree) = { + if (res.isEmpty) + if (f(t)) + res = Some(t) + else + super.traverse(t) + } + } + t.traverse(tree) + res + } + + /** + * Similar to tree.collect but it doesn't collect the children of a + * collected tree. + */ + def collect[T](tree: Tree)(f: PartialFunction[Tree, T]): List[T] = { + var res = List.newBuilder[T] + val t = new Traverser { + override def traverse(t: Tree) = { + f.lift(t) match { + case Some(v) => + res += v + case None => + super.traverse(t) + } + } + } + t.traverse(tree) + res.result() + } +} \ No newline at end of file diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/LimitationsTest.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/LimitationsTest.scala new file mode 100644 index 0000000000..0320bf8f7f --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/LimitationsTest.scala @@ -0,0 +1,25 @@ +package com.twitter.scalding.quotation + +class LimitationsTest extends Test { + + class TestClass { + def function[T, U](f: T => U)(implicit q: Quoted) = (q, f) + } + + val test = new TestClass + + "nested transitive projection" in pendingUntilFixed { + test.function[Person, Option[String]](_.alternativeContact.map(_.phone))._1.projections.set mustEqual + Set(Person.typeReference.andThen(Accessor("alternativeContact"), typeName[Option[Contact]]).andThen(Accessor("phone"), typeName[String])) + } + + "nested quoted function projection" in pendingUntilFixed { + val contactFunction = Quoted.function { + (p: Person) => p.contact + } + val phoneFunction = Quoted.function { + (p: Person) => contactFunction(p).phone + } + phoneFunction.quoted.projections.set mustEqual Set(Person.phoneProjection) + } +} \ No newline at end of file diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/Person.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/Person.scala new file mode 100644 index 0000000000..f578c407ec --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/Person.scala @@ -0,0 +1,11 @@ +package com.twitter.scalding.quotation + +case class Contact(phone: String) +case class Person(name: String, contact: Contact, alternativeContact: Option[Contact]) + +object Person { + val typeReference = TypeReference(typeName[Person]) + val nameProjection = typeReference.andThen(Accessor("name"), typeName[String]) + val contactProjection = typeReference.andThen(Accessor("contact"), typeName[Contact]) + val phoneProjection = contactProjection.andThen(Accessor("phone"), typeName[String]) +} \ No newline at end of file diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/ProjectionMacroTest.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/ProjectionMacroTest.scala new file mode 100644 index 0000000000..fab42a2242 --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/ProjectionMacroTest.scala @@ -0,0 +1,105 @@ +package com.twitter.scalding.quotation + +import org.scalatest.Matchers +import org.scalatest.WordSpec +import org.scalatest.FreeSpec +import org.scalatest.MustMatchers + +class ProjectionMacroTest extends Test { + + class TestClass { + def function[T, U](f: T => U)(implicit m: Quoted) = (m, f) + def noProjection(i: Int)(implicit m: Quoted) = (m, i) + } + + val test = new TestClass + + "no projection" in { + test.noProjection(42)._1.projections.set mustEqual Set.empty + } + + "method with params isn't considered as projection" in { + test + .function[Person, String](_.name.substring(1))._1 + .projections.set mustEqual Set(Person.nameProjection) + } + + "simple" in { + test.function[Person, String](_.name)._1 + .projections.set mustEqual Set(Person.nameProjection) + } + + "nested" in { + test.function[Person, String](_.contact.phone)._1 + .projections.set mustEqual Set(Person.phoneProjection) + } + + "all properties" in { + test.function[Person, Person](p => p)._1 + .projections.set mustEqual Set(Person.typeReference) + } + + "empty projection" in { + test.function[Person, Int](p => 1)._1 + .projections.set mustEqual Set.empty + } + + "function call" - { + "implicit apply" - { + "non-quoted" in { + val function = (p: Person) => p.name + test.function[Person, String](p => function(p))._1 + .projections.set mustEqual Set(Person.typeReference) + } + "quoted" in { + val function = Quoted.function { + (p: Person) => p.name + } + test.function[Person, String](p => function(p))._1 + .projections.set mustEqual Set(Person.nameProjection) + } + } + "explicit apply" - { + "non-quoted" in { + val function = (p: Person) => p.name + test.function[Person, String](p => function.apply(p))._1 + .projections.set mustEqual Set(Person.typeReference) + } + "quoted" in { + val function = Quoted.function { + (p: Person) => p.name + } + test.function[Person, String](p => function.apply(p))._1 + .projections.set mustEqual Set(Person.nameProjection) + } + } + } + + "function instance" - { + "non-quoted" in { + val function = (p: Person) => p.name + test.function[Person, String](function)._1 + .projections.set mustEqual Set(Person.typeReference) + } + "quoted" in { + val function = Quoted.function { + (p: Person) => p.name + } + test.function[Person, String](function)._1 + .projections.set mustEqual Set(Person.nameProjection) + } + } + + "method call" - { + "in the function body" in { + def method(p: Person) = p.name + test.function[Person, String](p => method(p))._1 + .projections.set mustEqual Set(Person.typeReference) + } + "as function" in { + def method(p: Person) = p.name + test.function[Person, String](method)._1 + .projections.set mustEqual Set(Person.typeReference) + } + } +} diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/ProjectionTest.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/ProjectionTest.scala new file mode 100644 index 0000000000..690da72ffa --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/ProjectionTest.scala @@ -0,0 +1,168 @@ +package com.twitter.scalding.quotation + +import org.scalatest.MustMatchers +import org.scalatest.FreeSpec + +trait S + +trait T1 extends S +trait T2 + +trait P1 extends S +trait P2 + +class ProjectionTest extends Test { + + val t1 = TypeReference(typeName[T1]) + val p1 = Property(t1, Accessor("p1"), typeName[P1]) + + val t2 = TypeReference(TypeName(classOf[T2].getName)) + val p2 = Property(t2, Accessor("p2"), typeName[P2]) + + "Projection" - { + "andThen" - { + "TypeReference" in { + t1.andThen(p1.accessor, p1.typeName) mustEqual p1 + } + "Property" in { + p1.andThen(Accessor("p2"), TypeName("p2t")) mustEqual + Property(p1, Accessor("p2"), TypeName("p2t")) + } + } + + "toString" - { + "TypeReference" - { + "simple" in { + t1.toString mustEqual "T1" + } + "ignores package" in { + TypeReference(TypeName("com.twitter.Test1")).toString mustEqual "Test1" + } + } + "Property" in { + p1.toString() mustEqual "T1.p1" + } + } + } + + "Projections" - { + "empty" in { + Projections.empty.set mustEqual Set() + } + "apply" - { + "simple" in { + val set = Set[Projection](p1) + Projections(set).set mustEqual set + } + "paths merge" - { + "simple" in { + val set = Set[Projection](p1, t1) + Projections(set).set mustEqual Set(t1) + } + "nested" in { + val px = p1.andThen(Accessor("x"), TypeName("X")) + val set = Set[Projection](px, t1) + Projections(set).set mustEqual Set(t1) + } + } + } + "flatten" - { + "empty" in { + Projections.flatten(Nil).set mustEqual Set() + } + "non-empty" in { + val list = List( + Projections(Set(p1)), + Projections(Set(p2))) + Projections.flatten(list).set mustEqual Set(p1, p2) + } + "non-empty with merge" in { + val list = List( + Projections(Set(t1)), + Projections(Set(p1))) + Projections.flatten(list).set mustEqual Set(t1) + } + } + + "++" - { + "simple" in { + val p = Projections(Set(p1)) ++ Projections(Set(p2)) + p.set mustEqual Set(p1, p2) + } + "with merge" in { + val list = List( + Projections(Set(p1)), + Projections(Set(t1))) + Projections.flatten(list).set mustEqual Set(t1) + } + } + + "toString" - { + "empty" in { + Projections.empty.toString mustEqual "Projections()" + } + "non-empty" in { + Projections(Set(p1, p2)).toString mustEqual "Projections(T1.p1, T2.p2)" + } + } + + "basedOn" - { + "empty base" in { + Projections(Set(p1, p2)).basedOn(Set.empty) mustEqual + Projections.empty + } + "no match" in { + Projections(Set(p1, p2)).basedOn(Set(TypeReference(TypeName("X")))) mustEqual + Projections.empty + } + "one match" in { + val px1 = Property(TypeReference(TypeName("X")), Accessor("px"), typeName[T1]) + Projections(Set(p1, p2)).basedOn(Set(px1)).set mustEqual + Set(p1.copy(path = px1)) + } + "multiple matches" in { + val px1 = Property(TypeReference(TypeName("X1")), Accessor("px1"), typeName[T1]) + val px2 = Property(TypeReference(TypeName("X1")), Accessor("px2"), typeName[T2]) + Projections(Set(p1, p2)).basedOn(Set(px1, px2)).set mustEqual + Set(p1.copy(path = px1), p2.copy(path = px2)) + } + "partial match" in { + val px1 = Property(TypeReference(TypeName("X1")), Accessor("px1"), typeName[T1]) + val px2 = Property(TypeReference(TypeName("X1")), Accessor("px2"), TypeName("TX")) + Projections(Set(p1, p2)).basedOn(Set(px1, px2)).set mustEqual + Set(p1.copy(path = px1)) + } + } + + "of" - { + "byType" - { + "matches" in { + Projections(Set(t1)).of(t1.typeName, classOf[Any]).set mustEqual + Set(t1) + } + "doesn't match" in { + Projections(Set(t1)).of(TypeName("X"), classOf[Any]).set mustEqual + Set.empty + } + "nested" in { + val px = Property(p1, Accessor("px"), TypeName("PX")) + Projections(Set(px)).of(typeName[T1], classOf[Any]).set mustEqual + Set(px) + } + } + "bySuperClass" - { + "filters only projections of the super class type" in { + val px = p1.andThen(Accessor("px"), typeName[String]) + val py = px.andThen(Accessor("isEmpty"), typeName[Boolean]) + Projections(Set(py)).of(t1.typeName, classOf[S]).set mustEqual Set(px) + } + "ignores if class can't be loaded" in { + val tx = TypeReference(TypeName("TX")) + Projections(Set(tx)).of(tx.typeName, classOf[Any]).set mustEqual + Set.empty + } + } + } + + } +} \ No newline at end of file diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/QuotedMacroTest.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/QuotedMacroTest.scala new file mode 100644 index 0000000000..a6dcac7532 --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/QuotedMacroTest.scala @@ -0,0 +1,71 @@ +package com.twitter.scalding.quotation + +import org.scalatest.MustMatchers +import org.scalatest.FreeSpec + +class QuotedMacroTest extends Test { + + val test = new TestClass + + val nullary = test.nullary + val parametrizedNullary = test.parametrizedNullary[Int] + val withParam = test.withParam[Person, String](_.name)._1 + + val quotedFunction = + Quoted.function[Person, Contact](_.contact) + + val nestedQuotedFuction = + Quoted.function[Person, Contact](p => quotedFunction(p)) + + val person = Person("John", Contact("33223"), None) + + class TestClass { + def nullary(implicit q: Quoted) = q + def parametrizedNullary[T](implicit q: Quoted) = q + def withParam[T, U](f: T => U)(implicit q: Quoted) = (q, f) + } + + "quoted method" - { + + "nullary" in { + nullary.position.toString mustEqual "QuotedMacroTest.scala:10" + nullary.projections.set mustEqual Set.empty + nullary.text mustEqual Some("nullary") + } + + "parametrizedNullary" in { + parametrizedNullary.position.toString mustEqual "QuotedMacroTest.scala:11" + parametrizedNullary.projections.set mustEqual Set.empty + parametrizedNullary.text mustEqual Some("parametrizedNullary[Int]") + } + + "withParam" in { + withParam.position.toString mustEqual "QuotedMacroTest.scala:12" + withParam.projections.set mustEqual Set(Person.nameProjection) + withParam.text mustEqual Some("withParam[Person, String](_.name)") + } + } + + "quoted function" - { + "simple" in { + val q = quotedFunction.quoted + q.position.toString mustEqual "QuotedMacroTest.scala:15" + q.projections.set mustEqual Set(Person.contactProjection) + q.text mustEqual Some("[Person, Contact](_.contact)") + + quotedFunction(person) mustEqual person.contact + } + "nested" in { + val q = nestedQuotedFuction.quoted + q.position.toString mustEqual "QuotedMacroTest.scala:18" + q.projections.set mustEqual Set(Person.contactProjection) + q.text mustEqual Some("[Person, Contact](p => quotedFunction(p))") + + nestedQuotedFuction(person) mustEqual person.contact + } + } + + "invalid quoted method call" in { + "Quoted.method" mustNot compile + } +} \ No newline at end of file diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/TextMacroTest.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/TextMacroTest.scala new file mode 100644 index 0000000000..eea5e65f89 --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/TextMacroTest.scala @@ -0,0 +1,114 @@ +package com.twitter.scalding.quotation + +import org.scalatest.Matchers +import org.scalatest.WordSpec +import org.scalatest.FreeSpec +import org.scalatest.MustMatchers + +class TextMacroTest extends Test { + + class TestClass { + def nullary(implicit m: Quoted) = m + def parametrizedNullary[T](implicit m: Quoted) = m + def primitiveParam(v: Int)(implicit m: Quoted) = (m, v) + def parametrized[T](v: T)(implicit m: Quoted) = (m, v) + def paramGroups(a: Int, b: Int)(c: Int)(implicit m: Quoted) = (m, a, b, c) + def parametrizedParamGroups[T](a: T, b: Int)(c: T)(implicit m: Quoted) = (m, a, b, c) + def paramGroupsWithFunction(a: Int)(b: Int => Int)(implicit m: Quoted) = (m, a, b) + def function(f: Int => Int)(implicit m: Quoted) = (m, f) + def multipleFunctions[T, U, V](f1: T => U, f2: U => V)(implicit m: Quoted) = (m, f1, f2) + def tupleParam(t: (Int, Int))(implicit m: Quoted) = (m, t) + } + + val test = new TestClass + + "nullary" in { + test.nullary.text mustEqual + Some("nullary") + } + + "parametrizedNullary" - { + "inferred type param" in { + test.parametrizedNullary.text mustEqual + Some("parametrizedNullary") + } + "explicit type param" in { + test.parametrizedNullary[Int].text mustEqual + Some("parametrizedNullary[Int]") + } + } + + "primitiveParam" in { + test.primitiveParam(22)._1.text mustEqual + Some("primitiveParam(22)") + } + + "parametrized" - { + "inferred type param" in { + test.parametrized(42)._1.text mustEqual + Some("parametrized(42)") + } + "explicit type param" in { + test.parametrized[Int](42)._1.text mustEqual + Some("parametrized[Int](42)") + } + } + + "paramGroups" - { + "primitives" in { + test.paramGroups(1, 2)(3)._1.text mustEqual + Some("paramGroups(1, 2)(3)") + } + "parametrized" - { + "explicit type param" in { + test.parametrizedParamGroups[Int](1, 2)(3)._1.text mustEqual + Some("parametrizedParamGroups[Int](1, 2)(3)") + } + "inferred type param" in { + test.parametrizedParamGroups(1, 2)(3)._1.text mustEqual + Some("parametrizedParamGroups(1, 2)(3)") + } + } + "with function" in { + (test.paramGroupsWithFunction(1) { + case 1 => 2 + case _ => 3 + })._1.text mustEqual + Some("""paramGroupsWithFunction(1) { + case 1 => 2 + case _ => 3 + }""") + } + } + + "function" - { + "underscore" in { + test.function(_ + 1)._1.text mustEqual + Some("function(_ + 1)") + } + "pattern matching" in { + test.function { case _ => 4 }._1.text mustEqual Some("function { case _ => 4 }") + } + "curly braces" in { + test.function { _ + 1 }._1.text mustEqual Some("function { _ + 1 }") + } + } + + "complex tree" in { + val c = test.function { + def test = 1 + _ + 1 + } + c._1.text mustEqual + Some( + """function { + def test = 1 + _ + 1 + }""") + } + + "tuple param" in { + test.tupleParam((1, 2))._1.text mustEqual + Some("tupleParam((1, 2))") + } +} \ No newline at end of file diff --git a/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/package.scala b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/package.scala new file mode 100644 index 0000000000..401b523788 --- /dev/null +++ b/scalding-quotation/src/test/scala/com/twitter/scalding/quotation/package.scala @@ -0,0 +1,9 @@ +package com.twitter.scalding + +import org.scalatest.MustMatchers +import org.scalatest.FreeSpec + +package object quotation { + def typeName[T](implicit ct: reflect.ClassTag[T]) = TypeName(ct.runtimeClass.getName) + trait Test extends FreeSpec with MustMatchers +} diff --git a/scalding-repl/src/main/scala/com/twitter/scalding/ScaldingILoop.scala b/scalding-repl/src/main/scala/com/twitter/scalding/ScaldingILoop.scala index e0b9c0210e..ed6bd81d95 100644 --- a/scalding-repl/src/main/scala/com/twitter/scalding/ScaldingILoop.scala +++ b/scalding-repl/src/main/scala/com/twitter/scalding/ScaldingILoop.scala @@ -114,7 +114,8 @@ class ScaldingILoop(in: Option[BufferedReader], out: JPrintWriter) val cwd = System.getProperty("user.dir") ScaldingILoop.findAllUpPath(cwd)(".scalding_repl").reverse.foreach { f => - s.loadfiles.appendToValue(f.toString) + val fs = s.loadfiles.value + s.loadfiles.value = fs ::: List(f.toString) } case _ => () } diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Boxed.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Boxed.scala index 1a70fe1916..b386cbd780 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Boxed.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Boxed.scala @@ -833,7 +833,7 @@ object Boxed { private[scalding] def nextCached[K](cacheKey: Option[AnyRef]): (K => Boxed[K], Class[Boxed[K]]) = cacheKey match { case Some(cls) => - val untypedRes = Option(boxedCache.get(cls)) match { + val untypedRes = Option(boxedCache.get(cls)) match { // linter:ignore case Some(r) => r case None => val r = next[Any]() diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/JavaStreamEnrichments.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/JavaStreamEnrichments.scala index d3a4bcda99..1405589620 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/JavaStreamEnrichments.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/JavaStreamEnrichments.scala @@ -245,7 +245,7 @@ object JavaStreamEnrichments { s.write(i) } else { // the linter does not like us repeating ourselves here - s.write(-1) + s.write(-1) // linter:ignore s.write(-1) // linter:ignore writeInt(i) } diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Laws.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Laws.scala index a173171480..db4f6344bf 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Laws.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Laws.scala @@ -22,6 +22,6 @@ package com.twitter.scalding.serialization sealed trait Law[T] { def name: String } -case class Law1[T](override val name: String, check: T => Boolean) extends Law[T] -case class Law2[T](override val name: String, check: (T, T) => Boolean) extends Law[T] -case class Law3[T](override val name: String, check: (T, T, T) => Boolean) extends Law[T] +final case class Law1[T](override val name: String, check: T => Boolean) extends Law[T] +final case class Law2[T](override val name: String, check: (T, T) => Boolean) extends Law[T] +final case class Law3[T](override val name: String, check: (T, T, T) => Boolean) extends Law[T] diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/OrderedSerialization.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/OrderedSerialization.scala index eff563a98b..829e8fd026 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/OrderedSerialization.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/OrderedSerialization.scala @@ -183,7 +183,7 @@ object OrderedSerialization { Law2("totality", { (a: T, b: T) => (ordb.lteq(a, b) || ordb.lteq(b, a)) }) def allLaws[T: OrderedSerialization]: Iterable[Law[T]] = - Serialization.allLaws ++ List(compareBinaryMatchesCompare[T], + Serialization.allLaws ++ List[Law[T]](compareBinaryMatchesCompare[T], orderingTransitive[T], orderingAntisymmetry[T], orderingTotality[T]) @@ -214,3 +214,17 @@ final case class DeserializingOrderedSerialization[T](serialization: Serializati final override def staticSize = serialization.staticSize final override def dynamicSize(t: T) = serialization.dynamicSize(t) } + +object UnitOrderedSerialization extends OrderedSerialization[Unit] with EquivSerialization[Unit] { + private[this] val same = OrderedSerialization.Equal + private[this] val someZero = Some(0) + + final override def read(i: InputStream) = Serialization.successUnit + final override def write(o: OutputStream, t: Unit) = Serialization.successUnit + final override def hash(t: Unit) = 0 + final override def compare(a: Unit, b: Unit) = 0 + final override def compareBinary(a: InputStream, b: InputStream) = + same + final override def staticSize = someZero + final override def dynamicSize(t: Unit) = someZero +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Serialization.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Serialization.scala index 288df7dbdf..5930260999 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Serialization.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/Serialization.scala @@ -164,7 +164,7 @@ object Serialization { }) def allLaws[T: Serialization]: Iterable[Law[T]] = - List(roundTripLaw, + List[Law[T]](roundTripLaw, serializationIsEquivalence, hashCodeImpliesEquality, reflexivity, diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/BinaryOrdering.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/BinaryOrdering.scala index 631d99c86b..5afa9d88b7 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/BinaryOrdering.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/BinaryOrdering.scala @@ -4,10 +4,8 @@ import com.twitter.scalding.serialization.OrderedSerialization import scala.language.experimental.macros -/** - * @author Mansur Ashraf. - */ -object BinaryOrdering { - +trait BinaryOrdering { implicit def ordSer[T]: OrderedSerialization[T] = macro com.twitter.scalding.serialization.macros.impl.OrderedSerializationProviderImpl[T] } + +object BinaryOrdering extends BinaryOrdering diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/OrderedBufferableProviderImpl.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/OrderedBufferableProviderImpl.scala index ff083f83aa..f86e6115fc 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/OrderedBufferableProviderImpl.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/OrderedBufferableProviderImpl.scala @@ -24,11 +24,13 @@ import com.twitter.scalding.serialization.macros.impl.ordered_serialization._ import com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers._ object OrderedSerializationProviderImpl { - def normalizedDispatcher(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + def normalizedDispatcher(c: Context)( + buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { case tpe if !(tpe.normalize == tpe) => buildDispatcher(tpe.normalize) } - def scaldingBasicDispatchers(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + def scaldingBasicDispatchers(c: Context)( + buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { val primitiveDispatcher = PrimitiveOrderedBuf.dispatch(c) val optionDispatcher = OptionOrderedBuf.dispatch(c)(buildDispatcher) @@ -42,7 +44,8 @@ object OrderedSerializationProviderImpl { val byteBufferDispatcher = ByteBufferOrderedBuf.dispatch(c) val sealedTraitDispatcher = SealedTraitOrderedBuf.dispatch(c)(buildDispatcher) - OrderedSerializationProviderImpl.normalizedDispatcher(c)(buildDispatcher) + OrderedSerializationProviderImpl + .normalizedDispatcher(c)(buildDispatcher) .orElse(primitiveDispatcher) .orElse(unitDispatcher) .orElse(optionDispatcher) @@ -63,7 +66,8 @@ object OrderedSerializationProviderImpl { private def outerDispatcher(c: Context): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { import c.universe._ scaldingBasicDispatchers(c)(OrderedSerializationProviderImpl.innerDispatcher(c)).orElse { - case tpe: Type => c.abort(c.enclosingPosition, s"""Unable to find OrderedSerialization for type ${tpe}""") + case tpe: Type => + c.abort(c.enclosingPosition, s"""Unable to find OrderedSerialization for type ${tpe}""") } } @@ -71,7 +75,14 @@ object OrderedSerializationProviderImpl { // So in essence it never fails to do a lookup private def innerDispatcher(c: Context): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { import c.universe._ - scaldingBasicDispatchers(c)(OrderedSerializationProviderImpl.innerDispatcher(c)).orElse(fallbackImplicitDispatcher(c)) + val innerF = scaldingBasicDispatchers(c)(OrderedSerializationProviderImpl.innerDispatcher(c)) + + val f: PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + case tpe if innerF.isDefinedAt(tpe) => + scala.util.Try(innerF(tpe)).getOrElse(fallbackImplicitDispatcher(c)(tpe)) + case tpe => fallbackImplicitDispatcher(c)(tpe) + } + f } def apply[T](c: Context)(implicit T: c.WeakTypeTag[T]): c.Expr[OrderedSerialization[T]] = { diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/CompileTimeLengthTypes.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/CompileTimeLengthTypes.scala index 7a2640c603..c677670fc3 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/CompileTimeLengthTypes.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/CompileTimeLengthTypes.scala @@ -16,7 +16,7 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context sealed trait CompileTimeLengthTypes[C <: Context] { val ctx: C @@ -33,7 +33,7 @@ object CompileTimeLengthTypes { } } - trait FastLengthCalculation[C <: Context] extends CompileTimeLengthTypes[C] + sealed trait FastLengthCalculation[C <: Context] extends CompileTimeLengthTypes[C] object MaybeLengthCalculation { def apply(c: Context)(tree: c.Tree): MaybeLengthCalculation[c.type] = @@ -43,7 +43,7 @@ object CompileTimeLengthTypes { } } - trait MaybeLengthCalculation[C <: Context] extends CompileTimeLengthTypes[C] + sealed trait MaybeLengthCalculation[C <: Context] extends CompileTimeLengthTypes[C] object ConstantLengthCalculation { def apply(c: Context)(intArg: Int): ConstantLengthCalculation[c.type] = @@ -57,12 +57,12 @@ object CompileTimeLengthTypes { } } - trait ConstantLengthCalculation[C <: Context] extends CompileTimeLengthTypes[C] { + sealed trait ConstantLengthCalculation[C <: Context] extends CompileTimeLengthTypes[C] { def toInt: Int } object NoLengthCalculationAvailable { - def apply(c: Context): NoLengthCalculationAvailable[c.type] = { + def apply(c: Context): NoLengthCalculationAvailable[c.type] = new NoLengthCalculationAvailable[c.type] { override val ctx: c.type = c override def t = { @@ -70,8 +70,7 @@ object CompileTimeLengthTypes { q"""_root_.scala.sys.error("no length available")""" } } - } } - trait NoLengthCalculationAvailable[C <: Context] extends CompileTimeLengthTypes[C] + sealed trait NoLengthCalculationAvailable[C <: Context] extends CompileTimeLengthTypes[C] } diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/ProductLike.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/ProductLike.scala index 36828768b4..b9d58cba7f 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/ProductLike.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/ProductLike.scala @@ -16,21 +16,23 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context import com.twitter.scalding._ object ProductLike { - def compareBinary(c: Context)(inputStreamA: c.TermName, inputStreamB: c.TermName)(elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): c.Tree = { + def compareBinary(c: Context)(inputStreamA: c.TermName, inputStreamB: c.TermName)( + elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): c.Tree = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) - - elementData.foldLeft(Option.empty[Tree]) { - case (existingTreeOpt, (tpe, accessorSymbol, tBuf)) => - existingTreeOpt match { - case Some(t) => - val lastCmp = freshT("lastCmp") - Some(q""" + def freshT(id: String) = TermName(c.freshName(id)) + + elementData + .foldLeft(Option.empty[Tree]) { + case (existingTreeOpt, (tpe, accessorSymbol, tBuf)) => + existingTreeOpt match { + case Some(t) => + val lastCmp = freshT("lastCmp") + Some(q""" val $lastCmp = $t if($lastCmp != 0) { $lastCmp @@ -38,15 +40,17 @@ object ProductLike { ${tBuf.compareBinary(inputStreamA, inputStreamB)} } """) - case None => - Some(tBuf.compareBinary(inputStreamA, inputStreamB)) - } - }.getOrElse(q"0") + case None => + Some(tBuf.compareBinary(inputStreamA, inputStreamB)) + } + } + .getOrElse(q"0") } - def hash(c: Context)(element: c.TermName)(elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): c.Tree = { + def hash(c: Context)(element: c.TermName)( + elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): c.Tree = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) val currentHash = freshT("last") @@ -55,7 +59,10 @@ object ProductLike { val target = freshT("target") q""" val $target = $element.$accessorSymbol - $currentHash = _root_.com.twitter.scalding.serialization.MurmurHashUtils.mixH1($currentHash, ${tBuf.hash(target)}) + $currentHash = _root_.com.twitter.scalding.serialization.MurmurHashUtils.mixH1($currentHash, ${ + tBuf + .hash(target) + }) """ } @@ -66,9 +73,10 @@ object ProductLike { """ } - def put(c: Context)(inputStream: c.TermName, element: c.TermName)(elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): c.Tree = { + def put(c: Context)(inputStream: c.TermName, element: c.TermName)( + elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): c.Tree = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) val innerElement = freshT("innerElement") elementData.foldLeft(q"") { @@ -81,18 +89,32 @@ object ProductLike { } } - def length(c: Context)(element: c.Tree)(elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): CompileTimeLengthTypes[c.type] = { + def length(c: Context)(element: c.Tree)( + elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): CompileTimeLengthTypes[c.type] = { import c.universe._ import CompileTimeLengthTypes._ val (constSize, dynamicFunctions, maybeLength, noLength) = elementData.foldLeft((0, Vector[c.Tree](), Vector[c.Tree](), 0)) { - case ((constantLength, dynamicLength, maybeLength, noLength), (tpe, accessorSymbol, tBuf)) => - + case ((constantLength, dynamicLength, maybeLength, noLength), + (tpe, accessorSymbol, tBuf)) => tBuf.length(q"$element.$accessorSymbol") match { - case const: ConstantLengthCalculation[_] => (constantLength + const.asInstanceOf[ConstantLengthCalculation[c.type]].toInt, dynamicLength, maybeLength, noLength) - case f: FastLengthCalculation[_] => (constantLength, dynamicLength :+ f.asInstanceOf[FastLengthCalculation[c.type]].t, maybeLength, noLength) - case m: MaybeLengthCalculation[_] => (constantLength, dynamicLength, maybeLength :+ m.asInstanceOf[MaybeLengthCalculation[c.type]].t, noLength) - case _: NoLengthCalculationAvailable[_] => (constantLength, dynamicLength, maybeLength, noLength + 1) + case const: ConstantLengthCalculation[_] => + (constantLength + const.asInstanceOf[ConstantLengthCalculation[c.type]].toInt, + dynamicLength, + maybeLength, + noLength) + case f: FastLengthCalculation[_] => + (constantLength, + dynamicLength :+ f.asInstanceOf[FastLengthCalculation[c.type]].t, + maybeLength, + noLength) + case m: MaybeLengthCalculation[_] => + (constantLength, + dynamicLength, + maybeLength :+ m.asInstanceOf[MaybeLengthCalculation[c.type]].t, + noLength) + case _: NoLengthCalculationAvailable[_] => + (constantLength, dynamicLength, maybeLength, noLength + 1) } } @@ -111,13 +133,18 @@ object ProductLike { FastLengthCalculation(c)(combinedDynamic) } else { - val const = q"_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen" - val dyn = q"_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen" - val noLen = q"_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation" + val const = + q"_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen" + val dyn = + q"_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen" + val noLen = + q"_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation" // Contains an MaybeLength - val combinedMaybe: Tree = maybeLength.reduce { (hOpt, nxtOpt) => q"""$hOpt + $nxtOpt""" } + val combinedMaybe: Tree = maybeLength.reduce { (hOpt, nxtOpt) => + q"""$hOpt + $nxtOpt""" + } if (dynamicFunctions.nonEmpty || constSize != 0) { - MaybeLengthCalculation(c) (q""" + MaybeLengthCalculation(c)(q""" $combinedMaybe match { case $const(l) => $dyn(l + $combinedDynamic) case $dyn(l) => $dyn(l + $combinedDynamic) @@ -132,26 +159,28 @@ object ProductLike { } } - def compare(c: Context)(elementA: c.TermName, elementB: c.TermName)(elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): c.Tree = { + def compare(c: Context)(elementA: c.TermName, elementB: c.TermName)( + elementData: List[(c.universe.Type, c.universe.TermName, TreeOrderedBuf[c.type])]): c.Tree = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) val innerElementA = freshT("innerElementA") val innerElementB = freshT("innerElementB") - elementData.map { - case (tpe, accessorSymbol, tBuf) => - val curCmp = freshT("curCmp") - val cmpTree = q""" + elementData + .map { + case (tpe, accessorSymbol, tBuf) => + val curCmp = freshT("curCmp") + val cmpTree = q""" val $curCmp: Int = { val $innerElementA = $elementA.$accessorSymbol val $innerElementB = $elementB.$accessorSymbol ${tBuf.compare(innerElementA, innerElementB)} } """ - (cmpTree, curCmp) - } + (cmpTree, curCmp) + } .reverse // go through last to first .foldLeft(None: Option[Tree]) { case (Some(rest), (tree, valname)) => diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/SealedTraitLike.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/SealedTraitLike.scala index f4f75e38f3..9565970cda 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/SealedTraitLike.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/SealedTraitLike.scala @@ -16,7 +16,7 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context import com.twitter.scalding._ import com.twitter.scalding.serialization.OrderedSerialization @@ -35,36 +35,37 @@ object SealedTraitLike { */ // This `_.get` could be removed by switching `subData` to a non-empty list type @SuppressWarnings(Array("org.wartremover.warts.OptionPartial")) - def compareBinary(c: Context)(inputStreamA: c.TermName, inputStreamB: c.TermName)(subData: List[(Int, c.Type, TreeOrderedBuf[c.type])]): c.Tree = { + def compareBinary(c: Context)(inputStreamA: c.TermName, inputStreamB: c.TermName)( + subData: List[(Int, c.Type, TreeOrderedBuf[c.type])]): c.Tree = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) val valueA = freshT("valueA") val valueB = freshT("valueB") val idxCmp = freshT("idxCmp") - val compareSameTypes: Tree = subData.foldLeft(Option.empty[Tree]) { - case (existing, (idx, tpe, tBuf)) => - - val commonCmp: Tree = tBuf.compareBinary(inputStreamA, inputStreamB) + val compareSameTypes: Tree = subData + .foldLeft(Option.empty[Tree]) { + case (existing, (idx, tpe, tBuf)) => + val commonCmp: Tree = tBuf.compareBinary(inputStreamA, inputStreamB) - existing match { - case Some(t) => - Some(q""" + existing match { + case Some(t) => + Some(q""" if($valueA == $idx) { $commonCmp } else { $t } """) - case None => - Some(q""" + case None => + Some(q""" if($valueA == $idx) { $commonCmp } else { sys.error("unreachable code -- this could only be reached by corruption in serialization.") }""") - } - }.get + } + }.get // linter:ignore:wartermover:OptionPartial q""" val $valueA: Int = $inputStreamA.readByte.toInt @@ -81,55 +82,60 @@ object SealedTraitLike { // This `_.get` could be removed by switching `subData` to a non-empty list type @SuppressWarnings(Array("org.wartremover.warts.OptionPartial")) - def hash(c: Context)(element: c.TermName)(subData: List[(Int, c.Type, TreeOrderedBuf[c.type])]): c.Tree = { + def hash(c: Context)(element: c.TermName)( + subData: List[(Int, c.Type, TreeOrderedBuf[c.type])]): c.Tree = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) - subData.foldLeft(Option.empty[Tree]) { - case (optiExisting, (idx, tpe, tBuf)) => - val innerArg = freshT("innerArg") - val elementHash: Tree = q""" + subData + .foldLeft(Option.empty[Tree]) { + case (optiExisting, (idx, tpe, tBuf)) => + val innerArg = freshT("innerArg") + val elementHash: Tree = q""" val $innerArg: $tpe = $element.asInstanceOf[$tpe] ${tBuf.hash(innerArg)} """ - optiExisting match { - case Some(s) => - Some(q""" + optiExisting match { + case Some(s) => + Some(q""" if($element.isInstanceOf[$tpe]) { $elementHash ^ ${intHash(idx)} } else { $s } """) - case None => - Some(q""" + case None => + Some(q""" if($element.isInstanceOf[$tpe]) { $elementHash ^ ${intHash(idx)} } else { _root_.scala.Int.MaxValue } """) - } - }.get + } + } + .get } // This `_.get` could be removed by switching `subData` to a non-empty list type @SuppressWarnings(Array("org.wartremover.warts.OptionPartial")) - def put(c: Context)(inputStream: c.TermName, element: c.TermName)(subData: List[(Int, c.Type, TreeOrderedBuf[c.type])]): c.Tree = { + def put(c: Context)(inputStream: c.TermName, element: c.TermName)( + subData: List[(Int, c.Type, TreeOrderedBuf[c.type])]): c.Tree = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) val innerArg = freshT("innerArg") - subData.foldLeft(Option.empty[Tree]) { - case (optiExisting, (idx, tpe, tBuf)) => - val commonPut: Tree = q"""val $innerArg: $tpe = $element.asInstanceOf[$tpe] + subData + .foldLeft(Option.empty[Tree]) { + case (optiExisting, (idx, tpe, tBuf)) => + val commonPut: Tree = q"""val $innerArg: $tpe = $element.asInstanceOf[$tpe] ${tBuf.put(inputStream, innerArg)} """ - optiExisting match { - case Some(s) => - Some(q""" + optiExisting match { + case Some(s) => + Some(q""" if($element.isInstanceOf[$tpe]) { $inputStream.writeByte($idx.toByte) $commonPut @@ -137,43 +143,49 @@ object SealedTraitLike { $s } """) - case None => - Some(q""" + case None => + Some(q""" if($element.isInstanceOf[$tpe]) { $inputStream.writeByte($idx.toByte) $commonPut } """) - } - }.get + } + } + .get } // This `_.get` could be removed by switching `subData` to a non-empty list type - @SuppressWarnings(Array("org.wartremover.warts.OptionPartial")) - def length(c: Context)(element: c.Tree)(subData: List[(Int, c.Type, TreeOrderedBuf[c.type])]): CompileTimeLengthTypes[c.type] = { + @SuppressWarnings(Array("org.wartremover.warts.OptionPartial", "org.wartremover.warts.Return")) + def length(c: Context)(element: c.Tree)( + subData: List[(Int, c.Type, TreeOrderedBuf[c.type])]): CompileTimeLengthTypes[c.type] = { import CompileTimeLengthTypes._ import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) - - val prevSizeData = subData.foldLeft(Option.empty[Tree]) { - case (optiTree, (idx, tpe, tBuf)) => - - val baseLenT: Tree = tBuf.length(q"$element.asInstanceOf[$tpe]") match { - case m: MaybeLengthCalculation[_] => - m.asInstanceOf[MaybeLengthCalculation[c.type]].t - - case f: FastLengthCalculation[_] => - q"""_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(${f.asInstanceOf[FastLengthCalculation[c.type]].t})""" - - case _: NoLengthCalculationAvailable[_] => - return NoLengthCalculationAvailable(c) - case const: ConstantLengthCalculation[_] => - q"""_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen(${const.toInt})""" - case e => sys.error("unexpected input to union length code of " + e) - } - val tmpPreLen = freshT("tmpPreLen") + def freshT(id: String) = TermName(c.freshName(id)) + + val prevSizeData = subData + .foldLeft(Option.empty[Tree]) { + case (optiTree, (idx, tpe, tBuf)) => + val baseLenT: Tree = tBuf.length(q"$element.asInstanceOf[$tpe]") match { + case m: MaybeLengthCalculation[_] => + m.asInstanceOf[MaybeLengthCalculation[c.type]].t + + case f: FastLengthCalculation[_] => + q"""_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(${ + f + .asInstanceOf[FastLengthCalculation[c.type]] + .t + })""" + + case _: NoLengthCalculationAvailable[_] => + return NoLengthCalculationAvailable(c) + case const: ConstantLengthCalculation[_] => + q"""_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.ConstLen(${const.toInt})""" + case e => sys.error("unexpected input to union length code of " + e) + } + val tmpPreLen = freshT("tmpPreLen") - val lenT = q""" + val lenT = q""" val $tmpPreLen: _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MaybeLength = $baseLenT ($tmpPreLen match { @@ -185,59 +197,63 @@ object SealedTraitLike { _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.NoLengthCalculation }): _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.MaybeLength """ - optiTree match { - case Some(t) => - Some(q""" + optiTree match { + case Some(t) => + Some(q""" if($element.isInstanceOf[$tpe]) { $lenT } else { $t } """) - case None => - Some(q""" + case None => + Some(q""" if($element.isInstanceOf[$tpe]) { $lenT } else { sys.error("Unreachable code, did not match sealed trait type") }""") - } - }.get + } + } + .get - MaybeLengthCalculation(c) (prevSizeData) + MaybeLengthCalculation(c)(prevSizeData) } // This `_.get` could be removed by switching `subData` to a non-empty list type @SuppressWarnings(Array("org.wartremover.warts.OptionPartial")) - def get(c: Context)(inputStream: c.TermName)(subData: List[(Int, c.Type, TreeOrderedBuf[c.type])]): c.Tree = { + def get(c: Context)(inputStream: c.TermName)( + subData: List[(Int, c.Type, TreeOrderedBuf[c.type])]): c.Tree = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) val valueA = freshT("valueA") - val expandedOut = subData.foldLeft(Option.empty[Tree]) { - case (existing, (idx, tpe, tBuf)) => - val extract = q"${tBuf.get(inputStream)}" + val expandedOut = subData + .foldLeft(Option.empty[Tree]) { + case (existing, (idx, tpe, tBuf)) => + val extract = q"${tBuf.get(inputStream)}" - existing match { - case Some(t) => - Some(q""" + existing match { + case Some(t) => + Some(q""" if($valueA == $idx) { $extract : $tpe } else { $t } """) - case None => - Some(q""" + case None => + Some(q""" if($valueA == $idx) { $extract } else { sys.error("Did not understand sealed trait with idx: " + $valueA + ", this should only happen in a serialization failure.") } """) - } - }.get + } + } + .get q""" val $valueA: Int = $inputStream.readByte.toInt @@ -247,36 +263,39 @@ object SealedTraitLike { // This `_.get` could be removed by switching `subData` to a non-empty list type @SuppressWarnings(Array("org.wartremover.warts.OptionPartial")) - def compare(c: Context)(cmpType: c.Type, elementA: c.TermName, elementB: c.TermName)(subData: List[(Int, c.Type, TreeOrderedBuf[c.type])]): c.Tree = { + def compare(c: Context)(cmpType: c.Type, elementA: c.TermName, elementB: c.TermName)( + subData: List[(Int, c.Type, TreeOrderedBuf[c.type])]): c.Tree = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) val arg = freshT("arg") val idxCmp = freshT("idxCmp") val idxA = freshT("idxA") val idxB = freshT("idxB") - val toIdOpt: Tree = subData.foldLeft(Option.empty[Tree]) { - case (existing, (idx, tpe, _)) => - existing match { - case Some(t) => - Some(q""" + val toIdOpt: Tree = subData + .foldLeft(Option.empty[Tree]) { + case (existing, (idx, tpe, _)) => + existing match { + case Some(t) => + Some(q""" if($arg.isInstanceOf[$tpe]) { $idx } else { $t } """) - case None => - Some(q""" + case None => + Some(q""" if($arg.isInstanceOf[$tpe]) { $idx } else { sys.error("This should be unreachable code, failure in serializer or deserializer to reach here.") }""") - } - }.get + } + } + .get val compareSameTypes: Option[Tree] = subData.foldLeft(Option.empty[Tree]) { case (existing, (idx, tpe, tBuf)) => @@ -328,4 +347,3 @@ object SealedTraitLike { compareFn } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/TreeOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/TreeOrderedBuf.scala index b8c89e5f19..024fcbf26c 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/TreeOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/TreeOrderedBuf.scala @@ -19,7 +19,7 @@ import com.twitter.scalding._ import com.twitter.scalding.serialization.OrderedSerialization import com.twitter.scalding.serialization.JavaStreamEnrichments import java.io.InputStream -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context import scala.language.experimental.macros import scala.util.control.NonFatal @@ -37,6 +37,7 @@ object CommonCompareBinary { * check if they are byte-for-byte identical, which is a cheap way to avoid doing * potentially complex logic in binary comparators */ + @SuppressWarnings(Array("org.wartremover.warts.Return")) final def earlyEqual(inputStreamA: InputStream, lenA: Int, inputStreamB: InputStream, @@ -59,8 +60,8 @@ object CommonCompareBinary { // yeah, return sucks, but trying to optimize here return false } + else if (a < 0) return JavaStreamEnrichments.eof // a == b, but may be eof - if (a < 0) return JavaStreamEnrichments.eof } // we consumed all the bytes, and they were all equal true @@ -70,7 +71,7 @@ object TreeOrderedBuf { import CompileTimeLengthTypes._ def toOrderedSerialization[T](c: Context)(t: TreeOrderedBuf[c.type])(implicit T: t.ctx.WeakTypeTag[T]): t.ctx.Expr[OrderedSerialization[T]] = { import t.ctx.universe._ - def freshT(id: String) = newTermName(c.fresh(s"fresh_$id")) + def freshT(id: String) = TermName(c.freshName(s"fresh_$id")) val outputLength = freshT("outputLength") val innerLengthFn: Tree = { @@ -213,7 +214,7 @@ object TreeOrderedBuf { val lazyVariables = t.lazyOuterVariables.map { case (n, t) => - val termName = newTermName(n) + val termName = TermName(n) q"""lazy val $termName = $t""" } @@ -258,7 +259,7 @@ object TreeOrderedBuf { } override def hash(passedInObjectToHash: $T): Int = { - ${t.hash(newTermName("passedInObjectToHash"))} + ${t.hash(TermName("passedInObjectToHash"))} } private[this] def failedLengthCalc(): Unit = { @@ -282,8 +283,8 @@ object TreeOrderedBuf { override def read(from: _root_.java.io.InputStream): _root_.scala.util.Try[$T] = { try { - ${discardLength(newTermName("from"))} - _root_.scala.util.Success(${t.get(newTermName("from"))}) + ${discardLength(TermName("from"))} + _root_.scala.util.Success(${t.get(TermName("from"))}) } catch { case _root_.scala.util.control.NonFatal(e) => _root_.scala.util.Failure(e) } @@ -291,7 +292,7 @@ object TreeOrderedBuf { override def write(into: _root_.java.io.OutputStream, e: $T): _root_.scala.util.Try[Unit] = { try { - ${putFnGen(newTermName("into"), newTermName("e"))} + ${putFnGen(TermName("into"), TermName("e"))} _root_.com.twitter.scalding.serialization.Serialization.successUnit } catch { case _root_.scala.util.control.NonFatal(e) => _root_.scala.util.Failure(e) @@ -299,7 +300,7 @@ object TreeOrderedBuf { } override def compare(x: $T, y: $T): Int = { - ${t.compare(newTermName("x"), newTermName("y"))} + ${t.compare(TermName("x"), TermName("y"))} } } """) diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ByteBufferOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ByteBufferOrderedBuf.scala index af26712f42..3b9ab67574 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ByteBufferOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ByteBufferOrderedBuf.scala @@ -16,10 +16,14 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context import com.twitter.scalding._ -import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ + CompileTimeLengthTypes, + ProductLike, + TreeOrderedBuf +} import CompileTimeLengthTypes._ import java.nio.ByteBuffer @@ -33,7 +37,7 @@ object ByteBufferOrderedBuf { def apply(c: Context)(outerType: c.Type): TreeOrderedBuf[c.type] = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) new TreeOrderedBuf[c.type] { override val ctx: c.type = c @@ -96,4 +100,3 @@ object ByteBufferOrderedBuf { } } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/CaseClassOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/CaseClassOrderedBuf.scala index 8b0dcb5bfb..f911ac1f60 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/CaseClassOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/CaseClassOrderedBuf.scala @@ -16,33 +16,40 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context import com.twitter.scalding._ -import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ + CompileTimeLengthTypes, + ProductLike, + TreeOrderedBuf +} import CompileTimeLengthTypes._ import com.twitter.scalding.serialization.OrderedSerialization +@SuppressWarnings(Array("org.wartremover.warts.MergeMaps")) object CaseClassOrderedBuf { def dispatch(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { case tpe if tpe.typeSymbol.isClass && tpe.typeSymbol.asClass.isCaseClass && !tpe.typeSymbol.asClass.isModuleClass => CaseClassOrderedBuf(c)(buildDispatcher, tpe) } - def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], outerType: c.Type): TreeOrderedBuf[c.type] = { + def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], + outerType: c.Type): TreeOrderedBuf[c.type] = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) val dispatcher = buildDispatcher val elementData: List[(c.universe.Type, TermName, TreeOrderedBuf[c.type])] = - outerType - .declarations + outerType.decls .collect { case m: MethodSymbol if m.isCaseAccessor => m } .map { accessorMethod => - val fieldType = accessorMethod.returnType.asSeenFrom(outerType, outerType.typeSymbol.asClass) + val fieldType = + accessorMethod.returnType.asSeenFrom(outerType, outerType.typeSymbol.asClass) val b: TreeOrderedBuf[c.type] = dispatcher(fieldType) - (fieldType, accessorMethod.name.toTermName, b) - }.toList + (fieldType, accessorMethod.name, b) + } + .toList new TreeOrderedBuf[c.type] { override val ctx: c.type = c @@ -50,7 +57,8 @@ object CaseClassOrderedBuf { override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = ProductLike.compareBinary(c)(inputStreamA, inputStreamB)(elementData) - override def hash(element: ctx.TermName): ctx.Tree = ProductLike.hash(c)(element)(elementData) + override def hash(element: ctx.TermName): ctx.Tree = + ProductLike.hash(c)(element)(elementData) override def put(inputStream: ctx.TermName, element: ctx.TermName) = ProductLike.put(c)(inputStream, element)(elementData) @@ -67,6 +75,7 @@ object CaseClassOrderedBuf { """ (builderTree, curR) } + q""" ..${getValProcessor.map(_._1)} ${outerType.typeSymbol.companionSymbol}(..${getValProcessor.map(_._2)}) @@ -83,4 +92,3 @@ object CaseClassOrderedBuf { } } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/CaseObjectOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/CaseObjectOrderedBuf.scala index 39dc0b1f18..bbee543e0d 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/CaseObjectOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/CaseObjectOrderedBuf.scala @@ -16,10 +16,14 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context import com.twitter.scalding._ -import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ + CompileTimeLengthTypes, + ProductLike, + TreeOrderedBuf +} import CompileTimeLengthTypes._ import com.twitter.scalding.serialization.OrderedSerialization @@ -40,7 +44,8 @@ object CaseObjectOrderedBuf { override def put(inputStream: ctx.TermName, element: ctx.TermName) = q"()" - override def get(inputStream: ctx.TermName): ctx.Tree = q"${outerType.typeSymbol.companionSymbol}" + override def get(inputStream: ctx.TermName): ctx.Tree = + q"${outerType.typeSymbol.companionSymbol}" override def compare(elementA: ctx.TermName, elementB: ctx.TermName): ctx.Tree = q"0" @@ -49,4 +54,3 @@ object CaseObjectOrderedBuf { } } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/EitherOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/EitherOrderedBuf.scala index 0c87848bbd..af564142c9 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/EitherOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/EitherOrderedBuf.scala @@ -16,21 +16,27 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context import com.twitter.scalding._ -import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ + CompileTimeLengthTypes, + ProductLike, + TreeOrderedBuf +} import CompileTimeLengthTypes._ import com.twitter.scalding.serialization.OrderedSerialization object EitherOrderedBuf { def dispatch(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { - case tpe if tpe.erasure =:= c.universe.typeOf[Either[Any, Any]] => EitherOrderedBuf(c)(buildDispatcher, tpe) + case tpe if tpe.erasure =:= c.universe.typeOf[Either[Any, Any]] => + EitherOrderedBuf(c)(buildDispatcher, tpe) } - def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], outerType: c.Type): TreeOrderedBuf[c.type] = { + def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], + outerType: c.Type): TreeOrderedBuf[c.type] = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) val dispatcher = buildDispatcher val leftType = outerType.asInstanceOf[TypeRefApi].args(0) // linter:ignore @@ -132,17 +138,21 @@ object EitherOrderedBuf { new TreeOrderedBuf[c.type] { override val ctx: c.type = c override val tpe = outerType - override def compareBinary(inputStreamA: TermName, inputStreamB: TermName) = genBinaryCompare(inputStreamA, inputStreamB) + override def compareBinary(inputStreamA: TermName, inputStreamB: TermName) = + genBinaryCompare(inputStreamA, inputStreamB) override def hash(element: TermName): ctx.Tree = genHashFn(element) override def put(inputStream: TermName, element: TermName) = genPutFn(inputStream, element) override def get(inputStreamA: TermName): ctx.Tree = genGetFn(inputStreamA) - override def compare(elementA: TermName, elementB: TermName): ctx.Tree = genCompareFn(elementA, elementB) + override def compare(elementA: TermName, elementB: TermName): ctx.Tree = + genCompareFn(elementA, elementB) override val lazyOuterVariables: Map[String, ctx.Tree] = rightBuf.lazyOuterVariables ++ leftBuf.lazyOuterVariables override def length(element: Tree): CompileTimeLengthTypes[c.type] = { - def tree(ctl: CompileTimeLengthTypes[_]): c.Tree = ctl.asInstanceOf[CompileTimeLengthTypes[c.type]].t - val dyn = q"""_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen""" + def tree(ctl: CompileTimeLengthTypes[_]): c.Tree = + ctl.asInstanceOf[CompileTimeLengthTypes[c.type]].t + val dyn = + q"""_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen""" (leftBuf.length(q"$element.left.get"), rightBuf.length(q"$element.right.get")) match { case (lconst: ConstantLengthCalculation[_], rconst: ConstantLengthCalculation[_]) if lconst.toInt == rconst.toInt => @@ -177,4 +187,3 @@ object EitherOrderedBuf { } } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ImplicitOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ImplicitOrderedBuf.scala index 32e4027135..e84d1586da 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ImplicitOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ImplicitOrderedBuf.scala @@ -16,7 +16,7 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context import com.twitter.scalding._ import com.twitter.scalding.serialization.OrderedSerialization @@ -25,7 +25,7 @@ import com.twitter.scalding.serialization.macros.impl.ordered_serialization._ /* A fall back ordered bufferable to look for the user to have an implicit in scope to satisfy the missing type. This is for the case where its an opaque class to our macros where we can't figure out the fields -*/ + */ object ImplicitOrderedBuf { def dispatch(c: Context): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { @@ -38,11 +38,11 @@ object ImplicitOrderedBuf { def apply(c: Context)(outerType: c.Type): TreeOrderedBuf[c.type] = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) val variableID = (outerType.typeSymbol.fullName.hashCode.toLong + Int.MaxValue.toLong).toString val variableNameStr = s"orderedSer_$variableID" - val variableName = newTermName(variableNameStr) + val variableName = TermName(variableNameStr) val implicitInstanciator = q""" implicitly[_root_.com.twitter.scalding.serialization.OrderedSerialization[${outerType}]]""" @@ -80,4 +80,3 @@ object ImplicitOrderedBuf { } } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/OptionOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/OptionOrderedBuf.scala index 7d2c1403b8..353345ff35 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/OptionOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/OptionOrderedBuf.scala @@ -16,21 +16,27 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context import com.twitter.scalding._ -import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ + CompileTimeLengthTypes, + ProductLike, + TreeOrderedBuf +} import CompileTimeLengthTypes._ import com.twitter.scalding.serialization.OrderedSerialization object OptionOrderedBuf { def dispatch(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { - case tpe if tpe.erasure =:= c.universe.typeOf[Option[Any]] => OptionOrderedBuf(c)(buildDispatcher, tpe) + case tpe if tpe.erasure =:= c.universe.typeOf[Option[Any]] => + OptionOrderedBuf(c)(buildDispatcher, tpe) } - def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], outerType: c.Type): TreeOrderedBuf[c.type] = { + def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], + outerType: c.Type): TreeOrderedBuf[c.type] = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) val dispatcher = buildDispatcher val innerType = outerType.asInstanceOf[TypeRefApi].args.head @@ -114,13 +120,15 @@ object OptionOrderedBuf { new TreeOrderedBuf[c.type] { override val ctx: c.type = c override val tpe = outerType - override def compareBinary(inputStreamA: TermName, inputStreamB: TermName) = genBinaryCompare(inputStreamA, inputStreamB) + override def compareBinary(inputStreamA: TermName, inputStreamB: TermName) = + genBinaryCompare(inputStreamA, inputStreamB) override def hash(element: TermName): ctx.Tree = genHashFn(element) override def put(inputStream: TermName, element: TermName) = genPutFn(inputStream, element) override def get(inputStreamA: TermName): ctx.Tree = genGetFn(inputStreamA) - override def compare(elementA: TermName, elementB: TermName): ctx.Tree = genCompareFn(elementA, elementB) + override def compare(elementA: TermName, elementB: TermName): ctx.Tree = + genCompareFn(elementA, elementB) override val lazyOuterVariables: Map[String, ctx.Tree] = innerBuf.lazyOuterVariables - override def length(element: Tree): CompileTimeLengthTypes[c.type] = { + override def length(element: Tree): CompileTimeLengthTypes[c.type] = innerBuf.length(q"$element.get") match { case const: ConstantLengthCalculation[_] => FastLengthCalculation(c)(q""" if($element.isDefined) { 1 + ${const.toInt} } @@ -134,15 +142,14 @@ object OptionOrderedBuf { """) case m: MaybeLengthCalculation[_] => val t = m.asInstanceOf[MaybeLengthCalculation[c.type]].t - val dynlen = q"""_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen""" + val dynlen = + q"""_root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen""" MaybeLengthCalculation(c)(q""" if ($element.isDefined) { $t + $dynlen(1) } else { $dynlen(1) } """) case _ => NoLengthCalculationAvailable(c) } - } } } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/PrimitiveOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/PrimitiveOrderedBuf.scala index 41b7584763..b7e8618dcb 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/PrimitiveOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/PrimitiveOrderedBuf.scala @@ -16,10 +16,14 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context import com.twitter.scalding._ -import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ + CompileTimeLengthTypes, + ProductLike, + TreeOrderedBuf +} import CompileTimeLengthTypes._ import java.nio.ByteBuffer import com.twitter.scalding.serialization.OrderedSerialization @@ -60,26 +64,27 @@ object PrimitiveOrderedBuf { PrimitiveOrderedBuf(c)(tpe, "Double", 8, true) } - def apply(c: Context)(outerType: c.Type, + def apply(c: Context)( + outerType: c.Type, javaTypeStr: String, lenInBytes: Int, boxed: Boolean): TreeOrderedBuf[c.type] = { import c.universe._ - val javaType = newTermName(javaTypeStr) + val javaType = TermName(javaTypeStr) - def freshT(id: String) = newTermName(c.fresh(s"fresh_$id")) + def freshT(id: String) = TermName(c.freshName(s"fresh_$id")) - val shortName: String = Map("Integer" -> "Int", "Character" -> "Char") - .getOrElse(javaTypeStr, javaTypeStr) + val shortName: String = + Map("Integer" -> "Int", "Character" -> "Char").getOrElse(javaTypeStr, javaTypeStr) - val bbGetter = newTermName("read" + shortName) - val bbPutter = newTermName("write" + shortName) + val bbGetter = TermName("read" + shortName) + val bbPutter = TermName("write" + shortName) def genBinaryCompare(inputStreamA: TermName, inputStreamB: TermName): Tree = q"""_root_.java.lang.$javaType.compare($inputStreamA.$bbGetter, $inputStreamB.$bbGetter)""" def accessor(e: c.TermName): c.Tree = { - val primitiveAccessor = newTermName(shortName.toLowerCase + "Value") + val primitiveAccessor = TermName(shortName.toLowerCase + "Value") if (boxed) q"$e.$primitiveAccessor" else q"$e" } @@ -91,7 +96,7 @@ object PrimitiveOrderedBuf { genBinaryCompare(inputStreamA, inputStreamB) override def hash(element: ctx.TermName): ctx.Tree = { // This calls out the correctly named item in Hasher - val typeLowerCase = newTermName(javaTypeStr.toLowerCase) + val typeLowerCase = TermName(javaTypeStr.toLowerCase) q"_root_.com.twitter.scalding.serialization.Hasher.$typeLowerCase.hash(${accessor(element)})" } override def put(inputStream: ctx.TermName, element: ctx.TermName) = @@ -113,4 +118,3 @@ object PrimitiveOrderedBuf { } } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ProductOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ProductOrderedBuf.scala index d97c7be71b..dae8f490ca 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ProductOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/ProductOrderedBuf.scala @@ -16,18 +16,24 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context import com.twitter.scalding._ -import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ + CompileTimeLengthTypes, + ProductLike, + TreeOrderedBuf +} import CompileTimeLengthTypes._ import java.nio.ByteBuffer import com.twitter.scalding.serialization.OrderedSerialization object ProductOrderedBuf { - def dispatch(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { + def dispatch(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]) + : PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { import c.universe._ - val validTypes: List[Type] = List(typeOf[Product1[Any]], + val validTypes: List[Type] = List( + typeOf[Product1[Any]], typeOf[Product2[Any, Any]], typeOf[Product3[Any, Any, Any]], typeOf[Product4[Any, Any, Any, Any]], @@ -42,46 +48,173 @@ object ProductOrderedBuf { typeOf[Product13[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], typeOf[Product14[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], typeOf[Product15[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], - typeOf[Product16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], - typeOf[Product17[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], - typeOf[Product18[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], - typeOf[Product19[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], - typeOf[Product20[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], - typeOf[Product21[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], - typeOf[Product22[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]]) + typeOf[ + Product16[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]], + typeOf[ + Product17[Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any]], + typeOf[ + Product18[Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any]], + typeOf[ + Product19[Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any]], + typeOf[ + Product20[Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any]], + typeOf[ + Product21[Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any]], + typeOf[ + Product22[Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any, + Any]] + ) def validType(curType: Type): Boolean = - validTypes.exists { t => curType <:< t } + validTypes.exists { t => + curType <:< t + } // The `_.get` is safe since it's always preceded by a matching // `_.isDefined` check in `validType` @SuppressWarnings(Array("org.wartremover.warts.OptionPartial")) def symbolFor(subType: Type): Type = { - val superType = validTypes.find{ t => subType.erasure <:< t }.get + val superType = validTypes.find { t => + subType.erasure <:< t + }.get subType.baseType(superType.typeSymbol) } val pf: PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { - case tpe if validType(tpe.erasure) => ProductOrderedBuf(c)(buildDispatcher, tpe, symbolFor(tpe)) + case tpe if validType(tpe.erasure) => + ProductOrderedBuf(c)(buildDispatcher, tpe, symbolFor(tpe)) } pf } - def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], originalType: c.Type, outerType: c.Type): TreeOrderedBuf[c.type] = { + def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], + originalType: c.Type, + outerType: c.Type): TreeOrderedBuf[c.type] = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) val dispatcher = buildDispatcher val elementData: List[(c.universe.Type, TermName, TreeOrderedBuf[c.type])] = - outerType - .declarations + outerType.decls .collect { case m: MethodSymbol => m } - .filter(m => m.name.toTermName.toString.startsWith("_")) + .filter(m => m.name.toString.startsWith("_")) .map { accessorMethod => - val fieldType = accessorMethod.returnType.asSeenFrom(outerType, outerType.typeSymbol.asClass) + val fieldType = + accessorMethod.returnType.asSeenFrom(outerType, outerType.typeSymbol.asClass) val b: TreeOrderedBuf[c.type] = dispatcher(fieldType) - (fieldType, accessorMethod.name.toTermName, b) - }.toList + (fieldType, accessorMethod.name, b) + } + .toList new TreeOrderedBuf[c.type] { override val ctx: c.type = c @@ -89,7 +222,8 @@ object ProductOrderedBuf { override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = ProductLike.compareBinary(c)(inputStreamA, inputStreamB)(elementData) - override def hash(element: ctx.TermName): ctx.Tree = ProductLike.hash(c)(element)(elementData) + override def hash(element: ctx.TermName): ctx.Tree = + ProductLike.hash(c)(element)(elementData) override def put(inputStream: ctx.TermName, element: ctx.TermName) = ProductLike.put(c)(inputStream, element)(elementData) @@ -122,4 +256,3 @@ object ProductOrderedBuf { } } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/SealedTraitOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/SealedTraitOrderedBuf.scala index be893aa774..42f589d7d5 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/SealedTraitOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/SealedTraitOrderedBuf.scala @@ -18,57 +18,76 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization.pro import com.twitter.scalding.serialization.macros.impl.ordered_serialization._ import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context object SealedTraitOrderedBuf { def dispatch(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { import c.universe._ val pf: PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { - case tpe if (tpe.typeSymbol.isClass && (tpe.typeSymbol.asClass.isAbstractClass || tpe.typeSymbol.asClass.isTrait)) => SealedTraitOrderedBuf(c)(buildDispatcher, tpe) + case tpe if (tpe.typeSymbol.isClass && (tpe.typeSymbol.asClass.isAbstractClass || tpe.typeSymbol.asClass.isTrait)) => + SealedTraitOrderedBuf(c)(buildDispatcher, tpe) } pf } - def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], outerType: c.Type): TreeOrderedBuf[c.type] = { + def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], + outerType: c.Type): TreeOrderedBuf[c.type] = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(s"$id")) + def freshT(id: String) = TermName(c.freshName(s"$id")) - val knownDirectSubclasses = StableKnownDirectSubclasses(c)(outerType) + val knownDirectSubclasses = outerType.typeSymbol.asClass.knownDirectSubclasses if (knownDirectSubclasses.isEmpty) - c.abort(c.enclosingPosition, s"Unable to access any knownDirectSubclasses for $outerType , a bug in scala 2.10/2.11 makes this unreliable.") + sys.error( + s"Unable to access any knownDirectSubclasses for $outerType , a bug in scala 2.10/2.11 makes this unreliable. -- ${c.enclosingPosition}") + + // 22 is a magic number, so pick it aligning with usual size for case class fields + // could be bumped, but the getLength method may get slow, or fail to compile at some point. + if (knownDirectSubclasses.size > 22) + sys.error( + s"More than 22 subclasses($outerType). This code is inefficient for this and may cause jvm errors. Supply code manually. -- ${c.enclosingPosition}") val subClassesValid = knownDirectSubclasses.forall { sc => scala.util.Try(sc.asType.asClass.isCaseClass).getOrElse(false) } if (!subClassesValid) - c.abort(c.enclosingPosition, "We only support the extension of a sealed trait with case classes.") + sys.error( + s"We only support the extension of a sealed trait with case classes, for type $outerType -- ${c.enclosingPosition}") val dispatcher = buildDispatcher - val subClasses: List[Type] = knownDirectSubclasses.map(_.asType.toType).toList + val subClasses: List[Type] = + knownDirectSubclasses.map(_.asType.toType).toList.sortBy(_.toString) - val subData: List[(Int, Type, TreeOrderedBuf[c.type])] = subClasses.map { t => - (t, dispatcher(t)) - }.zipWithIndex.map{ case ((tpe, tbuf), idx) => (idx, tpe, tbuf) }.toList + val subData: List[(Int, Type, TreeOrderedBuf[c.type])] = subClasses + .map { t => + (t, dispatcher(t)) + } + .zipWithIndex + .map { case ((tpe, tbuf), idx) => (idx, tpe, tbuf) } - require(subData.nonEmpty, "Unable to parse any subtypes for the sealed trait, error. This must be an error.") + require(subData.nonEmpty, + "Unable to parse any subtypes for the sealed trait, error. This must be an error.") new TreeOrderedBuf[c.type] { override val ctx: c.type = c override val tpe = outerType - override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = SealedTraitLike.compareBinary(c)(inputStreamA, inputStreamB)(subData) + override def compareBinary(inputStreamA: ctx.TermName, inputStreamB: ctx.TermName) = + SealedTraitLike.compareBinary(c)(inputStreamA, inputStreamB)(subData) override def hash(element: ctx.TermName): ctx.Tree = SealedTraitLike.hash(c)(element)(subData) - override def put(inputStream: ctx.TermName, element: ctx.TermName) = SealedTraitLike.put(c)(inputStream, element)(subData) - override def get(inputStream: ctx.TermName): ctx.Tree = SealedTraitLike.get(c)(inputStream)(subData) - override def compare(elementA: ctx.TermName, elementB: ctx.TermName): ctx.Tree = SealedTraitLike.compare(c)(outerType, elementA, elementB)(subData) - override def length(element: Tree): CompileTimeLengthTypes[c.type] = SealedTraitLike.length(c)(element)(subData) + override def put(inputStream: ctx.TermName, element: ctx.TermName) = + SealedTraitLike.put(c)(inputStream, element)(subData) + override def get(inputStream: ctx.TermName): ctx.Tree = + SealedTraitLike.get(c)(inputStream)(subData) + override def compare(elementA: ctx.TermName, elementB: ctx.TermName): ctx.Tree = + SealedTraitLike.compare(c)(outerType, elementA, elementB)(subData) + override def length(element: Tree): CompileTimeLengthTypes[c.type] = + SealedTraitLike.length(c)(element)(subData) override val lazyOuterVariables: Map[String, ctx.Tree] = subData.map(_._3.lazyOuterVariables).reduce(_ ++ _) } } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/StableKnownDirectSubclasses.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/StableKnownDirectSubclasses.scala index 486c41349d..3c9cd39e62 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/StableKnownDirectSubclasses.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/StableKnownDirectSubclasses.scala @@ -13,6 +13,6 @@ import scala.reflect.macros.whitebox.Context */ object StableKnownDirectSubclasses { - def apply(c: Context)(tpe: c.Type): List[c.universe.TypeSymbol] = + def apply(c: Context)(tpe: c.Type): List[c.universe.TypeSymbol] = // linter:ignore:UnusedParameter tpe.typeSymbol.asClass.knownDirectSubclasses.map(_.asType).toList.sortBy(_.fullName) } diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/StringOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/StringOrderedBuf.scala index 2c1f6f846c..a85fa5374f 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/StringOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/StringOrderedBuf.scala @@ -16,10 +16,14 @@ package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers import scala.language.experimental.macros -import scala.reflect.macros.Context +import scala.reflect.macros.blackbox.Context import com.twitter.scalding._ -import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ + CompileTimeLengthTypes, + ProductLike, + TreeOrderedBuf +} import CompileTimeLengthTypes._ import java.nio.ByteBuffer import com.twitter.scalding.serialization.OrderedSerialization @@ -32,7 +36,7 @@ object StringOrderedBuf { def apply(c: Context)(outerType: c.Type): TreeOrderedBuf[c.type] = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(id)) + def freshT(id: String) = TermName(c.freshName(id)) new TreeOrderedBuf[c.type] { override val ctx: c.type = c @@ -51,7 +55,8 @@ object StringOrderedBuf { """ } - override def hash(element: ctx.TermName): ctx.Tree = q"_root_.com.twitter.scalding.serialization.Hasher.string.hash($element)" + override def hash(element: ctx.TermName): ctx.Tree = + q"_root_.com.twitter.scalding.serialization.Hasher.string.hash($element)" override def put(inputStream: ctx.TermName, element: ctx.TermName) = { val bytes = freshT("bytes") @@ -113,7 +118,8 @@ object StringOrderedBuf { q"""$elementA.compareTo($elementB)""" override val lazyOuterVariables: Map[String, ctx.Tree] = Map.empty - override def length(element: Tree): CompileTimeLengthTypes[c.type] = MaybeLengthCalculation(c)(q""" + override def length(element: Tree): CompileTimeLengthTypes[c.type] = + MaybeLengthCalculation(c)(q""" if($element.isEmpty) { _root_.com.twitter.scalding.serialization.macros.impl.ordered_serialization.runtime_helpers.DynamicLen(1) } else { @@ -123,4 +129,3 @@ object StringOrderedBuf { } } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/TraversablesOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/TraversablesOrderedBuf.scala index da3a3f7a44..df856a7e46 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/TraversablesOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/TraversablesOrderedBuf.scala @@ -15,16 +15,13 @@ */ package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers -import scala.language.experimental.macros -import scala.reflect.macros.Context -import java.io.InputStream - -import com.twitter.scalding._ -import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import scala.reflect.macros.blackbox.Context +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ + CompileTimeLengthTypes, + ProductLike, + TreeOrderedBuf +} import CompileTimeLengthTypes._ -import com.twitter.scalding.serialization.OrderedSerialization -import scala.reflect.ClassTag - import scala.{ collection => sc } import scala.collection.{ immutable => sci } @@ -38,41 +35,65 @@ case object NotArray extends MaybeArray object TraversablesOrderedBuf { def dispatch(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]]): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { - case tpe if tpe.erasure =:= c.universe.typeOf[Iterable[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sci.Iterable[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[List[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sci.List[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[Seq[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sc.Seq[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sci.Seq[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[Vector[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sci.Vector[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[IndexedSeq[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sci.IndexedSeq[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sci.Queue[Any]] => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[Iterable[Any]] => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.Iterable[Any]] => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[List[Any]] => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.List[Any]] => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[Seq[Any]] => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sc.Seq[Any]] => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.Seq[Any]] => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[Vector[Any]] => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.Vector[Any]] => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[IndexedSeq[Any]] => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.IndexedSeq[Any]] => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.Queue[Any]] => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, NotArray) // Arrays are special in that the erasure doesn't do anything - case tpe if tpe.typeSymbol == c.universe.typeOf[Array[Any]].typeSymbol => TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, IsArray) + case tpe if tpe.typeSymbol == c.universe.typeOf[Array[Any]].typeSymbol => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, NoSort, IsArray) // The erasure of a non-covariant is Set[_], so we need that here for sets - case tpe if tpe.erasure =:= c.universe.typeOf[Set[Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sc.Set[Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sci.Set[Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sci.HashSet[Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sci.ListSet[Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) - - case tpe if tpe.erasure =:= c.universe.typeOf[Map[Any, Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sc.Map[Any, Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sci.Map[Any, Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sci.HashMap[Any, Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) - case tpe if tpe.erasure =:= c.universe.typeOf[sci.ListMap[Any, Any]].erasure => TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[Set[Any]].erasure => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sc.Set[Any]].erasure => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.Set[Any]].erasure => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.HashSet[Any]].erasure => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.ListSet[Any]].erasure => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + + case tpe if tpe.erasure =:= c.universe.typeOf[Map[Any, Any]].erasure => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sc.Map[Any, Any]].erasure => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.Map[Any, Any]].erasure => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.HashMap[Any, Any]].erasure => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) + case tpe if tpe.erasure =:= c.universe.typeOf[sci.ListMap[Any, Any]].erasure => + TraversablesOrderedBuf(c)(buildDispatcher, tpe, DoSort, NotArray) } - def apply(c: Context)(buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], + def apply(c: Context)( + buildDispatcher: => PartialFunction[c.Type, TreeOrderedBuf[c.type]], outerType: c.Type, maybeSort: ShouldSort, maybeArray: MaybeArray): TreeOrderedBuf[c.type] = { import c.universe._ - def freshT(id: String) = newTermName(c.fresh(s"fresh_$id")) + def freshT(id: String) = TermName(c.freshName(s"fresh_$id")) val dispatcher = buildDispatcher @@ -81,7 +102,8 @@ object TraversablesOrderedBuf { // When dealing with a map we have 2 type args, and need to generate the tuple type // it would correspond to if we .toList the Map. val innerType = if (outerType.asInstanceOf[TypeRefApi].args.size == 2) { - val (tpe1, tpe2) = (outerType.asInstanceOf[TypeRefApi].args(0), outerType.asInstanceOf[TypeRefApi].args(1)) // linter:ignore + val (tpe1, tpe2) = (outerType.asInstanceOf[TypeRefApi].args.head, + outerType.asInstanceOf[TypeRefApi].args(1)) // linter:ignore val containerType = typeOf[Tuple2[Any, Any]].asInstanceOf[TypeRef] import compat._ TypeRef.apply(containerType.pre, containerType.sym, List(tpe1, tpe2)) @@ -172,7 +194,10 @@ object TraversablesOrderedBuf { $element.foreach { t => val $target = t $currentHash = - _root_.com.twitter.scalding.serialization.MurmurHashUtils.mixH1($currentHash, ${innerBuf.hash(target)}) + _root_.com.twitter.scalding.serialization.MurmurHashUtils.mixH1($currentHash, ${ + innerBuf + .hash(target) + }) // go ahead and compute the length so we don't traverse twice for lists $len += 1 } @@ -257,8 +282,7 @@ object TraversablesOrderedBuf { override val lazyOuterVariables: Map[String, ctx.Tree] = innerBuf.lazyOuterVariables - override def length(element: Tree): CompileTimeLengthTypes[c.type] = { - + override def length(element: Tree): CompileTimeLengthTypes[c.type] = innerBuf.length(q"$element.head") match { case const: ConstantLengthCalculation[_] => FastLengthCalculation(c)(q"""{ @@ -296,8 +320,6 @@ object TraversablesOrderedBuf { } """) } - } } } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/UnitOrderedBuf.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/UnitOrderedBuf.scala index e0cedb05a9..3c5a7c116d 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/UnitOrderedBuf.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/providers/UnitOrderedBuf.scala @@ -15,14 +15,12 @@ */ package com.twitter.scalding.serialization.macros.impl.ordered_serialization.providers -import scala.language.experimental.macros -import scala.reflect.macros.Context - -import com.twitter.scalding._ -import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ CompileTimeLengthTypes, ProductLike, TreeOrderedBuf } +import scala.reflect.macros.blackbox.Context +import com.twitter.scalding.serialization.macros.impl.ordered_serialization.{ + CompileTimeLengthTypes, + TreeOrderedBuf +} import CompileTimeLengthTypes._ -import java.nio.ByteBuffer -import com.twitter.scalding.serialization.OrderedSerialization object UnitOrderedBuf { def dispatch(c: Context): PartialFunction[c.Type, TreeOrderedBuf[c.type]] = { @@ -59,4 +57,3 @@ object UnitOrderedBuf { } } } - diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/LengthCalculations.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/LengthCalculations.scala index 137c82060b..f69f83dfa2 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/LengthCalculations.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/LengthCalculations.scala @@ -26,14 +26,14 @@ sealed trait MaybeLength { case object NoLengthCalculation extends MaybeLength { def +(that: MaybeLength): MaybeLength = this } -case class ConstLen(toInt: Int) extends MaybeLength { +final case class ConstLen(toInt: Int) extends MaybeLength { def +(that: MaybeLength): MaybeLength = that match { case ConstLen(c) => ConstLen(toInt + c) case DynamicLen(d) => DynamicLen(toInt + d) case NoLengthCalculation => NoLengthCalculation } } -case class DynamicLen(toInt: Int) extends MaybeLength { +final case class DynamicLen(toInt: Int) extends MaybeLength { def +(that: MaybeLength): MaybeLength = that match { case ConstLen(c) => DynamicLen(toInt + c) case DynamicLen(d) => DynamicLen(toInt + d) diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/MacroEqualityOrderedSerialization.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/MacroEqualityOrderedSerialization.scala index 46290e1188..91c90e3c00 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/MacroEqualityOrderedSerialization.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/MacroEqualityOrderedSerialization.scala @@ -21,11 +21,13 @@ object MacroEqualityOrderedSerialization { private val seed = "MacroEqualityOrderedSerialization".hashCode } -abstract class MacroEqualityOrderedSerialization[T] extends OrderedSerialization[T] with EquivSerialization[T] { +abstract class MacroEqualityOrderedSerialization[T] + extends OrderedSerialization[T] + with EquivSerialization[T] { def uniqueId: String override def hashCode = MacroEqualityOrderedSerialization.seed ^ uniqueId.hashCode override def equals(other: Any): Boolean = other match { case o: MacroEqualityOrderedSerialization[_] => o.uniqueId == uniqueId case _ => false } -} \ No newline at end of file +} diff --git a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/TraversableHelpers.scala b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/TraversableHelpers.scala index 8e9961e9f4..a754de1f82 100644 --- a/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/TraversableHelpers.scala +++ b/scalding-serialization/src/main/scala/com/twitter/scalding/serialization/macros/impl/ordered_serialization/runtime_helpers/TraversableHelpers.scala @@ -21,7 +21,8 @@ import scala.collection.mutable.Buffer object TraversableHelpers { import com.twitter.scalding.serialization.JavaStreamEnrichments._ - final def rawCompare(inputStreamA: InputStream, inputStreamB: InputStream)(consume: (InputStream, InputStream) => Int): Int = { + final def rawCompare(inputStreamA: InputStream, inputStreamB: InputStream)( + consume: (InputStream, InputStream) => Int): Int = { val lenA = inputStreamA.readPosVarInt val lenB = inputStreamB.readPosVarInt @@ -37,7 +38,8 @@ object TraversableHelpers { else java.lang.Integer.compare(lenA, lenB) } - final def iteratorCompare[T](iteratorA: Iterator[T], iteratorB: Iterator[T])(implicit ord: Ordering[T]): Int = { + final def iteratorCompare[T](iteratorA: Iterator[T], iteratorB: Iterator[T])( + implicit ord: Ordering[T]): Int = { @annotation.tailrec def result: Int = if (iteratorA.isEmpty) { @@ -55,7 +57,8 @@ object TraversableHelpers { result } - final def iteratorEquiv[T](iteratorA: Iterator[T], iteratorB: Iterator[T])(implicit eq: Equiv[T]): Boolean = { + final def iteratorEquiv[T](iteratorA: Iterator[T], iteratorB: Iterator[T])( + implicit eq: Equiv[T]): Boolean = { @annotation.tailrec def result: Boolean = if (iteratorA.isEmpty) iteratorB.isEmpty @@ -64,6 +67,7 @@ object TraversableHelpers { result } + /** * This returns the same result as * @@ -74,7 +78,8 @@ object TraversableHelpers { * the complexity should be O(N + M) rather than O(N log N + M log M) for the full * sort case */ - final def sortedCompare[T](travA: Iterable[T], travB: Iterable[T])(implicit ord: Ordering[T]): Int = { + final def sortedCompare[T](travA: Iterable[T], travB: Iterable[T])( + implicit ord: Ordering[T]): Int = { def compare(startA: Int, endA: Int, a: Buffer[T], startB: Int, endB: Int, b: Buffer[T]): Int = if (startA == endA) { if (startB == endB) 0 // both empty @@ -82,7 +87,11 @@ object TraversableHelpers { } else if (startB == endB) 1 // non-empty is bigger than empty else { @annotation.tailrec - def partition(pivot: T, pivotStart: Int, pivotEnd: Int, endX: Int, x: Buffer[T]): (Int, Int) = { + def partition(pivot: T, + pivotStart: Int, + pivotEnd: Int, + endX: Int, + x: Buffer[T]): (Int, Int) = if (pivotEnd >= endX) (pivotStart, pivotEnd) else { val t = x(pivotEnd) @@ -106,7 +115,6 @@ object TraversableHelpers { partition(pivot, pivotStart + 1, pivotEnd + 1, endX, x) } } - } val pivot = a(startA) val (aps, ape) = partition(pivot, startA, startA + 1, endA, a) val (bps, bpe) = partition(pivot, startB, startB, endB, b) diff --git a/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/macros/MacroOrderingProperties.scala b/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/macros/MacroOrderingProperties.scala index 37be34a32b..c8ae779082 100644 --- a/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/macros/MacroOrderingProperties.scala +++ b/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/macros/MacroOrderingProperties.scala @@ -15,22 +15,27 @@ limitations under the License. */ package com.twitter.scalding.serialization.macros - +import scala.language.higherKinds import java.io.{ ByteArrayOutputStream, InputStream } import java.nio.ByteBuffer -import com.twitter.scalding.serialization.{ JavaStreamEnrichments, Law, Law1, Law2, Law3, OrderedSerialization, Serialization } +import com.twitter.scalding.serialization.{ + JavaStreamEnrichments, + Law, + Law1, + Law2, + Law3, + OrderedSerialization, + Serialization +} import org.scalacheck.Arbitrary.{ arbitrary => arb } import org.scalacheck.{ Arbitrary, Gen, Prop } import org.scalatest.prop.{ Checkers, PropertyChecks } -import org.scalatest.{ FunSuite, Matchers } - +import org.scalatest.FunSuite //, ShouldMatchers } +import com.twitter.scalding.some.other.space.space._ import scala.collection.immutable.Queue import scala.language.experimental.macros - -trait LowerPriorityImplicit { - implicit def primitiveOrderedBufferSupplier[T]: OrderedSerialization[T] = macro impl.OrderedSerializationProviderImpl[T] -} +import com.twitter.scalding.serialization.macros.impl.BinaryOrdering object LawTester { def apply[T: Arbitrary](laws: Iterable[Law[T]]): Prop = @@ -108,6 +113,34 @@ object TestCC { } yield testSealedAbstractClass } + implicit def arbitraryElementY: Arbitrary[ContainerX.ElementY] = Arbitrary { + for { + v <- arb[String] + } yield ContainerX.ElementY(v) + } + + implicit def arbitraryElementZ: Arbitrary[ContainerX.ElementZ] = Arbitrary { + for { + v <- arb[String] + } yield ContainerX.ElementZ(v) + } + + implicit def arbitraryTestCaseHardA: Arbitrary[TestCaseHardA] = Arbitrary { + for { + cc <- arb[ContainerX.ElementY] + bb <- arb[ContainerX.ElementZ] + o <- arb[String] + t <- Gen.oneOf(cc, bb) + } yield TestCaseHardA(t, o) + } + + implicit def arbitraryTestCaseHardB: Arbitrary[TestCaseHardB] = Arbitrary { + for { + o <- arb[String] + t <- Gen.oneOf(ContainerP.ElementA, ContainerP.ElementB) + } yield TestCaseHardB(t, o) + } + } sealed abstract class TestSealedAbstractClass(val name: Option[String]) @@ -115,9 +148,17 @@ case object A extends TestSealedAbstractClass(None) case object B extends TestSealedAbstractClass(Some("b")) sealed trait SealedTraitTest -case class TestCC(a: Int, b: Long, c: Option[Int], d: Double, e: Option[String], f: Option[List[String]], aBB: ByteBuffer) extends SealedTraitTest - -case class TestCaseClassB(a: Int, b: Long, c: Option[Int], d: Double, e: Option[String]) extends SealedTraitTest +case class TestCC(a: Int, + b: Long, + c: Option[Int], + d: Double, + e: Option[String], + f: Option[List[String]], + aBB: ByteBuffer) + extends SealedTraitTest + +case class TestCaseClassB(a: Int, b: Long, c: Option[Int], d: Double, e: Option[String]) + extends SealedTraitTest case class TestCaseClassD(a: Int) extends SealedTraitTest @@ -127,6 +168,33 @@ case object TestObjectE extends SealedTraitTest case class TypedParameterCaseClass[A](v: A) +sealed trait BigTrait +case class BigTraitA(a: Int) extends BigTrait +case class BigTraitC(a: Int) extends BigTrait +case class BigTraitD(a: Int) extends BigTrait +case class BigTraitE(a: Int) extends BigTrait +case class BigTraitF(a: Int) extends BigTrait +case class BigTraitG(a: Int) extends BigTrait +case class BigTraitH(a: Int) extends BigTrait +case class BigTraitI(a: Int) extends BigTrait +case class BigTraitJ(a: Int) extends BigTrait +case class BigTraitK(a: Int) extends BigTrait +case class BigTraitL(a: Int) extends BigTrait +case class BigTraitM(a: Int) extends BigTrait +case class BigTraitN(a: Int) extends BigTrait +case class BigTraitO(a: Int) extends BigTrait +case class BigTraitP(a: Int) extends BigTrait +case class BigTraitQ(a: Int) extends BigTrait +case class BigTraitR(a: Int) extends BigTrait +case class BigTraitS(a: Int) extends BigTrait +case class BigTraitT(a: Int) extends BigTrait +case class BigTraitU(a: Int) extends BigTrait +case class BigTraitV(a: Int) extends BigTrait +case class BigTraitW(a: Int) extends BigTrait +case class BigTraitX(a: Int) extends BigTrait +case class BigTraitY(a: Int) extends BigTrait +case class BigTraitZ(a: Int) extends BigTrait + object MyData { implicit def arbitraryTestCC: Arbitrary[MyData] = Arbitrary { for { @@ -136,18 +204,20 @@ object MyData { } } -class MyData(override val _1: Int, override val _2: Option[Long]) extends Product2[Int, Option[Long]] { +class MyData(override val _1: Int, override val _2: Option[Long]) + extends Product2[Int, Option[Long]] { override def canEqual(that: Any): Boolean = that match { case o: MyData => true case _ => false } override def equals(obj: scala.Any): Boolean = obj match { - case o: MyData => (o._2, _2) match { - case (Some(l), Some(r)) => r == l && _1 == o._1 - case (None, None) => _1 == o._1 - case _ => false - } + case o: MyData => + (o._2, _2) match { + case (Some(l), Some(r)) => r == l && _1 == o._1 + case (None, None) => _1 == o._1 + case _ => false + } case _ => false } @@ -156,19 +226,23 @@ class MyData(override val _1: Int, override val _2: Option[Long]) extends Produc } object MacroOpaqueContainer { - def getOrdSer[T]: OrderedSerialization[T] = macro impl.OrderedSerializationProviderImpl[T] import java.io._ implicit val myContainerOrderedSerializer = new OrderedSerialization[MacroOpaqueContainer] { - val intOrderedSerialization = getOrdSer[Int] + val intOrderedSerialization = BinaryOrdering.ordSer[Int] - override def hash(s: MacroOpaqueContainer) = intOrderedSerialization.hash(s.myField) ^ Int.MaxValue - override def compare(a: MacroOpaqueContainer, b: MacroOpaqueContainer) = intOrderedSerialization.compare(a.myField, b.myField) + override def hash(s: MacroOpaqueContainer) = + intOrderedSerialization.hash(s.myField) ^ Int.MaxValue + override def compare(a: MacroOpaqueContainer, b: MacroOpaqueContainer) = + intOrderedSerialization.compare(a.myField, b.myField) - override def read(in: InputStream) = intOrderedSerialization.read(in).map(MacroOpaqueContainer(_)) + override def read(in: InputStream) = + intOrderedSerialization.read(in).map(MacroOpaqueContainer(_)) - override def write(b: OutputStream, s: MacroOpaqueContainer) = intOrderedSerialization.write(b, s.myField) + override def write(b: OutputStream, s: MacroOpaqueContainer) = + intOrderedSerialization.write(b, s.myField) - override def compareBinary(lhs: InputStream, rhs: InputStream) = intOrderedSerialization.compareBinary(lhs, rhs) + override def compareBinary(lhs: InputStream, rhs: InputStream) = + intOrderedSerialization.compareBinary(lhs, rhs) override val staticSize = Some(4) override def dynamicSize(i: MacroOpaqueContainer) = staticSize @@ -202,7 +276,10 @@ object Container { type SetAlias = Set[Double] case class InnerCaseClass(e: SetAlias) } -class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers with LowerPriorityImplicit { +class MacroOrderingProperties + extends FunSuite + with PropertyChecks + with BinaryOrdering { type SetAlias = Set[Double] import ByteBufferArb._ @@ -213,13 +290,15 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers def arbMap[T: Arbitrary, U](fn: T => U): Arbitrary[U] = Arbitrary(gen[T].map(fn)) - def collectionArb[C[_], T: Arbitrary](implicit cbf: collection.generic.CanBuildFrom[Nothing, T, C[T]]): Arbitrary[C[T]] = Arbitrary { - gen[List[T]].map { l => - val builder = cbf() - l.foreach { builder += _ } - builder.result + def collectionArb[C[_], T: Arbitrary]( + implicit cbf: collection.generic.CanBuildFrom[Nothing, T, C[T]]): Arbitrary[C[T]] = + Arbitrary { + gen[List[T]].map { l => + val builder = cbf() + l.foreach { builder += _ } + builder.result + } } - } def serialize[T](t: T)(implicit orderedBuffer: OrderedSerialization[T]): InputStream = serializeSeq(List(t)) @@ -250,9 +329,13 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers val compareBinary = obuf.compareBinary(serializedA, serializedB).unsafeToInt val compareMem = obuf.compare(a, b) if (compareBinary < 0) { - assert(compareMem < 0, s"Compare binary: $compareBinary, and compareMem : $compareMem must have the same sign") + assert( + compareMem < 0, + s"Compare binary: $compareBinary, and compareMem : $compareMem must have the same sign") } else if (compareBinary > 0) { - assert(compareMem > 0, s"Compare binary: $compareBinary, and compareMem : $compareMem must have the same sign") + assert( + compareMem > 0, + s"Compare binary: $compareBinary, and compareMem : $compareMem must have the same sign") } } } @@ -275,25 +358,32 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers assert(oBufCompare(rta, a) === 0, s"A should be equal to itself after an RT -- ${rt(a)}") assert(oBufCompare(rtb, b) === 0, s"B should be equal to itself after an RT-- ${rt(b)}") assert(oBufCompare(a, b) + oBufCompare(b, a) === 0, "In memory comparasons make sense") - assert(rawCompare(a, b) + rawCompare(b, a) === 0, "When adding the raw compares in inverse order they should sum to 0") - assert(oBufCompare(rta, rtb) === oBufCompare(a, b), "Comparing a and b with ordered bufferables compare after a serialization RT") + assert(rawCompare(a, b) + rawCompare(b, a) === 0, + "When adding the raw compares in inverse order they should sum to 0") + assert(oBufCompare(rta, rtb) === oBufCompare(a, b), + "Comparing a and b with ordered bufferables compare after a serialization RT") } def checkAreSame[T](a: T, b: T)(implicit obuf: OrderedSerialization[T]): Unit = { val rta = rt(a) // before we do anything ensure these don't throw val rtb = rt(b) // before we do anything ensure these don't throw assert(oBufCompare(rta, a) === 0, s"A should be equal to itself after an RT -- ${rt(a)}") - assert(oBufCompare(rtb, b) === 0, "B should be equal to itself after an RT-- ${rt(b)}") + assert(oBufCompare(rtb, b) === 0, s"B should be equal to itself after an RT-- ${rt(b)}") assert(oBufCompare(a, b) === 0, "In memory comparasons make sense") assert(oBufCompare(b, a) === 0, "In memory comparasons make sense") - assert(rawCompare(a, b) === 0, "When adding the raw compares in inverse order they should sum to 0") - assert(rawCompare(b, a) === 0, "When adding the raw compares in inverse order they should sum to 0") - assert(oBufCompare(rta, rtb) === 0, "Comparing a and b with ordered bufferables compare after a serialization RT") + assert(rawCompare(a, b) === 0, + "When adding the raw compares in inverse order they should sum to 0") + assert(rawCompare(b, a) === 0, + "When adding the raw compares in inverse order they should sum to 0") + assert(oBufCompare(rta, rtb) === 0, + "Comparing a and b with ordered bufferables compare after a serialization RT") } def check[T: Arbitrary](implicit obuf: OrderedSerialization[T]) = { Checkers.check(LawTester(OrderedSerialization.allLaws)) - forAll(minSuccessful(500)) { (a: T, b: T) => checkWithInputs(a, b) } + forAll(minSuccessful(500)) { (a: T, b: T) => + checkWithInputs(a, b) + } } def checkCollisions[T: Arbitrary: OrderedSerialization] = { @@ -306,39 +396,50 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers assert(input.distinct.size - hashes.distinct.size <= 3) //generously allow upto 3 collision } + def noOrderedSerialization[T](implicit ev: OrderedSerialization[T] = null) = + assert(ev === null, "Expected unable to produce OrderedSerialization") + test("Test out Unit") { - primitiveOrderedBufferSupplier[Unit] + BinaryOrdering.ordSer[Unit] check[Unit] checkMany[Unit] } test("Test out Boolean") { - primitiveOrderedBufferSupplier[Boolean] + BinaryOrdering.ordSer[Boolean] check[Boolean] } test("Test out jl.Boolean") { - implicit val a: Arbitrary[java.lang.Boolean] = arbMap { b: Boolean => java.lang.Boolean.valueOf(b) } + implicit val a = arbMap { b: Boolean => + java.lang.Boolean.valueOf(b) + } check[java.lang.Boolean] } test("Test out Byte") { check[Byte] } test("Test out jl.Byte") { - implicit val a: Arbitrary[java.lang.Byte] = arbMap { b: Byte => java.lang.Byte.valueOf(b) } + implicit val a = arbMap { b: Byte => + java.lang.Byte.valueOf(b) + } check[java.lang.Byte] checkCollisions[java.lang.Byte] } test("Test out Short") { check[Short] } test("Test out jl.Short") { - implicit val a: Arbitrary[java.lang.Short] = arbMap { b: Short => java.lang.Short.valueOf(b) } + implicit val a = arbMap { b: Short => + java.lang.Short.valueOf(b) + } check[java.lang.Short] checkCollisions[java.lang.Short] } test("Test out Char") { check[Char] } test("Test out jl.Char") { - implicit val a: Arbitrary[java.lang.Character] = arbMap { b: Char => java.lang.Character.valueOf(b) } + implicit val a = arbMap { b: Char => + java.lang.Character.valueOf(b) + } check[java.lang.Character] checkCollisions[java.lang.Character] } test("Test out Int") { - primitiveOrderedBufferSupplier[Int] + BinaryOrdering.ordSer[Int] check[Int] checkMany[Int] checkCollisions[Int] @@ -353,7 +454,7 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers test("Test out Tuple of AnyVal's of String") { import TestCC._ - primitiveOrderedBufferSupplier[(TestCaseClassE, TestCaseClassE)] + BinaryOrdering.ordSer[(TestCaseClassE, TestCaseClassE)] check[(TestCaseClassE, TestCaseClassE)] checkMany[(TestCaseClassE, TestCaseClassE)] checkCollisions[(TestCaseClassE, TestCaseClassE)] @@ -361,39 +462,47 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers test("Test out Tuple of TestSealedAbstractClass") { import TestCC._ - primitiveOrderedBufferSupplier[TestSealedAbstractClass] + BinaryOrdering.ordSer[TestSealedAbstractClass] check[TestSealedAbstractClass] checkMany[TestSealedAbstractClass] checkCollisions[TestSealedAbstractClass] } test("Test out jl.Integer") { - implicit val a: Arbitrary[java.lang.Integer] = arbMap { b: Int => java.lang.Integer.valueOf(b) } + implicit val a = arbMap { b: Int => + java.lang.Integer.valueOf(b) + } check[java.lang.Integer] checkCollisions[java.lang.Integer] } test("Test out Float") { check[Float] } test("Test out jl.Float") { - implicit val a: Arbitrary[java.lang.Float] = arbMap { b: Float => java.lang.Float.valueOf(b) } + implicit val a = arbMap { b: Float => + java.lang.Float.valueOf(b) + } check[java.lang.Float] checkCollisions[java.lang.Float] } test("Test out Long") { check[Long] } test("Test out jl.Long") { - implicit val a: Arbitrary[java.lang.Long] = arbMap { b: Long => java.lang.Long.valueOf(b) } + implicit val a = arbMap { b: Long => + java.lang.Long.valueOf(b) + } check[java.lang.Long] checkCollisions[java.lang.Long] } test("Test out Double") { check[Double] } test("Test out jl.Double") { - implicit val a: Arbitrary[java.lang.Double] = arbMap { b: Double => java.lang.Double.valueOf(b) } + implicit val a = arbMap { b: Double => + java.lang.Double.valueOf(b) + } check[java.lang.Double] checkCollisions[java.lang.Double] } test("Test out String") { - primitiveOrderedBufferSupplier[String] + BinaryOrdering.ordSer[String] check[String] checkMany[String] @@ -401,140 +510,144 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers } test("Test out ByteBuffer") { - primitiveOrderedBufferSupplier[ByteBuffer] + BinaryOrdering.ordSer[ByteBuffer] check[ByteBuffer] checkCollisions[ByteBuffer] } test("Test out List[Float]") { - primitiveOrderedBufferSupplier[List[Float]] + BinaryOrdering.ordSer[List[Float]] check[List[Float]] checkCollisions[List[Float]] } test("Test out Queue[Int]") { - implicit val isa: Arbitrary[Queue[Int]] = collectionArb[Queue, Int] - primitiveOrderedBufferSupplier[Queue[Int]] + implicit val isa = collectionArb[Queue, Int] + BinaryOrdering.ordSer[Queue[Int]] check[Queue[Int]] checkCollisions[Queue[Int]] } test("Test out IndexedSeq[Int]") { - implicit val isa: Arbitrary[IndexedSeq[Int]] = collectionArb[IndexedSeq, Int] - primitiveOrderedBufferSupplier[IndexedSeq[Int]] + implicit val isa = collectionArb[IndexedSeq, Int] + BinaryOrdering.ordSer[IndexedSeq[Int]] check[IndexedSeq[Int]] checkCollisions[IndexedSeq[Int]] } test("Test out HashSet[Int]") { import scala.collection.immutable.HashSet - implicit val isa: Arbitrary[HashSet[Int]] = collectionArb[HashSet, Int] - primitiveOrderedBufferSupplier[HashSet[Int]] + implicit val isa = collectionArb[HashSet, Int] + BinaryOrdering.ordSer[HashSet[Int]] check[HashSet[Int]] checkCollisions[HashSet[Int]] } test("Test out ListSet[Int]") { import scala.collection.immutable.ListSet - implicit val isa: Arbitrary[ListSet[Int]] = collectionArb[ListSet, Int] - primitiveOrderedBufferSupplier[ListSet[Int]] + implicit val isa = collectionArb[ListSet, Int] + BinaryOrdering.ordSer[ListSet[Int]] check[ListSet[Int]] checkCollisions[ListSet[Int]] } test("Test out List[String]") { - primitiveOrderedBufferSupplier[List[String]] + BinaryOrdering.ordSer[List[String]] check[List[String]] checkCollisions[List[String]] } test("Test out List[List[String]]") { - val oBuf = primitiveOrderedBufferSupplier[List[List[String]]] + val oBuf = BinaryOrdering.ordSer[List[List[String]]] assert(oBuf.dynamicSize(List(List("sdf"))) === None) check[List[List[String]]] checkCollisions[List[List[String]]] } test("Test out List[Int]") { - primitiveOrderedBufferSupplier[List[Int]] + BinaryOrdering.ordSer[List[Int]] check[List[Int]] checkCollisions[List[Int]] } test("Test out SetAlias") { - primitiveOrderedBufferSupplier[SetAlias] + BinaryOrdering.ordSer[SetAlias] check[SetAlias] checkCollisions[SetAlias] } test("Container.InnerCaseClass") { - primitiveOrderedBufferSupplier[Container.InnerCaseClass] + BinaryOrdering.ordSer[Container.InnerCaseClass] check[Container.InnerCaseClass] checkCollisions[Container.InnerCaseClass] } test("Test out Seq[Int]") { - primitiveOrderedBufferSupplier[Seq[Int]] + BinaryOrdering.ordSer[Seq[Int]] check[Seq[Int]] checkCollisions[Seq[Int]] } test("Test out scala.collection.Seq[Int]") { - primitiveOrderedBufferSupplier[scala.collection.Seq[Int]] + BinaryOrdering.ordSer[scala.collection.Seq[Int]] check[scala.collection.Seq[Int]] checkCollisions[scala.collection.Seq[Int]] } test("Test out Array[Byte]") { - primitiveOrderedBufferSupplier[Array[Byte]] + BinaryOrdering.ordSer[Array[Byte]] check[Array[Byte]] checkCollisions[Array[Byte]] } test("Test out Vector[Int]") { - primitiveOrderedBufferSupplier[Vector[Int]] + BinaryOrdering.ordSer[Vector[Int]] check[Vector[Int]] checkCollisions[Vector[Int]] } test("Test out Iterable[Int]") { - primitiveOrderedBufferSupplier[Iterable[Int]] + BinaryOrdering.ordSer[Iterable[Int]] check[Iterable[Int]] checkCollisions[Iterable[Int]] } test("Test out Set[Int]") { - primitiveOrderedBufferSupplier[Set[Int]] + BinaryOrdering.ordSer[Set[Int]] check[Set[Int]] checkCollisions[Set[Int]] } test("Test out Set[Double]") { - primitiveOrderedBufferSupplier[Set[Double]] + BinaryOrdering.ordSer[Set[Double]] check[Set[Double]] checkCollisions[Set[Double]] } test("Test out Map[Long, Set[Int]]") { - primitiveOrderedBufferSupplier[Map[Long, Set[Int]]] + BinaryOrdering.ordSer[Map[Long, Set[Int]]] check[Map[Long, Set[Int]]] val c = List(Map(9223372036854775807L -> Set[Int]()), Map(-1L -> Set[Int](-2043106012))) - checkManyExplicit(c.map { i => (i, i) }) + checkManyExplicit(c.map { i => + (i, i) + }) checkMany[Map[Long, Set[Int]]] checkCollisions[Map[Long, Set[Int]]] } test("Test out Map[Long, Long]") { - primitiveOrderedBufferSupplier[Map[Long, Long]] + BinaryOrdering.ordSer[Map[Long, Long]] check[Map[Long, Long]] checkCollisions[Map[Long, Long]] } test("Test out HashMap[Long, Long]") { import scala.collection.immutable.HashMap - implicit val isa: Arbitrary[HashMap[Long, Long]] = Arbitrary(implicitly[Arbitrary[List[(Long, Long)]]].arbitrary.map(HashMap(_: _*))) - primitiveOrderedBufferSupplier[HashMap[Long, Long]] + implicit val isa = + Arbitrary(implicitly[Arbitrary[List[(Long, Long)]]].arbitrary.map(HashMap(_: _*))) + BinaryOrdering.ordSer[HashMap[Long, Long]] check[HashMap[Long, Long]] checkCollisions[HashMap[Long, Long]] } test("Test out ListMap[Long, Long]") { import scala.collection.immutable.ListMap - implicit val isa: Arbitrary[ListMap[Long, Long]] = Arbitrary(implicitly[Arbitrary[List[(Long, Long)]]].arbitrary.map(ListMap(_: _*))) - primitiveOrderedBufferSupplier[ListMap[Long, Long]] + implicit val isa = + Arbitrary(implicitly[Arbitrary[List[(Long, Long)]]].arbitrary.map(ListMap(_: _*))) + BinaryOrdering.ordSer[ListMap[Long, Long]] check[ListMap[Long, Long]] checkCollisions[ListMap[Long, Long]] } @@ -559,19 +672,24 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers val ord = Ordering.String assert(rawCompare(a, b) === ord.compare(a, b).signum, "Raw and in memory compares match.") - val c = List("榴㉕⊟풠湜ᙬ覹ꜻ裧뚐⠂覝쫨塢䇺楠谭픚ᐌ轮뺷Ⱟ洦擄黏著탅ﮓꆋ숷梸傠ァ蹵窥轲闇涡飽ꌳ䝞慙擃", + val c = List( + "榴㉕⊟풠湜ᙬ覹ꜻ裧뚐⠂覝쫨塢䇺楠谭픚ᐌ轮뺷Ⱟ洦擄黏著탅ﮓꆋ숷梸傠ァ蹵窥轲闇涡飽ꌳ䝞慙擃", "堒凳媨쉏떽㶥⾽샣井ㆠᇗ裉깴辫࠷᤭塈䎙寫㸉ᶴ䰄똇䡷䥞㷗䷱赫懓䷏剆祲ᝯ졑쐯헢鷴ӕ秔㽰ퟡ㏉鶖奚㙰银䮌ᕗ膾买씋썴행䣈丶偝쾕鐗쇊ኋ넥︇瞤䋗噯邧⹆♣ἷ铆玼⪷沕辤ᠥ⥰箼䔄◗", "騰쓢堷뛭ᣣﰩ嚲ﲯ㤑ᐜ檊೦⠩奯ᓩ윇롇러ᕰెꡩ璞﫼᭵礀閮䈦椄뾪ɔ믻䖔᪆嬽フ鶬曭꣍ᆏ灖㐸뗋ㆃ녵ퟸ겵晬礙㇩䫓ᘞ昑싨", "좃ఱ䨻綛糔唄࿁劸酊᫵橻쩳괊筆ݓ淤숪輡斋靑耜঄骐冠㝑⧠떅漫곡祈䵾ᳺ줵됵↲搸虂㔢Ꝅ芆٠풐쮋炞哙⨗쾄톄멛癔짍避쇜畾㣕剼⫁়╢ꅢ澛氌ᄚ㍠ꃫᛔ匙㜗詇閦單錖⒅瘧崥", "獌癚畇") - checkManyExplicit(c.map { i => (i, i) }) + checkManyExplicit(c.map { i => + (i, i) + }) val c2 = List("聸", "") - checkManyExplicit(c2.map { i => (i, i) }) + checkManyExplicit(c2.map { i => + (i, i) + }) } test("Test out Option[Int]") { - val oser = primitiveOrderedBufferSupplier[Option[Int]] + val oser = BinaryOrdering.ordSer[Option[Int]] assert(oser.staticSize === None, "can't get the size statically") check[Option[Int]] @@ -580,7 +698,7 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers } test("Test out Option[String]") { - primitiveOrderedBufferSupplier[Option[String]] + BinaryOrdering.ordSer[Option[String]] check[Option[String]] checkMany[Option[String]] @@ -588,38 +706,40 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers } test("Test Either[Int, Option[Int]]") { - val oser = primitiveOrderedBufferSupplier[Either[Int, Option[Int]]] + val oser = BinaryOrdering.ordSer[Either[Int, Option[Int]]] assert(oser.staticSize === None, "can't get the size statically") check[Either[Int, Option[Int]]] checkCollisions[Either[Int, Option[Int]]] } test("Test Either[Int, String]") { - val oser = primitiveOrderedBufferSupplier[Either[Int, String]] + val oser = BinaryOrdering.ordSer[Either[Int, String]] assert(oser.staticSize === None, "can't get the size statically") - assert(Some(Serialization.toBytes[Either[Int, String]](Left(1)).length) === oser.dynamicSize(Left(1)), + assert( + Some(Serialization.toBytes[Either[Int, String]](Left(1)).length) === oser.dynamicSize( + Left(1)), "serialization size matches dynamic size") check[Either[Int, String]] checkCollisions[Either[Int, String]] } test("Test Either[Int, Int]") { - val oser = primitiveOrderedBufferSupplier[Either[Int, Int]] + val oser = BinaryOrdering.ordSer[Either[Int, Int]] assert(oser.staticSize === Some(5), "can get the size statically") check[Either[Int, Int]] checkCollisions[Either[Int, Int]] } test("Test Either[String, Int]") { - primitiveOrderedBufferSupplier[Either[String, Int]] + BinaryOrdering.ordSer[Either[String, Int]] check[Either[String, Int]] checkCollisions[Either[String, Int]] } test("Test Either[String, String]") { - primitiveOrderedBufferSupplier[Either[String, String]] + BinaryOrdering.ordSer[Either[String, String]] check[Either[String, String]] checkCollisions[Either[String, String]] } test("Test out Option[Option[Int]]") { - primitiveOrderedBufferSupplier[Option[Option[Int]]] + BinaryOrdering.ordSer[Option[Option[Int]]] check[Option[Option[Int]]] checkCollisions[Option[Option[Int]]] @@ -631,7 +751,7 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers } test("test specific tuple aa1") { - primitiveOrderedBufferSupplier[(String, Option[Int], String)] + BinaryOrdering.ordSer[(String, Option[Int], String)] checkMany[(String, Option[Int], String)] checkCollisions[(String, Option[Int], String)] @@ -643,14 +763,17 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers } test("test specific tuple 3") { - val c = List(("", None, ""), + val c = List( + ("", None, ""), ("a", Some(1), "b")) - checkManyExplicit(c.map { i => (i, i) }) + checkManyExplicit(c.map { i => + (i, i) + }) } test("Test out TestCC") { import TestCC._ - primitiveOrderedBufferSupplier[TestCC] + BinaryOrdering.ordSer[TestCC] check[TestCC] checkMany[TestCC] checkCollisions[TestCC] @@ -658,34 +781,54 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers test("Test out Sealed Trait") { import TestCC._ - primitiveOrderedBufferSupplier[SealedTraitTest] + BinaryOrdering.ordSer[SealedTraitTest] check[SealedTraitTest] checkMany[SealedTraitTest] checkCollisions[SealedTraitTest] } + test("Test out Sealed TestCaseHardA") { + import TestCC._ + BinaryOrdering.ordSer[TestCaseHardA] + check[TestCaseHardA] + checkMany[TestCaseHardA] + checkCollisions[TestCaseHardA] + } + + test("Test out Sealed TestCaseHardB") { + import TestCC._ + + implicit val v: OrderedSerialization[ContainerP] = + OrderedSerialization.viaTransform(_.id, ContainerP.fromId) + + BinaryOrdering.ordSer[TestCaseHardB] + check[TestCaseHardB] + checkMany[TestCaseHardB] + checkCollisions[TestCaseHardB] + } + test("Test out CaseObject") { import TestCC._ - primitiveOrderedBufferSupplier[TestObjectE.type] + BinaryOrdering.ordSer[TestObjectE.type] check[TestObjectE.type] checkMany[TestObjectE.type] } test("Test out (Int, Int)") { - primitiveOrderedBufferSupplier[(Int, Int)] + BinaryOrdering.ordSer[(Int, Int)] check[(Int, Int)] checkCollisions[(Int, Int)] } test("Test out (String, Option[Int], String)") { - primitiveOrderedBufferSupplier[(String, Option[Int], String)] + BinaryOrdering.ordSer[(String, Option[Int], String)] check[(String, Option[Int], String)] checkCollisions[(String, Option[Int], String)] } test("Test out MyData") { import MyData._ - primitiveOrderedBufferSupplier[MyData] + BinaryOrdering.ordSer[MyData] check[MyData] checkCollisions[MyData] } @@ -696,7 +839,7 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers implicitly[OrderedSerialization[MacroOpaqueContainer]] // Put inside a tuple2 to test that - primitiveOrderedBufferSupplier[(MacroOpaqueContainer, MacroOpaqueContainer)] + BinaryOrdering.ordSer[(MacroOpaqueContainer, MacroOpaqueContainer)] check[(MacroOpaqueContainer, MacroOpaqueContainer)] checkCollisions[(MacroOpaqueContainer, MacroOpaqueContainer)] check[Option[MacroOpaqueContainer]] @@ -705,13 +848,17 @@ class MacroOrderingProperties extends FunSuite with PropertyChecks with Matchers checkCollisions[List[MacroOpaqueContainer]] } - def fn[A](implicit or: OrderedSerialization[A]): OrderedSerialization[TypedParameterCaseClass[A]] = { - primitiveOrderedBufferSupplier[TypedParameterCaseClass[A]] + test("Does not produce ordering for large sealed trait") { + noOrderedSerialization[BigTrait] } + def fn[A]( + implicit or: OrderedSerialization[A]): OrderedSerialization[TypedParameterCaseClass[A]] = + BinaryOrdering.ordSer[TypedParameterCaseClass[A]] + test("Test out MacroOpaqueContainer inside a case class as an abstract type") { fn[MacroOpaqueContainer] - primitiveOrderedBufferSupplier[(MacroOpaqueContainer, MacroOpaqueContainer)] + BinaryOrdering.ordSer[(MacroOpaqueContainer, MacroOpaqueContainer)] + () } } - diff --git a/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/macros/ZDifficultTypes.scala b/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/macros/ZDifficultTypes.scala new file mode 100644 index 0000000000..0d7ffeaeb3 --- /dev/null +++ b/scalding-serialization/src/test/scala/com/twitter/scalding/serialization/macros/ZDifficultTypes.scala @@ -0,0 +1,27 @@ +package com.twitter.scalding.some.other.space.space + +sealed trait ContainerX +object ContainerX { + case class ElementY(x: String) extends ContainerX + case class ElementZ(x: String) extends ContainerX +} + +// This is intentionally not sealed. User can supply their own +trait ContainerP { + def id: String +} +object ContainerP { + case object ElementA extends ContainerP { + def id: String = "a" + } + case object ElementB extends ContainerP { + def id: String = "b" + } + def fromId(id: String): ContainerP = id match { + case _ if id == ElementA.id => ElementA + case _ if id == ElementB.id => ElementB + } +} + +case class TestCaseHardA(e: ContainerX, y: String) +case class TestCaseHardB(e: ContainerP, y: String) diff --git a/scalding-thrift-macros/src/main/scala/com/twitter/scalding/thrift/macros/impl/ordered_serialization/ScroogeOrderedBuf.scala b/scalding-thrift-macros/src/main/scala/com/twitter/scalding/thrift/macros/impl/ordered_serialization/ScroogeOrderedBuf.scala index f74218e6e1..2a979aa99b 100644 --- a/scalding-thrift-macros/src/main/scala/com/twitter/scalding/thrift/macros/impl/ordered_serialization/ScroogeOrderedBuf.scala +++ b/scalding-thrift-macros/src/main/scala/com/twitter/scalding/thrift/macros/impl/ordered_serialization/ScroogeOrderedBuf.scala @@ -61,11 +61,11 @@ object ScroogeOrderedBuf { outerType .declarations .collect { case m: MethodSymbol => m } - .filter(m => fieldNames.contains(m.name.toTermName.toString.toLowerCase)) + .filter(m => fieldNames.contains(m.name.toString.toLowerCase)) .map { accessorMethod => val fieldType = accessorMethod.returnType.asSeenFrom(outerType, outerType.typeSymbol.asClass) val b: TreeOrderedBuf[c.type] = dispatcher(fieldType) - (fieldType, accessorMethod.name.toTermName, b) + (fieldType, accessorMethod.name, b) }.toList new TreeOrderedBuf[c.type] { diff --git a/scalding-thrift-macros/src/main/scala/com/twitter/scalding/thrift/macros/impl/ordered_serialization/ScroogeUnionOrderedBuf.scala b/scalding-thrift-macros/src/main/scala/com/twitter/scalding/thrift/macros/impl/ordered_serialization/ScroogeUnionOrderedBuf.scala index 6399e35b3f..77520c0b0d 100644 --- a/scalding-thrift-macros/src/main/scala/com/twitter/scalding/thrift/macros/impl/ordered_serialization/ScroogeUnionOrderedBuf.scala +++ b/scalding-thrift-macros/src/main/scala/com/twitter/scalding/thrift/macros/impl/ordered_serialization/ScroogeUnionOrderedBuf.scala @@ -41,16 +41,16 @@ object ScroogeUnionOrderedBuf { val dispatcher = buildDispatcher val subClasses: List[Type] = StableKnownDirectSubclasses(c)(outerType).map(_.toType) - + val subData: List[(Int, Type, Option[TreeOrderedBuf[c.type]])] = subClasses.map { t => if (t.typeSymbol.name.toString == "UnknownUnionField") { (t, None) } else { (t, Some(dispatcher(t))) } - }.zipWithIndex.map{ case ((tpe, tbuf), idx) => (idx, tpe, tbuf) }.toList + }.zipWithIndex.map { case ((tpe, tbuf), idx) => (idx, tpe, tbuf) } - require(subData.size > 0, "Must have some sub types on a union?") + require(subData.nonEmpty, "Must have some sub types on a union?") new TreeOrderedBuf[c.type] { override val ctx: c.type = c diff --git a/scalding-thrift-macros/src/main/scala/com/twitter/scalding/thrift/macros/impl/ordered_serialization/UnionLike.scala b/scalding-thrift-macros/src/main/scala/com/twitter/scalding/thrift/macros/impl/ordered_serialization/UnionLike.scala index 55c0ac7a57..9882dd901c 100644 --- a/scalding-thrift-macros/src/main/scala/com/twitter/scalding/thrift/macros/impl/ordered_serialization/UnionLike.scala +++ b/scalding-thrift-macros/src/main/scala/com/twitter/scalding/thrift/macros/impl/ordered_serialization/UnionLike.scala @@ -146,7 +146,7 @@ object UnionLike { } // This `_.get` could be removed by switching `subData` to a non-empty list type - @SuppressWarnings(Array("org.wartremover.warts.OptionPartial")) + @SuppressWarnings(Array("org.wartremover.warts.OptionPartial", "org.wartremover.warts.Return")) def length(c: Context)(element: c.Tree)(subData: List[(Int, c.Type, Option[TreeOrderedBuf[c.type]])]): CompileTimeLengthTypes[c.type] = { import CompileTimeLengthTypes._ import c.universe._ diff --git a/tutorial/execution-tutorial/ExecutionTutorial.scala b/tutorial/execution-tutorial/ExecutionTutorial.scala index 45d676590b..d3fe2133c2 100644 --- a/tutorial/execution-tutorial/ExecutionTutorial.scala +++ b/tutorial/execution-tutorial/ExecutionTutorial.scala @@ -37,10 +37,10 @@ Run: **/ object MyExecJob extends ExecutionApp { - + override def job = Execution.getConfig.flatMap { config => val args = config.getArgs - + TypedPipe.from(TextLine(args("input"))) .flatMap(_.split("\\s+")) .map((_, 1L)) @@ -48,15 +48,14 @@ object MyExecJob extends ExecutionApp { .toIterableExecution // toIterableExecution will materialize the outputs to submitter node when finish. // We can also write the outputs on HDFS via .writeExecution(TypedTsv(args("output"))) - .onComplete { t => t match { - case Success(iter) => + .onComplete { + case Success(iter) => val file = new PrintWriter(new File(args("output"))) iter.foreach { case (k, v) => file.write(s"$k\t$v\n") } file.close case Failure(e) => println("Error: " + e.toString) - } } // use the result and map it to a Unit. Otherwise the onComplete call won't happen .unit