Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions src/expression/ekat_expression_base.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#ifndef EKAT_EXPRESSION_BASE_HPP
#define EKAT_EXPRESSION_BASE_HPP

#include "ekat_expression_traits.hpp"

#include <Kokkos_Core.hpp>

namespace ekat {

// A base class for expressions
template<typename Derived>
class ExpressionBase {
public:
using expression_tag = void; // Add tag to be used for SFINAE and meta-utils

ExpressionBase () {
static_assert(is_expr_v<Derived>,
"Template arg is NOT an expression. Ensure Derived inherits from ExpressionBase.");
}

Derived& cast () { return *static_cast<Derived*>(this); }
const Derived& cast () const { return *static_cast<const Derived*>(this); }

ExpressionBase<Derived>& as_base () { return *this; }
const ExpressionBase<Derived>& as_base () const { return *this; }

static constexpr int rank() { return Derived::rank(); }

int extent (int i) const { return cast().extent(i); }

template<typename... Args>
KOKKOS_INLINE_FUNCTION
auto eval (Args... args) const
{
return cast().eval(args...);
}
};

} // namespace ekat

#endif // EKAT_EXPRESSION_BASE_HPP
86 changes: 64 additions & 22 deletions src/expression/ekat_expression_binary_op.hpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
#ifndef EKAT_EXPRESSION_BINARY_OP_HPP
#define EKAT_EXPRESSION_BINARY_OP_HPP

#include "ekat_expression_meta.hpp"

#include <Kokkos_Core.hpp>
#include "ekat_expression_base.hpp"

