diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Liftables.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Liftables.scala index 38f340426e..032b2c3458 100644 --- a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Liftables.scala +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Liftables.scala @@ -2,6 +2,14 @@ package com.twitter.scalding.quotation import scala.reflect.macros.blackbox.Context +/** + * These Liftables allows us to lift values into quasiquote trees. + * For example: + * + * def test(v: Source) => q"$v" + * + * uses `sourceLiftable` + */ trait Liftables { val c: Context import c.universe.{ TypeName => _, _ } diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Projection.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Projection.scala index 5737e7dd1f..fa39a391a4 100644 --- a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Projection.scala +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/Projection.scala @@ -8,6 +8,80 @@ case class TypeName(asString: String) extends AnyVal sealed trait Projection { def andThen(accessor: Accessor, typeName: TypeName): Projection = Property(this, accessor, typeName) + + def rootProjection: TypeReference = { + @tailrec def loop(p: Projection): TypeReference = + p match { + case p @ TypeReference(_) => p + case Property(p, _, _) => loop(p) + } + loop(this) + } + + /** + * Given a base projection, returns the projection based on it if applicable. + * + * For instance, given a quoted function + * `val contact = Quoted.function { (c: Contact) => c.contact }` + * and a call + * `(p: Person) => contact(p.name)` + * produces the projection + * `Person.name.contact` + */ + def basedOn(base: Projection): Option[Projection] = + this match { + case TypeReference(tpe) => + base match { + case TypeReference(`tpe`) => Some(base) + case Property(_, _, `tpe`) => Some(base) + case other => None + } + case Property(path, name, tpe) => + path.basedOn(base).map(Property(_, name, tpe)) + } + + /** + * Limits projections to only values of `superClass`. Example: + * + * case class Person(name: String, contact: Contact) extends ThriftObject + * case class Contact(phone: Phone) extends ThriftObject + * case class Phone(number: String) + * + * For the super class `ThriftObject`, it produces the transformations: + * + * Person.contact.phone => Some(Person.contact.phone) + * Person.contact.phone.number => Some(Person.contact.phone) + * Person.name.isEmpty => Some(Person.name) + * Phone.number => None + */ + def bySuperClass(superClass: Class[_]): Option[Projection] = { + + def isSubclass(c: TypeName) = + try + superClass.isAssignableFrom(Class.forName(c.asString)) + catch { + case _: ClassNotFoundException => + false + } + + def loop(p: Projection): Either[Projection, Option[Projection]] = + p match { + case TypeReference(typeName) => + Either.cond(!isSubclass(typeName), None, p) + case p @ Property(path, name, typeName) => + loop(path) match { + case Left(_) => + Either.cond(!isSubclass(typeName), Some(p), p) + case Right(path) => + Right(path) + } + } + + loop(this) match { + case Left(path) => Some(path) + case Right(opt) => opt + } + } } /** @@ -30,81 +104,21 @@ final case class Property(path: Projection, accessor: Accessor, typeName: TypeNa final class Projections private (val set: Set[Projection]) extends Serializable { /** - * Returns the projections that are based on `tpe` and limits projections + * Returns the projections that are based on `typeName` and limits projections * to only properties that extend from `superClass`. */ - def of(typeName: TypeName, superClass: Class[_]): Projections = { - - def byType(p: Projection) = { - @tailrec def loop(p: Projection): Boolean = - p match { - case TypeReference(`typeName`) => true - case TypeReference(_) => false - case Property(p, _, _) => loop(p) - } - loop(p) - } - - def bySuperClass(p: Projection): Option[Projection] = { - - def isSubclass(c: TypeName) = - try - superClass.isAssignableFrom(Class.forName(c.asString)) - catch { - case _: ClassNotFoundException => - false - } - - def loop(p: Projection): Either[Projection, Option[Projection]] = - p match { - case TypeReference(tpe) => - Either.cond(!isSubclass(tpe), None, p) - case p @ Property(path, name, tpe) => - loop(path) match { - case Left(_) => - Either.cond(!isSubclass(tpe), Some(p), p) - case Right(path) => - Right(path) - } - } - - loop(p) match { - case Left(path) => Some(path) - case Right(opt) => opt - } + def of(typeName: TypeName, superClass: Class[_]): Projections = + Projections { + set.filter(_.rootProjection.typeName == typeName) + .flatMap(_.bySuperClass(superClass)) } - Projections(set.filter(byType).flatMap(bySuperClass)) - } - - /** - * Given a set of base projections, returns the projections based on them. - * - * For instance, given a quoted function - * `val contact = Quoted.function { (c: Contact) => c.contact }` - * and a call - * `(p: Person) => contact(p.name)` - * returns the projection - * `Person.name.contact` - */ - def basedOn(base: Set[Projection]): Projections = { - def loop(base: Projection, p: Projection): Option[Projection] = - p match { - case TypeReference(tpe) => - base match { - case TypeReference(`tpe`) => Some(base) - case Property(_, _, `tpe`) => Some(base) - case other => None - } - case Property(path, name, tpe) => - loop(base, path).map(Property(_, name, tpe)) - } + def basedOn(base: Set[Projection]): Projections = Projections { set.flatMap { p => - base.flatMap(loop(_, p)) + base.flatMap(p.basedOn) } } - } def ++(p: Projections) = Projections(set ++ p.set) @@ -115,7 +129,7 @@ final class Projections private (val set: Set[Projection]) extends Serializable override def equals(other: Any) = other match { case other: Projections => set == other.set - case other => false + case other => false } override def hashCode = @@ -135,7 +149,7 @@ object Projections { p match { case Property(path, acessor, property) => set.contains(path) || isNested(path) - case _ => + case _ => false } new Projections(set.filter(!isNested(_))) diff --git a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/ProjectionMacro.scala b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/ProjectionMacro.scala index 9984007659..f4529ff2cb 100644 --- a/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/ProjectionMacro.scala +++ b/scalding-quotation/src/main/scala/com/twitter/scalding/quotation/ProjectionMacro.scala @@ -26,76 +26,88 @@ trait ProjectionMacro extends TreeOps with Liftables { .contains("scala.Function") }.getOrElse(false) - val nestedList = - params.flatMap { - case param @ q"(..$inputs) => $body" => + def functionBodyProjections(param: Tree, inputs: List[Tree], body: Tree): List[Tree] = { - val inputSymbols = inputs.map(_.symbol).toSet + val inputSymbols = inputs.map(_.symbol).toSet - object ProjectionExtractor { - def unapply(t: Tree): Option[Tree] = - t match { + object ProjectionExtractor { + def unapply(t: Tree): Option[Tree] = + t match { - case q"$v.$m(..$params)" => unapply(v) + case q"$v.$m(..$params)" => unapply(v) - case q"$v.$m" if t.symbol.isMethod => + case q"$v.$m" if t.symbol.isMethod => - if (inputSymbols.contains(v.symbol)) { - val p = - TypeReference(typeName(v)) - .andThen(accessor(m), typeName(t)) - Some(q"$p") - } else - unapply(v).map { n => - q"$n.andThen(${accessor(m)}, ${typeName(t)})" - } + if (inputSymbols.contains(v.symbol)) { + val p = + TypeReference(typeName(v)) + .andThen(accessor(m), typeName(t)) + Some(q"$p") + } else + unapply(v).map { n => + q"$n.andThen(${accessor(m)}, ${typeName(t)})" + } - case t if inputSymbols.contains(t.symbol) => - Some(q"${TypeReference(typeName(t))}") + case t if inputSymbols.contains(t.symbol) => + Some(q"${TypeReference(typeName(t))}") - case _ => None - } + case _ => None } + } - def functionCall(func: Tree, params: List[Tree]) = { - val paramProjections = params.flatMap(ProjectionExtractor.unapply) - q""" - $func match { - case f: _root_.com.twitter.scalding.quotation.QuotedFunction => - f.quoted.projections.basedOn($paramProjections.toSet) - case _ => - _root_.com.twitter.scalding.quotation.Projections(Set(..$paramProjections)) - } - """ + def functionCall(func: Tree, params: List[Tree]): Tree = { + val paramProjections = params.flatMap(ProjectionExtractor.unapply) + q""" + $func match { + case f: _root_.com.twitter.scalding.quotation.QuotedFunction => + f.quoted.projections.basedOn($paramProjections.toSet) + case _ => + _root_.com.twitter.scalding.quotation.Projections(Set(..$paramProjections)) } + """ + } - collect(body) { - case q"$func.apply[..$t](..$params)" => - functionCall(func, params) - case q"$func(..$params)" if isFunction(func) => - functionCall(func, params) - case t @ ProjectionExtractor(p) => - q"_root_.com.twitter.scalding.quotation.Projections(Set($p))" - } + collect(body) { + case q"$func.apply[..$t](..$params)" => + functionCall(func, params) + case q"$func(..$params)" if isFunction(func) => + functionCall(func, params) + case t @ ProjectionExtractor(p) => + q"_root_.com.twitter.scalding.quotation.Projections(Set($p))" + } + } + + def functionInstanceProjections(func: Tree): List[Tree] = { + val paramProjections = + func.symbol.typeSignature.typeArgs.dropRight(1) + .map(typeReference) + q""" + $func match { + case f: _root_.com.twitter.scalding.quotation.QuotedFunction => + f.quoted.projections + case _ => + _root_.com.twitter.scalding.quotation.Projections(Set(..$paramProjections)) + } + """ :: Nil + } + + def methodProjections(method: Tree): List[Tree] = { + val paramRefs = + method.symbol.asMethod.paramLists.flatten + .map(param => typeReference(param.typeSignature)) + q"${Projections(paramRefs.toSet)}" :: Nil + } + + val nestedList = + params.flatMap { + case param @ q"(..$inputs) => $body" => + functionBodyProjections(param, inputs, body) case func if isFunction(func) => - val paramProjections = - func.symbol.typeSignature.typeArgs.dropRight(1) - .map(typeReference) - q""" - $func match { - case f: _root_.com.twitter.scalding.quotation.QuotedFunction => - f.quoted.projections - case _ => - _root_.com.twitter.scalding.quotation.Projections(Set(..$paramProjections)) - } - """ :: Nil + functionInstanceProjections(func) case method if method.symbol != null && method.symbol.isMethod => - val paramRefs = - method.symbol.asMethod.paramLists.flatten - .map(param => typeReference(param.typeSignature)) - q"${Projections(paramRefs.toSet)}" :: Nil + methodProjections(method) case other => Nil