namespace ekat {

Expand All @@ -15,21 +13,19 @@ enum class BinOp {
};

template<typename ELeft, typename ERight, BinOp OP>
class BinaryExpression {
class BinaryExpression : public ExpressionBase<BinaryExpression<ELeft,ERight,OP>> {
public:
static constexpr bool expr_l = is_expr_v<ELeft>;
static constexpr bool expr_r = is_expr_v<ERight>;
static_assert (expr_l or expr_r, "[BinaryExpression] Error! At least one operand must be an Expression type.\n");

Comment thread
bartgol marked this conversation as resolved.
using eval_left_t = eval_return_t<ELeft>;
using eval_right_t = eval_return_t<ERight>;

using eval_t = std::common_type_t<eval_left_t,eval_right_t>;
using return_left_t = eval_return_t<ELeft>;
using return_right_t = eval_return_t<ERight>;

// Don't create an expression from builtin types, just combine them!
static_assert (expr_l or expr_r,
"[BinaryExpression] At least one between ELeft and ERight must be an Expression type.\n");
using return_type = std::common_type_t<return_left_t,return_right_t>;

BinaryExpression (const ELeft& left, const ERight& right)
BinaryExpression (const ELeft& left,
const ERight& right)
: m_left(left)
, m_right(right)
{
Expand All @@ -56,7 +52,7 @@ class BinaryExpression {

template<typename... Args>
KOKKOS_INLINE_FUNCTION
eval_t eval(Args... args) const {
return_type eval(Args... args) const {
if constexpr (not expr_l) {
return eval_impl(m_left,m_right.eval(args...));
} else if constexpr (not expr_r) {
Expand All @@ -69,31 +65,77 @@ class BinaryExpression {
protected:

KOKKOS_INLINE_FUNCTION
eval_t eval_impl (const eval_left_t& l, const eval_right_t& r) const {
return_type eval_impl (const return_left_t& l, const return_right_t& r) const {
if constexpr (OP==BinOp::Plus) {
return l+r;
} else if constexpr (OP==BinOp::Minus) {
return l-r;
} else if constexpr (OP==BinOp::Mult) {
return l*r;
} else if constexpr (OP==BinOp::Div) {
} else {
return l/r;
return Kokkos::min(static_cast<const eval_t&>(l),static_cast<const eval_t&>(r));
}
}

ELeft m_left;
ERight m_right;
};

// Specialize meta utils
template<typename ELeft, typename ERight, BinOp OP>
struct is_expr<BinaryExpression<ELeft,ERight,OP>> : std::true_type {};
template<typename ELeft, typename ERight, BinOp OP>
struct eval_return<BinaryExpression<ELeft,ERight,OP>> {
using type = typename BinaryExpression<ELeft,ERight,OP>::eval_t;
// We could impl op- via BinaryOp (with -1*Expr), but a dedicated class is easier
template<typename EInner>
class NegateExpression : public ExpressionBase<NegateExpression<EInner>> {
public:

using return_type = eval_return_t<EInner>;

NegateExpression (const ExpressionBase<EInner>& inner)
: m_inner(inner.cast())
{
// Nothing to do here
}

static constexpr int rank () { return EInner::rank(); }
int extent (int i) const { return m_inner.extent(i); }

template<typename... Args>
KOKKOS_INLINE_FUNCTION
return_type eval(Args... args) const {
return -m_inner.eval(args...);
}

protected:
EInner m_inner;
};

// Unary minus implemented as -1*expr
template<typename ERight>
KOKKOS_INLINE_FUNCTION
auto operator- (const ExpressionBase<ERight>& r)
{
return NegateExpression(r);
}

// Overload arithmetic operators
#define EKAT_GEN_BIN_OP_EXPR(OP,ENUM) \
template<typename T1, typename T2, \
typename = std::enable_if_t<is_any_expr_v<T1,T2>>> \
KOKKOS_INLINE_FUNCTION \
auto operator OP (const T1& l, const T2& r) \
{ \
using ret_t = BinaryExpression<get_expr_node_t<T1>, \
get_expr_node_t<T2>, \
BinOp::ENUM>; \
\
return ret_t(get_expr_node(l),get_expr_node(r)); \
}

EKAT_GEN_BIN_OP_EXPR(+,Plus);
EKAT_GEN_BIN_OP_EXPR(-,Minus);
EKAT_GEN_BIN_OP_EXPR(*,Mult);
EKAT_GEN_BIN_OP_EXPR(/,Div);

#undef EKAT_GEN_BIN_OP_EXPR

} // namespace ekat

#endif // EKAT_EXPRESSION_BINARY_OP_HPP
132 changes: 132 additions & 0 deletions src/expression/ekat_expression_binary_predicate.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
#ifndef EKAT_EXPRESSION_BINARY_PREDICATE_HPP
#define EKAT_EXPRESSION_BINARY_PREDICATE_HPP

#include "ekat_expression_base.hpp"

namespace ekat {

enum class BinaryPredicateOp : int {
EQ, // ==
NE, // !=
GT, // >
GE, // >=
LT, // <
LE, // <=
AND, // logical and
OR // logical or
};

template<typename ELeft, typename ERight, BinaryPredicateOp Op>
class BinaryPredicateExpression : public ExpressionBase<BinaryPredicateExpression<ELeft,ERight,Op>> {
public:
Comment thread
bartgol marked this conversation as resolved.
static constexpr bool expr_l = is_expr_v<ELeft>;
static constexpr bool expr_r = is_expr_v<ERight>;

using return_left_t = eval_return_t<ELeft>;
using return_right_t = eval_return_t<ERight>;
// The return type is a logical-like, but same for all Op's, so just use one
using return_type = decltype(std::declval<return_left_t>()==std::declval<return_right_t>());

// Don't create an expression from builtin types, just compare them!
static_assert(expr_l or expr_r,
"[BinaryPredicateExpression] At least one between ELeft and ERight must be an Expression type.\n");

BinaryPredicateExpression (const ELeft& left,
const ERight& right)
: m_left(left)
, m_right(right)
{
// Nothing to do here
}

static constexpr int rank() {
if constexpr (expr_l) {
if constexpr (expr_r) {
static_assert(ELeft::rank()==ERight::rank(),
"[BinaryPredicateExpression] Error! ELeft and ERight are Expression types of different rank.\n");
}
return ELeft::rank();
} else if constexpr (expr_r) {
return ERight::rank();
} else {
return 0;
}
}

int extent (int i) const {
if constexpr (expr_l)
return m_left.extent(i);
else
return m_right.extent(i);
}

template<typename... Args>
KOKKOS_INLINE_FUNCTION
return_type eval(Args... args) const {
if constexpr (expr_l) {
if constexpr (expr_r)
return eval_impl(m_left.eval(args...), m_right.eval(args...));
else
return eval_impl(m_left.eval(args...), m_right);
} else if constexpr (expr_r) {
return eval_impl(m_left, m_right.eval(args...));
} else {
return eval_impl(m_left, m_right);
}
}

protected:

template<typename... Args>
KOKKOS_INLINE_FUNCTION
return_type eval_impl(const return_left_t& l, const return_right_t& r) const {
if constexpr (Op==BinaryPredicateOp::EQ)
return l==r;
else if constexpr (Op==BinaryPredicateOp::NE)
return l!=r;
else if constexpr (Op==BinaryPredicateOp::GT)
return l>r;
else if constexpr (Op==BinaryPredicateOp::GE)
return l>=r;
else if constexpr (Op==BinaryPredicateOp::LT)
return l<r;
else if constexpr (Op==BinaryPredicateOp::LE)
return l<=r;
else if constexpr (Op==BinaryPredicateOp::AND)
return l and r;
else
return l or r;
}

ELeft m_left;
ERight m_right;
};

// Overload comparison operators
#define EKAT_GEN_BIN_PREDICATE_EXPR(OP,ENUM) \
template<typename T1, typename T2, \
typename = std::enable_if_t<is_any_expr_v<T1,T2>>> \
KOKKOS_INLINE_FUNCTION \
auto operator OP (const T1& l, const T2& r) \
{ \
using ret_t = BinaryPredicateExpression<get_expr_node_t<T1>, \
get_expr_node_t<T2>, \
BinaryPredicateOp::ENUM>; \
\
return ret_t(get_expr_node(l),get_expr_node(r)); \
}

EKAT_GEN_BIN_PREDICATE_EXPR(==,EQ);
EKAT_GEN_BIN_PREDICATE_EXPR(!=,NE);
EKAT_GEN_BIN_PREDICATE_EXPR(> ,GT);
EKAT_GEN_BIN_PREDICATE_EXPR(>=,GE);
EKAT_GEN_BIN_PREDICATE_EXPR(< ,LT);
EKAT_GEN_BIN_PREDICATE_EXPR(<=,LE);
EKAT_GEN_BIN_PREDICATE_EXPR(&&,AND);
EKAT_GEN_BIN_PREDICATE_EXPR(||,OR);

#undef EKAT_GEN_BIN_PREDICATE_EXPR

} // namespace ekat

#endif // EKAT_EXPRESSION_BINARY_PREDICATE_HPP
Loading
Loading