From 41dcd5410f8b4816175469c770a9bc2eec1c0de4 Mon Sep 17 00:00:00 2001 From: Luca Bertagna Date: Fri, 1 May 2026 17:34:25 -0600 Subject: [PATCH 1/6] Add a base class for expressions --- src/expression/ekat_expression_base.hpp | 41 +++++++++++++++ src/expression/ekat_expression_traits.hpp | 64 +++++++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 src/expression/ekat_expression_base.hpp create mode 100644 src/expression/ekat_expression_traits.hpp diff --git a/src/expression/ekat_expression_base.hpp b/src/expression/ekat_expression_base.hpp new file mode 100644 index 00000000..8d7beaa0 --- /dev/null +++ b/src/expression/ekat_expression_base.hpp @@ -0,0 +1,41 @@ +#ifndef EKAT_EXPRESSION_BASE_HPP +#define EKAT_EXPRESSION_BASE_HPP + +#include "ekat_expression_traits.hpp" + +#include + +namespace ekat { + +// A base class for expressions +template +class ExpressionBase { +public: + using expression_tag = void; // Add tag to be used for SFINAE and meta-utils + + ExpressionBase () { + static_assert(is_expr_v, + "Template arg is NOT an expression. Ensure Derived inherits from ExpressionBase."); + } + + Derived& cast () { return *static_cast(this); } + const Derived& cast () const { return *static_cast(this); } + + ExpressionBase& as_base () { return *this; } + const ExpressionBase& as_base () const { return *this; } + + static constexpr int rank() { return Derived::rank(); } + + int extent (int i) const { return cast().extent(i); } + + template + KOKKOS_INLINE_FUNCTION + auto eval (Args... args) const + { + return cast().eval(args...); + } +}; + +} // namespace ekat + +#endif // EKAT_EXPRESSION_BASE_HPP diff --git a/src/expression/ekat_expression_traits.hpp b/src/expression/ekat_expression_traits.hpp new file mode 100644 index 00000000..e89eb784 --- /dev/null +++ b/src/expression/ekat_expression_traits.hpp @@ -0,0 +1,64 @@ +#ifndef EKAT_EXPRESSION_TRAITS_HPP +#define EKAT_EXPRESSION_TRAITS_HPP + +#include + +namespace ekat { + +// TODO: C++20 comes with std::type_identity +template +struct identity { using type = T; }; + +// Meta-utilities for expressions +template +struct is_expr : std::false_type {}; + +template +struct is_expr> : std::true_type {}; + +template +inline constexpr bool is_expr_v = is_expr::value; + +template +inline constexpr bool is_any_expr_v = (is_expr_v || ...); + +// Primary template: Handle raw scalars (non-expressions) +template +struct eval_return : identity {}; + +// Specialization: For when IsExpr is 'true' +template +struct eval_return { + // Hard contract: T must have return_type or this will fail loudly + using type = typename T::return_type; +}; + +template +using eval_return_t = typename eval_return>::type; + +// If type is an expression, extract inner type, otherwise return input type +template +struct get_expr_node_trait : identity {}; + +template class Base, typename D> +struct get_expr_node_trait, true> : identity {}; + +template +using get_expr_node_t = typename get_expr_node_trait::type; + +// Sometimes we need to access the concrete derived expression, +// as they are the ONLY types that define return_type (we can't +// access derived's type from the base class, even in CRTP. +template +auto get_expr_node (const T& t) { +if constexpr (is_expr_v) { + return static_cast&>(t); + } else { + return t; + } +} + + +} // namespace ekat + +#endif // EKAT_EXPRESSION_TRAITS_HPP From 651ba98ff5d5d35af11de26b0c9b811251507d61 Mon Sep 17 00:00:00 2001 From: Luca Bertagna Date: Fri, 1 May 2026 17:34:58 -0600 Subject: [PATCH 2/6] Make existing expressions inherit from the base class --- src/expression/ekat_expression_binary_op.hpp | 86 +++++-- src/expression/ekat_expression_compare.hpp | 55 +++-- .../ekat_expression_conditional.hpp | 41 ++-- src/expression/ekat_expression_eval.hpp | 13 +- src/expression/ekat_expression_helpers.hpp | 109 --------- src/expression/ekat_expression_math.hpp | 210 ++++++++---------- src/expression/ekat_expression_meta.hpp | 19 -- src/expression/ekat_expression_view.hpp | 23 +- tests/expression/expressions.cpp | 46 +++- 9 files changed, 272 insertions(+), 330 deletions(-) delete mode 100644 src/expression/ekat_expression_helpers.hpp delete mode 100644 src/expression/ekat_expression_meta.hpp diff --git a/src/expression/ekat_expression_binary_op.hpp b/src/expression/ekat_expression_binary_op.hpp index 2f1f311f..ffc51bc0 100644 --- a/src/expression/ekat_expression_binary_op.hpp +++ b/src/expression/ekat_expression_binary_op.hpp @@ -1,9 +1,7 @@ #ifndef EKAT_EXPRESSION_BINARY_OP_HPP #define EKAT_EXPRESSION_BINARY_OP_HPP -#include "ekat_expression_meta.hpp" - -#include +#include "ekat_expression_base.hpp" namespace ekat { @@ -15,21 +13,19 @@ enum class BinOp { }; template -class BinaryExpression { +class BinaryExpression : public ExpressionBase> { public: static constexpr bool expr_l = is_expr_v; static constexpr bool expr_r = is_expr_v; + static_assert (expr_l or expr_r, "[BinaryExpression] Error! At least one operand must be an Expression type.\n"); - using eval_left_t = eval_return_t; - using eval_right_t = eval_return_t; - - using eval_t = std::common_type_t; + using return_left_t = eval_return_t; + using return_right_t = eval_return_t; - // Don't create an expression from builtin types, just combine them! - static_assert (expr_l or expr_r, - "[BinaryExpression] At least one between ELeft and ERight must be an Expression type.\n"); + using return_type = std::common_type_t; - BinaryExpression (const ELeft& left, const ERight& right) + BinaryExpression (const ELeft& left, + const ERight& right) : m_left(left) , m_right(right) { @@ -56,7 +52,7 @@ class BinaryExpression { template KOKKOS_INLINE_FUNCTION - eval_t eval(Args... args) const { + return_type eval(Args... args) const { if constexpr (not expr_l) { return eval_impl(m_left,m_right.eval(args...)); } else if constexpr (not expr_r) { @@ -69,16 +65,15 @@ class BinaryExpression { protected: KOKKOS_INLINE_FUNCTION - eval_t eval_impl (const eval_left_t& l, const eval_right_t& r) const { + return_type eval_impl (const return_left_t& l, const return_right_t& r) const { if constexpr (OP==BinOp::Plus) { return l+r; } else if constexpr (OP==BinOp::Minus) { return l-r; } else if constexpr (OP==BinOp::Mult) { return l*r; - } else if constexpr (OP==BinOp::Div) { + } else { return l/r; - return Kokkos::min(static_cast(l),static_cast(r)); } } @@ -86,14 +81,61 @@ class BinaryExpression { ERight m_right; }; -// Specialize meta utils -template -struct is_expr> : std::true_type {}; -template -struct eval_return> { - using type = typename BinaryExpression::eval_t; +// We could impl op- via BinaryOp (with -1*Expr), but a dedicated class is easier +template +class NegateExpression : public ExpressionBase> { +public: + + using return_type = eval_return_t; + + NegateExpression (const ExpressionBase& inner) + : m_inner(inner.cast()) + { + // Nothing to do here + } + + static constexpr int rank () { return EInner::rank(); } + int extent (int i) const { return m_inner.extent(i); } + + template + KOKKOS_INLINE_FUNCTION + return_type eval(Args... args) const { + return -m_inner.eval(args...); + } + +protected: + EInner m_inner; }; +// Unary minus implemented as -1*expr +template +KOKKOS_INLINE_FUNCTION +auto operator- (const ExpressionBase& r) +{ + return NegateExpression(r); +} + +// Overload arithmetic operators +#define EKAT_GEN_BIN_OP_EXPR(OP,ENUM) \ + template>> \ + KOKKOS_INLINE_FUNCTION \ + auto operator OP (const T1& l, const T2& r) \ + { \ + using ret_t = BinaryExpression, \ + get_expr_node_t, \ + BinOp::ENUM>; \ + \ + return ret_t(get_expr_node(l),get_expr_node(r)); \ + } + +EKAT_GEN_BIN_OP_EXPR(+,Plus); +EKAT_GEN_BIN_OP_EXPR(-,Minus); +EKAT_GEN_BIN_OP_EXPR(*,Mult); +EKAT_GEN_BIN_OP_EXPR(/,Div); + +#undef EKAT_GEN_BIN_OP_EXPR + } // namespace ekat #endif // EKAT_EXPRESSION_BINARY_OP_HPP diff --git a/src/expression/ekat_expression_compare.hpp b/src/expression/ekat_expression_compare.hpp index 2a833844..e27d7784 100644 --- a/src/expression/ekat_expression_compare.hpp +++ b/src/expression/ekat_expression_compare.hpp @@ -1,14 +1,12 @@ #ifndef EKAT_EXPRESSION_COMPARE_HPP #define EKAT_EXPRESSION_COMPARE_HPP -#include "ekat_expression_meta.hpp" +#include "ekat_expression_base.hpp" #include "ekat_std_utils.hpp" #include "ekat_kernel_assert.hpp" #include "ekat_assert.hpp" -#include - namespace ekat { enum class Comparison : int { @@ -21,20 +19,22 @@ enum class Comparison : int { }; template -class CmpExpression { +class CmpExpression : public ExpressionBase> { public: static constexpr bool expr_l = is_expr_v; static constexpr bool expr_r = is_expr_v; - using eval_left_t = eval_return_t; - using eval_right_t = eval_return_t; - using eval_t = decltype(std::declval()==std::declval()); + using return_left_t = eval_return_t; + using return_right_t = eval_return_t; + using return_type = decltype(std::declval()==std::declval()); // Don't create an expression from builtin types, just compare them! static_assert(expr_l or expr_r, "[CmpExpression] At least one between ELeft and ERight must be an Expression type.\n"); - CmpExpression (const ELeft& left, const ERight& right, Comparison CMP) + CmpExpression (const ELeft& left, + const ERight& right, + Comparison CMP) : m_left(left) , m_right(right) , m_cmp(CMP) @@ -48,7 +48,7 @@ class CmpExpression { static constexpr int rank() { if constexpr (expr_l) { if constexpr (expr_r) { - static_assert(ELeft::rank()==ERight::rank(), + static_assert(ELeft::rank()==ERight::rank() or ELeft::rank()==0 or ERight::rank()==0, "[CmpExpression] Error! ELeft and ERight are Expression types of different rank.\n"); } return ELeft::rank(); @@ -68,13 +68,12 @@ class CmpExpression { template KOKKOS_INLINE_FUNCTION - eval_t eval(Args... args) const { + return_type eval(Args... args) const { if constexpr (expr_l) { - if constexpr (expr_r) { + if constexpr (expr_r) return eval_impl(m_left.eval(args...), m_right.eval(args...)); - } else { + else return eval_impl(m_left.eval(args...), m_right); - } } else if constexpr (expr_r) { return eval_impl(m_left, m_right.eval(args...)); } else { @@ -86,7 +85,7 @@ class CmpExpression { template KOKKOS_INLINE_FUNCTION - eval_t eval_impl(const eval_left_t& l, const eval_right_t& r) const { + return_type eval_impl(const return_left_t& l, const return_right_t& r) const { switch (m_cmp) { case Comparison::EQ: return l==r; case Comparison::NE: return l!=r; @@ -105,13 +104,27 @@ class CmpExpression { Comparison m_cmp; }; -// Specialize meta utils -template -struct is_expr> : std::true_type {}; -template -struct eval_return> { - using type = typename CmpExpression::eval_t; -}; +// Overload comparison operators +#define EKAT_GEN_CMP_OP_EXPR(OP,ENUM) \ + template>> \ + KOKKOS_INLINE_FUNCTION \ + auto operator OP (const T1& l, const T2& r) \ + { \ + using ret_t = CmpExpression, \ + get_expr_node_t>; \ + \ + return ret_t(get_expr_node(l),get_expr_node(r),Comparison::ENUM); \ + } + +EKAT_GEN_CMP_OP_EXPR(==,EQ); +EKAT_GEN_CMP_OP_EXPR(!=,NE); +EKAT_GEN_CMP_OP_EXPR(> ,GT); +EKAT_GEN_CMP_OP_EXPR(>=,GE); +EKAT_GEN_CMP_OP_EXPR(< ,LT); +EKAT_GEN_CMP_OP_EXPR(<=,LE); + +#undef EKAT_GEN_CMP_OP_EXPR } // namespace ekat diff --git a/src/expression/ekat_expression_conditional.hpp b/src/expression/ekat_expression_conditional.hpp index 05f9fc3c..076adb43 100644 --- a/src/expression/ekat_expression_conditional.hpp +++ b/src/expression/ekat_expression_conditional.hpp @@ -1,28 +1,26 @@ #ifndef EKAT_EXPRESSION_CONDITIONAL_HPP #define EKAT_EXPRESSION_CONDITIONAL_HPP -#include "ekat_expression_meta.hpp" - -#include +#include "ekat_expression_base.hpp" namespace ekat { template -class ConditionalExpression { +class ConditionalExpression : public ExpressionBase> { public: static constexpr bool expr_c = is_expr_v; static constexpr bool expr_l = is_expr_v; static constexpr bool expr_r = is_expr_v; - using eval_cond_t = eval_return_t; - using eval_left_t = eval_return_t; - using eval_right_t = eval_return_t; + using return_cond_t = eval_return_t; + using return_left_t = eval_return_t; + using return_right_t = eval_return_t; - using eval_t = std::common_type_t; + using return_type = std::common_type_t; // Don't create an expression from builtin types, just use a ternary op! static_assert(expr_c or expr_l or expr_r, - "[CmpExpression] At least one between ECond, ELeft, and ERight must be an Expression type.\n"); + "[ConditionalExpression] At least one between ECond, ELeft, and ERight must be an Expression type.\n"); ConditionalExpression (const ECond& cmp, const ELeft& left, const ERight& right) : m_cmp(cmp) @@ -64,7 +62,7 @@ class ConditionalExpression { template KOKKOS_INLINE_FUNCTION - eval_t eval (Args... args) const + return_type eval (Args... args) const { if constexpr (expr_c) { if (m_cmp.eval(args...)) @@ -78,16 +76,17 @@ class ConditionalExpression { else return m_right; } else { - if (m_cmp) + if (m_cmp) { if constexpr (expr_l) return m_left.eval(args...); else return m_left; - else + } else { if constexpr (expr_r) return m_right.eval(args...); else return m_right; + } } } @@ -98,13 +97,17 @@ class ConditionalExpression { ERight m_right; }; -// Specialize meta utils -template -struct is_expr> : std::true_type {}; -template -struct eval_return> { - using type = typename ConditionalExpression::eval_t; -}; +// Free fcn to construct a ConditionalExpression +template>> +auto if_then_else(const TC& c, const T1& l, const T2& r) +{ + using ret_t = ConditionalExpression, + get_expr_node_t, + get_expr_node_t>; + + return ret_t(get_expr_node(c),get_expr_node(l),get_expr_node(r)); +} } // namespace ekat diff --git a/src/expression/ekat_expression_eval.hpp b/src/expression/ekat_expression_eval.hpp index 77adee75..6a4ca6b5 100644 --- a/src/expression/ekat_expression_eval.hpp +++ b/src/expression/ekat_expression_eval.hpp @@ -1,23 +1,23 @@ #ifndef EKAT_EXPRESSION_EVAL_HPP #define EKAT_EXPRESSION_EVAL_HPP -#include "ekat_expression_meta.hpp" +#include "ekat_expression_base.hpp" #include "ekat_assert.hpp" #include namespace ekat { -template -std::enable_if_t> -evaluate (const Expression& e, const ViewT& result) +template +void evaluate (const ExpressionBase& base, const ViewT& result) { + using expr_t = ExpressionBase; constexpr int N = ViewT::rank; - EKAT_REQUIRE_MSG (N==Expression::rank(), + EKAT_REQUIRE_MSG (N==expr_t::rank(), "[evaluate] Error! Input expression and result view have different ranks.\n" " - view rank: " + std::to_string(N) + "\n" - " - expression rank: " + std::to_string(Expression::rank()) + "\n"); + " - expression rank: " + std::to_string(expr_t::rank()) + "\n"); // Kokkos views don't go higher than rank 8, but just in case... static_assert(N<=8, "[evaluate] Unsupported expression rank.\n"); @@ -29,6 +29,7 @@ evaluate (const Expression& e, const ViewT& result) // Ensure the beg/end array size is > 0. While compilers may allow size-0 arrays as an extension, // it is not standard compliant. For N=0, we won't use these anyways... + const auto& e = base.cast(); int beg[N==0 ? 1 : N] = {}; int end[N==0 ? 1 : N] = {}; for (int i=0; i -std::enable_if_t,BinaryExpression> -operator- (const ERight& r) -{ - return BinaryExpression(-1,r); -} - -// Overload arithmetic operators -template -std::enable_if_t or is_expr_v,BinaryExpression> -operator+ (const ELeft& l, const ERight& r) -{ - return BinaryExpression(l,r); -} - -template -std::enable_if_t or is_expr_v,BinaryExpression> -operator- (const ELeft& l, const ERight& r) -{ - return BinaryExpression(l,r); -} - -template -std::enable_if_t or is_expr_v,BinaryExpression> -operator* (const ELeft& l, const ERight& r) -{ - return BinaryExpression(l,r); -} - -template -std::enable_if_t or is_expr_v,BinaryExpression> -operator/ (const ELeft& l, const ERight& r) -{ - return BinaryExpression(l,r); -} - -// Overload cmp operators for Expression types -template -std::enable_if_t or is_expr_v,CmpExpression> -operator== (const ELeft& l, const ERight& r) -{ - return CmpExpression(l,r,Comparison::EQ); -} - -template -std::enable_if_t or is_expr_v,CmpExpression> -operator!= (const ELeft& l, const ERight& r) -{ - return CmpExpression(l,r,Comparison::NE); -} - -template -std::enable_if_t or is_expr_v,CmpExpression> -operator> (const ELeft& l, const ERight& r) -{ - return CmpExpression(l,r,Comparison::GT); -} - -template -std::enable_if_t or is_expr_v,CmpExpression> -operator>= (const ELeft& l, const ERight& r) -{ - return CmpExpression(l,r,Comparison::GE); -} - -template -std::enable_if_t or is_expr_v,CmpExpression> -operator< (const ELeft& l, const ERight& r) -{ - return CmpExpression(l,r,Comparison::LT); -} - -template -std::enable_if_t or is_expr_v,CmpExpression> -operator<= (const ELeft& l, const ERight& r) -{ - return CmpExpression(l,r,Comparison::LE); -} - -// Free fcn to construct a ConditionalExpression -template -std::enable_if_t or is_expr_v or is_expr_v,ConditionalExpression> -conditional(const ECond& c, const ELeft& l, const ERight& r) -{ - return ConditionalExpression(c,l,r); -} - -// Free fcn to construct a ViewExpression -template -std::enable_if_t,ViewExpression> -view_expression(const ViewT& v) -{ - return ViewExpression(v); -} - -} // namespace ekat - -#endif // EKAT_EXPRESSION_HELPERS_HPP diff --git a/src/expression/ekat_expression_math.hpp b/src/expression/ekat_expression_math.hpp index ddb29c2a..6280a27a 100644 --- a/src/expression/ekat_expression_math.hpp +++ b/src/expression/ekat_expression_math.hpp @@ -1,88 +1,78 @@ #ifndef EKAT_EXPRESSION_MATH_HPP #define EKAT_EXPRESSION_MATH_HPP -#include "ekat_expression_meta.hpp" - -#include +#include "ekat_expression_base.hpp" namespace ekat { // ----------------- Binary math fcns ------------------- // #define EKAT_BINARY_MATH_EXPRESSION(impl,name) \ - template \ - class name##Expression { \ - public: \ - static constexpr bool expr_l = is_expr_v; \ - static constexpr bool expr_r = is_expr_v; \ - \ - /* Don't create an expression from builtin types, just call the math fcn! */ \ - static_assert (expr_l or expr_r, \ - "At least one between EArg1 and EArg2 must be an Expression type.\n"); \ - \ - using eval_arg1_t = eval_return_t; \ - using eval_arg2_t = eval_return_t; \ - using eval_t = std::common_type_t; \ - \ - name##Expression (const EArg1& arg1, const EArg2& arg2) \ - : m_arg1(arg1) \ - , m_arg2(arg2) \ - {} \ - \ - static constexpr int rank () { \ - if constexpr (expr_l) { \ - if constexpr (expr_r) { \ - static_assert(EArg1::rank()==EArg2::rank(), \ - "[BinaryExpression] Error! EArg1 and EArg2 have different rank.\n"); \ - } \ - return EArg1::rank(); \ - } else { \ - return EArg2::rank(); \ - } \ - } \ - int extent (int i) const { \ - if constexpr (expr_l) \ - return m_arg1.extent(i); \ - else \ - return m_arg2.extent(i); \ - } \ - \ - template \ - KOKKOS_INLINE_FUNCTION \ - eval_t eval(Args... args) const { \ - if constexpr (not expr_l) \ - return eval_impl(m_arg1,m_arg2.eval(args...)); \ - else if constexpr (not expr_r) \ - return eval_impl(m_arg1.eval(args...),m_arg2); \ - else \ - return eval_impl(m_arg1.eval(args...),m_arg2.eval(args...)); \ - } \ - protected: \ - KOKKOS_INLINE_FUNCTION \ - eval_t eval_impl(const eval_arg1_t& arg1, const eval_arg2_t& arg2) const { \ - return Kokkos::impl(static_cast(arg1), \ - static_cast(arg2)); \ - } \ - \ - EArg1 m_arg1; \ - EArg2 m_arg2; \ - }; \ - \ - /* Free function to create a ##nameExpression */ \ - template \ - std::enable_if_t or is_expr_v,name##Expression> \ - impl (const EArg1& arg1, const EArg2& arg2) \ - { \ - return name##Expression(arg1,arg2); \ - } \ - \ - /* Specialize meta util */ \ - template \ - struct is_expr> : std::true_type {}; \ - template \ - struct eval_return> { \ - using type = typename name##Expression::eval_t; \ - }; + template \ + class name##Expression : public ExpressionBase> \ + { \ + public: \ + static constexpr bool expr_l = is_expr_v; \ + static constexpr bool expr_r = is_expr_v; \ + \ + using return_arg1_t = eval_return_t; \ + using return_arg2_t = eval_return_t; \ + using return_type = std::common_type_t; \ + \ + name##Expression (const EArg1& arg1, const EArg2& arg2) \ + : m_arg1(arg1) \ + , m_arg2(arg2) \ + {} \ + \ + static constexpr int rank () { \ + if constexpr (expr_l) { \ + if constexpr (expr_r) { \ + static_assert(EArg1::rank()==EArg2::rank(), \ + "[" #name "Expression] Error! EArg1 and EArg2 have different rank.\n"); \ + } \ + return EArg1::rank(); \ + } else { \ + return EArg2::rank(); \ + } \ + } \ + int extent (int i) const { \ + if constexpr (expr_l) \ + return m_arg1.extent(i); \ + else \ + return m_arg2.extent(i); \ + } \ + \ + template \ + KOKKOS_INLINE_FUNCTION \ + return_type eval(Args... args) const { \ + if constexpr (not expr_l) \ + return eval_impl(m_arg1,m_arg2.eval(args...)); \ + else if constexpr (not expr_r) \ + return eval_impl(m_arg1.eval(args...),m_arg2); \ + else \ + return eval_impl(m_arg1.eval(args...),m_arg2.eval(args...)); \ + } \ + protected: \ + KOKKOS_INLINE_FUNCTION \ + return_type eval_impl(const return_arg1_t& arg1, \ + const return_arg2_t& arg2) const { \ + return Kokkos::impl(static_cast(arg1), \ + static_cast(arg2)); \ + } \ + \ + EArg1 m_arg1; \ + EArg2 m_arg2; \ + }; \ + \ + /* Free function to create a ##nameExpression */ \ + template>> \ + auto impl (const T1& arg1, const T2& arg2) \ + { \ + using ret_t = name##Expression, \ + get_expr_node_t>; \ + return ret_t(get_expr_node(arg1),get_expr_node(arg2)); \ + } EKAT_BINARY_MATH_EXPRESSION(pow,Pow); EKAT_BINARY_MATH_EXPRESSION(max,Max); @@ -93,43 +83,37 @@ EKAT_BINARY_MATH_EXPRESSION(min,Min); // ----------------- Unary math fcns ------------------- // #define EKAT_UNARY_MATH_EXPRESSION(impl,name) \ - template \ - class name##Expression { \ - public: \ - using arg_eval_t = eval_return_t; \ - using eval_t = decltype(Kokkos::impl(std::declval())); \ - \ - name##Expression (const EArg& arg) \ - : m_arg(arg) \ - {} \ - \ - static constexpr int rank() { return EArg::rank(); } \ - int extent (int i) const { return m_arg.extent(i); } \ - \ - template \ - KOKKOS_INLINE_FUNCTION \ - eval_t eval(Args... args) const { \ - return Kokkos::impl(m_arg.eval(args...)); \ - } \ - protected: \ - EArg m_arg; \ - }; \ - \ - /* Free function to create a ##nameExpression */ \ - template \ - std::enable_if_t,name##Expression> \ - impl (const EArg& arg) \ - { \ - return name##Expression(arg); \ - } \ - \ - /* Specialize meta util */ \ - template \ - struct is_expr> : std::true_type {}; \ - template \ - struct eval_return> { \ - using type = typename name##Expression::eval_t; \ - }; + template \ + class name##Expression : public ExpressionBase> \ + { \ + public: \ + using return_arg_type = eval_return_t; \ + using return_type = decltype(Kokkos::impl(std::declval())); \ + \ + name##Expression (const EArg& arg) \ + : m_arg(arg) \ + {} \ + \ + static constexpr int rank() { return EArg::rank(); } \ + int extent (int i) const { return m_arg.extent(i); } \ + \ + template \ + KOKKOS_INLINE_FUNCTION \ + return_type eval(Args... args) const { \ + return Kokkos::impl(m_arg.eval(args...)); \ + } \ + protected: \ + EArg m_arg; \ + }; \ + \ + /* Free function to create a ##nameExpression */ \ + template \ + std::enable_if_t, \ + name##Expression> \ + impl (const EArg& arg) \ + { \ + return name##Expression(arg); \ + } EKAT_UNARY_MATH_EXPRESSION (sqrt,Sqrt) EKAT_UNARY_MATH_EXPRESSION (exp,Exp) diff --git a/src/expression/ekat_expression_meta.hpp b/src/expression/ekat_expression_meta.hpp deleted file mode 100644 index 4001962e..00000000 --- a/src/expression/ekat_expression_meta.hpp +++ /dev/null @@ -1,19 +0,0 @@ -#ifndef EKAT_EXPRESSION_META_HPP -#define EKAT_EXPRESSION_META_HPP - -namespace ekat { - -// Detect if a type is an Expression -template -struct is_expr : std::false_type {}; -template -constexpr bool is_expr_v = is_expr::value; - -template -struct eval_return { using type = T; }; -template -using eval_return_t = typename eval_return::type; - -} // namespace ekat - -#endif // EKAT_EXPRESSION_META_HPP diff --git a/src/expression/ekat_expression_view.hpp b/src/expression/ekat_expression_view.hpp index 6715ef69..c10b7cfb 100644 --- a/src/expression/ekat_expression_view.hpp +++ b/src/expression/ekat_expression_view.hpp @@ -1,17 +1,17 @@ #ifndef EKAT_VIEW_EXPRESSION_HPP #define EKAT_VIEW_EXPRESSION_HPP -#include "ekat_expression_meta.hpp" +#include "ekat_expression_base.hpp" #include namespace ekat { template -class ViewExpression { +class ViewExpression : public ExpressionBase> { public: using view_t = ViewT; - using value_t = typename ViewT::element_type; + using return_type = typename ViewT::element_type; ViewExpression (const view_t& v) : m_view(v) @@ -24,23 +24,22 @@ class ViewExpression { template KOKKOS_INLINE_FUNCTION - const value_t& eval(Args... args) const { + const return_type& eval(Args... args) const { static_assert(sizeof...(Args)==ViewT::rank, "Something is off...\n"); return m_view(args...); } protected: - view_t m_view; }; -// Specialize meta utils -template -struct is_expr> : std::true_type {}; -template -struct eval_return> { - using type = typename ViewExpression::value_t; -}; +// Free fcn to construct a ViewExpression +template>> +auto expression(const ViewT& v) +{ + return ViewExpression(v); +} } // namespace ekat diff --git a/tests/expression/expressions.cpp b/tests/expression/expressions.cpp index e9f2b2db..d484006f 100644 --- a/tests/expression/expressions.cpp +++ b/tests/expression/expressions.cpp @@ -1,8 +1,13 @@ #include -#include "ekat_expression_helpers.hpp" #include "ekat_expression_eval.hpp" +#include "ekat_expression_binary_op.hpp" +#include "ekat_expression_compare.hpp" +#include "ekat_expression_conditional.hpp" +#include "ekat_expression_math.hpp" +#include "ekat_expression_view.hpp" + #include "ekat_view_utils.hpp" #include "ekat_kokkos_types.hpp" @@ -13,8 +18,8 @@ namespace ekat { template void bin_ops (const ViewT& x, const ViewT& y, const ViewT& z) { - auto xe = view_expression(x); - auto ye = view_expression(y); + auto xe = expression(x); + auto ye = expression(y); auto expression = xe*ye - 1/ye + 2*xe; evaluate(expression,z); @@ -27,15 +32,15 @@ void bin_ops (const ViewT& x, const ViewT& y, const ViewT& z) auto y_val = yh.data()[i]; auto z_val = zh.data()[i]; auto tgt = x_val*y_val-1/y_val+2*x_val; - REQUIRE (z_val==Approx(tgt).epsilon(1e-10)); + REQUIRE (z_val==tgt); } } template void math_fcns (const ViewT& x, const ViewT& y, const ViewT& z) { - auto xe = view_expression(x); - auto ye = view_expression(y); + auto xe = expression(x); + auto ye = expression(y); auto expression = 2*exp(-xe)*sin(xe)*log(ye)-sqrt(xe)+pow(ye,2)+pow(3,xe); evaluate(expression,z); @@ -53,12 +58,33 @@ void math_fcns (const ViewT& x, const ViewT& y, const ViewT& z) } } +template +void compare (const ViewT& x, const ViewT& y, const BViewT& z) +{ + auto xe = expression(x); + auto ye = expression(y); + auto expression = xe==ye; + + evaluate(expression,z); + + auto xh = create_host_mirror_and_copy(x); + auto yh = create_host_mirror_and_copy(y); + auto zh = create_host_mirror_and_copy(z); + for (size_t i=0; i void conditionals (const ViewT& x, const ViewT& y, const ViewT& z) { - auto xe = view_expression(x); - auto ye = view_expression(y); - auto expression = conditional(sqrt(xe)>=0.5,xe+ye,xe-ye); + auto xe = expression(x); + auto ye = expression(y); + auto expression = if_then_else(sqrt(xe)>=0.5,xe+ye,xe-ye); evaluate(expression,z); @@ -90,12 +116,14 @@ TEST_CASE("expressions", "") { kk_t::view_ND x("x"); kk_t::view_ND y("y"); kk_t::view_ND z("z"); + kk_t::view_ND zb("z"); genRandArray(x,engine,pdf); genRandArray(y,engine,pdf); bin_ops(x,y,z); math_fcns(x,y,z); + compare(x,y,zb); conditionals(x,y,z); } From 65bbc9a91e9b918bd6844d87e461b73eebf3e956 Mon Sep 17 00:00:00 2001 From: Luca Bertagna Date: Fri, 1 May 2026 18:05:03 -0600 Subject: [PATCH 3/6] Beef up expressions unit testing --- tests/expression/expressions.cpp | 204 +++++++++++++++++++++++++------ 1 file changed, 167 insertions(+), 37 deletions(-) diff --git a/tests/expression/expressions.cpp b/tests/expression/expressions.cpp index d484006f..5ef547d5 100644 --- a/tests/expression/expressions.cpp +++ b/tests/expression/expressions.cpp @@ -11,6 +11,8 @@ #include "ekat_view_utils.hpp" #include "ekat_kokkos_types.hpp" +#include "ekat_test_config.h" + #include namespace ekat { @@ -18,6 +20,8 @@ namespace ekat { template void bin_ops (const ViewT& x, const ViewT& y, const ViewT& z) { + auto eps = std::numeric_limits::epsilon(); + auto tol = 1e5*eps; auto xe = expression(x); auto ye = expression(y); auto expression = xe*ye - 1/ye + 2*xe; @@ -32,7 +36,7 @@ void bin_ops (const ViewT& x, const ViewT& y, const ViewT& z) auto y_val = yh.data()[i]; auto z_val = zh.data()[i]; auto tgt = x_val*y_val-1/y_val+2*x_val; - REQUIRE (z_val==tgt); + REQUIRE_THAT (z_val, Catch::Matchers::WithinRel(tgt,tol)); } } @@ -63,19 +67,92 @@ void compare (const ViewT& x, const ViewT& y, const BViewT& z) { auto xe = expression(x); auto ye = expression(y); - auto expression = xe==ye; - - evaluate(expression,z); - auto xh = create_host_mirror_and_copy(x); auto yh = create_host_mirror_and_copy(y); - auto zh = create_host_mirror_and_copy(z); - for (size_t i=0; iye; + evaluate(expression,z); + + auto zh = create_host_mirror_and_copy(z); + for (size_t i=0; iy_val; + REQUIRE (z_val==tgt); + } + } + // GE + { + auto expression = xe>=ye; + evaluate(expression,z); + + auto zh = create_host_mirror_and_copy(z); + for (size_t i=0; i=y_val; + REQUIRE (z_val==tgt); + } + } + // LT + { + auto expression = xe=0.5,xe+ye,xe-ye); - - evaluate(expression,z); - auto xh = create_host_mirror_and_copy(x); auto yh = create_host_mirror_and_copy(y); - auto zh = create_host_mirror_and_copy(z); - for (size_t i=0; i=0.5 ? x_val+y_val : x_val-y_val; - REQUIRE (z_val==tgt); + + // All expressions + { + auto expression = if_then_else(sqrt(xe)>=0.5,xe+ye,xe-ye); + evaluate(expression,z); + auto zh = create_host_mirror_and_copy(z); + for (size_t i=0; i=0.5 ? x_val+y_val : x_val-y_val; + REQUIRE (z_val==tgt); + } + } + // cond(expr,expr,scalar) + { + auto expression = if_then_else(sqrt(xe)>=0.5,xe+ye,-3); + evaluate(expression,z); + auto zh = create_host_mirror_and_copy(z); + for (size_t i=0; i=0.5 ? x_val+y_val : -3; + REQUIRE (z_val==tgt); + } + } + // cond(bool,expr,expr) + { + auto expression = if_then_else(false,xe+ye,xe-ye); + evaluate(expression,z); + auto zh = create_host_mirror_and_copy(z); + for (size_t i=0; i; SECTION ("0d") { - kk_t::view_ND x("x"); - kk_t::view_ND y("y"); - kk_t::view_ND z("z"); + printf("running od tests with rng seed: %d\n",seed); + + kk_t::view_ND x ("x"); + kk_t::view_ND y ("y"); + kk_t::view_ND z ("z"); kk_t::view_ND zb("z"); genRandArray(x,engine,pdf); @@ -128,41 +246,53 @@ TEST_CASE("expressions", "") { } SECTION ("1d") { - kk_t::view_1d x("x",1000); - kk_t::view_1d y("y",1000); - kk_t::view_1d z("z",1000); + printf("running 1d tests with rng seed: %d\n",seed); + + kk_t::view_1d x ("x",1000); + kk_t::view_1d y ("y",1000); + kk_t::view_1d z ("z",1000); + kk_t::view_1d zb("zb",1000); genRandArray(x,engine,pdf); genRandArray(y,engine,pdf); bin_ops(x,y,z); math_fcns(x,y,z); + compare(x,y,zb); conditionals(x,y,z); } SECTION ("2d") { - kk_t::view_2d x("x",100,32); - kk_t::view_2d y("y",100,32); - kk_t::view_2d z("z",100,32); + printf("running 2d tests with rng seed: %d\n",seed); + + kk_t::view_2d x ("x",100,32); + kk_t::view_2d y ("y",100,32); + kk_t::view_2d z ("z",100,32); + kk_t::view_2d zb("z",100,32); genRandArray(x,engine,pdf); genRandArray(y,engine,pdf); bin_ops(x,y,z); math_fcns(x,y,z); + compare(x,y,zb); conditionals(x,y,z); } SECTION ("3d") { - kk_t::view_3d x("x",100,4,32); - kk_t::view_3d y("y",100,4,32); - kk_t::view_3d z("z",100,4,32); + printf("running 3d tests with rng seed: %d\n",seed); + + kk_t::view_3d x ("x",100,4,32); + kk_t::view_3d y ("y",100,4,32); + kk_t::view_3d z ("z",100,4,32); + kk_t::view_3d zb("z",100,4,32); genRandArray(x,engine,pdf); genRandArray(y,engine,pdf); bin_ops(x,y,z); math_fcns(x,y,z); + compare(x,y,zb); conditionals(x,y,z); } } From fc37881fffc916c02f38fe4079b4b5ab0ffa3e77 Mon Sep 17 00:00:00 2001 From: Luca Bertagna Date: Fri, 1 May 2026 18:37:16 -0600 Subject: [PATCH 4/6] Upgrade (and rename) CmpExpression to support logical and/or --- .../ekat_expression_binary_predicate.hpp | 132 ++++++++++++++++++ src/expression/ekat_expression_compare.hpp | 131 ----------------- tests/expression/expressions.cpp | 39 ++++-- 3 files changed, 161 insertions(+), 141 deletions(-) create mode 100644 src/expression/ekat_expression_binary_predicate.hpp delete mode 100644 src/expression/ekat_expression_compare.hpp diff --git a/src/expression/ekat_expression_binary_predicate.hpp b/src/expression/ekat_expression_binary_predicate.hpp new file mode 100644 index 00000000..57a58861 --- /dev/null +++ b/src/expression/ekat_expression_binary_predicate.hpp @@ -0,0 +1,132 @@ +#ifndef EKAT_EXPRESSION_BINARY_PREDICATE_HPP +#define EKAT_EXPRESSION_BINARY_PREDICATE_HPP + +#include "ekat_expression_base.hpp" + +namespace ekat { + +enum class BinaryPredicateOp : int { + EQ, // == + NE, // != + GT, // > + GE, // >= + LT, // < + LE, // <= + AND, // logical and + OR // logical or +}; + +template +class BinaryPredicateExpression : public ExpressionBase> { +public: + static constexpr bool expr_l = is_expr_v; + static constexpr bool expr_r = is_expr_v; + + using return_left_t = eval_return_t; + using return_right_t = eval_return_t; + // The return type is a logical-like, but same for all Op's, so just use one + using return_type = decltype(std::declval()==std::declval()); + + // Don't create an expression from builtin types, just compare them! + static_assert(expr_l or expr_r, + "[BinaryPredicateExpression] At least one between ELeft and ERight must be an Expression type.\n"); + + BinaryPredicateExpression (const ELeft& left, + const ERight& right) + : m_left(left) + , m_right(right) + { + // Nothing to do here + } + + static constexpr int rank() { + if constexpr (expr_l) { + if constexpr (expr_r) { + static_assert(ELeft::rank()==ERight::rank(), + "[BinaryPredicateExpression] Error! ELeft and ERight are Expression types of different rank.\n"); + } + return ELeft::rank(); + } else if constexpr (expr_r) { + return ERight::rank(); + } else { + return 0; + } + } + + int extent (int i) const { + if constexpr (expr_l) + return m_left.extent(i); + else + return m_right.extent(i); + } + + template + KOKKOS_INLINE_FUNCTION + return_type eval(Args... args) const { + if constexpr (expr_l) { + if constexpr (expr_r) + return eval_impl(m_left.eval(args...), m_right.eval(args...)); + else + return eval_impl(m_left.eval(args...), m_right); + } else if constexpr (expr_r) { + return eval_impl(m_left, m_right.eval(args...)); + } else { + return eval_impl(m_left, m_right); + } + } + +protected: + + template + KOKKOS_INLINE_FUNCTION + return_type eval_impl(const return_left_t& l, const return_right_t& r) const { + if constexpr (Op==BinaryPredicateOp::EQ) + return l==r; + else if constexpr (Op==BinaryPredicateOp::NE) + return l!=r; + else if constexpr (Op==BinaryPredicateOp::GT) + return l>r; + else if constexpr (Op==BinaryPredicateOp::GE) + return l>=r; + else if constexpr (Op==BinaryPredicateOp::LT) + return l>> \ + KOKKOS_INLINE_FUNCTION \ + auto operator OP (const T1& l, const T2& r) \ + { \ + using ret_t = BinaryPredicateExpression, \ + get_expr_node_t, \ + BinaryPredicateOp::ENUM>; \ + \ + return ret_t(get_expr_node(l),get_expr_node(r)); \ + } + +EKAT_GEN_BIN_PREDICATE_EXPR(==,EQ); +EKAT_GEN_BIN_PREDICATE_EXPR(!=,NE); +EKAT_GEN_BIN_PREDICATE_EXPR(> ,GT); +EKAT_GEN_BIN_PREDICATE_EXPR(>=,GE); +EKAT_GEN_BIN_PREDICATE_EXPR(< ,LT); +EKAT_GEN_BIN_PREDICATE_EXPR(<=,LE); +EKAT_GEN_BIN_PREDICATE_EXPR(&&,AND); +EKAT_GEN_BIN_PREDICATE_EXPR(||,OR); + +#undef EKAT_GEN_BIN_PREDICATE_EXPR + +} // namespace ekat + +#endif // EKAT_EXPRESSION_BINARY_PREDICATE_HPP diff --git a/src/expression/ekat_expression_compare.hpp b/src/expression/ekat_expression_compare.hpp deleted file mode 100644 index e27d7784..00000000 --- a/src/expression/ekat_expression_compare.hpp +++ /dev/null @@ -1,131 +0,0 @@ -#ifndef EKAT_EXPRESSION_COMPARE_HPP -#define EKAT_EXPRESSION_COMPARE_HPP - -#include "ekat_expression_base.hpp" - -#include "ekat_std_utils.hpp" -#include "ekat_kernel_assert.hpp" -#include "ekat_assert.hpp" - -namespace ekat { - -enum class Comparison : int { - EQ, // == - NE, // != - GT, // > - GE, // >= - LT, // < - LE // <= -}; - -template -class CmpExpression : public ExpressionBase> { -public: - static constexpr bool expr_l = is_expr_v; - static constexpr bool expr_r = is_expr_v; - - using return_left_t = eval_return_t; - using return_right_t = eval_return_t; - using return_type = decltype(std::declval()==std::declval()); - - // Don't create an expression from builtin types, just compare them! - static_assert(expr_l or expr_r, - "[CmpExpression] At least one between ELeft and ERight must be an Expression type.\n"); - - CmpExpression (const ELeft& left, - const ERight& right, - Comparison CMP) - : m_left(left) - , m_right(right) - , m_cmp(CMP) - { - auto valid = {Comparison::EQ,Comparison::NE,Comparison::GT, - Comparison::GE,Comparison::LT,Comparison::LE}; - EKAT_REQUIRE_MSG (ekat::contains(valid,CMP), - "[CmpExpression] Error! Unrecognized/unsupported Comparison value.\n"); - } - - static constexpr int rank() { - if constexpr (expr_l) { - if constexpr (expr_r) { - static_assert(ELeft::rank()==ERight::rank() or ELeft::rank()==0 or ERight::rank()==0, - "[CmpExpression] Error! ELeft and ERight are Expression types of different rank.\n"); - } - return ELeft::rank(); - } else if constexpr (expr_r) { - return ERight::rank(); - } else { - return 0; - } - } - - int extent (int i) const { - if constexpr (expr_l) - return m_left.extent(i); - else - return m_right.extent(i); - } - - template - KOKKOS_INLINE_FUNCTION - return_type eval(Args... args) const { - if constexpr (expr_l) { - if constexpr (expr_r) - return eval_impl(m_left.eval(args...), m_right.eval(args...)); - else - return eval_impl(m_left.eval(args...), m_right); - } else if constexpr (expr_r) { - return eval_impl(m_left, m_right.eval(args...)); - } else { - return eval_impl(m_left, m_right); - } - } - -protected: - - template - KOKKOS_INLINE_FUNCTION - return_type eval_impl(const return_left_t& l, const return_right_t& r) const { - switch (m_cmp) { - case Comparison::EQ: return l==r; - case Comparison::NE: return l!=r; - case Comparison::GT: return l>r; - case Comparison::GE: return l>=r; - case Comparison::LT: return l>> \ - KOKKOS_INLINE_FUNCTION \ - auto operator OP (const T1& l, const T2& r) \ - { \ - using ret_t = CmpExpression, \ - get_expr_node_t>; \ - \ - return ret_t(get_expr_node(l),get_expr_node(r),Comparison::ENUM); \ - } - -EKAT_GEN_CMP_OP_EXPR(==,EQ); -EKAT_GEN_CMP_OP_EXPR(!=,NE); -EKAT_GEN_CMP_OP_EXPR(> ,GT); -EKAT_GEN_CMP_OP_EXPR(>=,GE); -EKAT_GEN_CMP_OP_EXPR(< ,LT); -EKAT_GEN_CMP_OP_EXPR(<=,LE); - -#undef EKAT_GEN_CMP_OP_EXPR - -} // namespace ekat - -#endif // EKAT_EXPRESSION_COMPARE_HPP diff --git a/tests/expression/expressions.cpp b/tests/expression/expressions.cpp index 5ef547d5..969b1e9e 100644 --- a/tests/expression/expressions.cpp +++ b/tests/expression/expressions.cpp @@ -3,7 +3,7 @@ #include "ekat_expression_eval.hpp" #include "ekat_expression_binary_op.hpp" -#include "ekat_expression_compare.hpp" +#include "ekat_expression_binary_predicate.hpp" #include "ekat_expression_conditional.hpp" #include "ekat_expression_math.hpp" #include "ekat_expression_view.hpp" @@ -63,7 +63,7 @@ void math_fcns (const ViewT& x, const ViewT& y, const ViewT& z) } template -void compare (const ViewT& x, const ViewT& y, const BViewT& z) +void predicate (const ViewT& x, const ViewT& y, const BViewT& z) { auto xe = expression(x); auto ye = expression(y); @@ -86,7 +86,7 @@ void compare (const ViewT& x, const ViewT& y, const BViewT& z) } // NE { - auto expression = xe==ye; + auto expression = xe!=ye; evaluate(expression,z); auto zh = create_host_mirror_and_copy(z); @@ -154,6 +154,27 @@ void compare (const ViewT& x, const ViewT& y, const BViewT& z) REQUIRE (z_val==tgt); } } + // AND/OR + { + // These two should eval to exactly opposite values (due to De Morgan's laws) + auto expr_and = xe>=0.5 && ye<=0.5; + auto expr_or = xe<0.5 || ye>0.5; + auto z2 = Kokkos::create_mirror(typename BViewT::memory_space{}, z); + evaluate(expr_and,z); + evaluate(expr_or,z2); + + auto zh = create_host_mirror_and_copy(z); + auto z2h = create_host_mirror_and_copy(z2); + for (size_t i=0; i=0.5 and y_val<=0.5; + REQUIRE (z_val==tgt); + REQUIRE (z2_val==(not tgt)); + } + } } template @@ -209,8 +230,6 @@ void conditionals (const ViewT& x, const ViewT& y, const ViewT& z) evaluate(expression,z); auto zh = create_host_mirror_and_copy(z); for (size_t i=0; i; SECTION ("0d") { - printf("running od tests with rng seed: %d\n",seed); + printf("running 0d tests with rng seed: %d\n",seed); kk_t::view_ND x ("x"); kk_t::view_ND y ("y"); @@ -241,7 +260,7 @@ TEST_CASE("expressions", "") { bin_ops(x,y,z); math_fcns(x,y,z); - compare(x,y,zb); + predicate(x,y,zb); conditionals(x,y,z); } @@ -258,7 +277,7 @@ TEST_CASE("expressions", "") { bin_ops(x,y,z); math_fcns(x,y,z); - compare(x,y,zb); + predicate(x,y,zb); conditionals(x,y,z); } @@ -275,7 +294,7 @@ TEST_CASE("expressions", "") { bin_ops(x,y,z); math_fcns(x,y,z); - compare(x,y,zb); + predicate(x,y,zb); conditionals(x,y,z); } @@ -292,7 +311,7 @@ TEST_CASE("expressions", "") { bin_ops(x,y,z); math_fcns(x,y,z); - compare(x,y,zb); + predicate(x,y,zb); conditionals(x,y,z); } } From d33d2aaca3f813ddd2b143a585b00bc38827e2d6 Mon Sep 17 00:00:00 2001 From: Luca Bertagna Date: Fri, 1 May 2026 20:39:25 -0600 Subject: [PATCH 5/6] Fix missing pfor for rank-4 expression evaluation --- src/expression/ekat_expression_eval.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/expression/ekat_expression_eval.hpp b/src/expression/ekat_expression_eval.hpp index 6a4ca6b5..3d76aa6b 100644 --- a/src/expression/ekat_expression_eval.hpp +++ b/src/expression/ekat_expression_eval.hpp @@ -70,6 +70,7 @@ void evaluate (const ExpressionBase& base, const ViewT& result) auto eval = KOKKOS_LAMBDA (int i,int j,int k,int l) { result(i,j,k,l) = e.eval(i,j,k,l); }; + Kokkos::parallel_for(p,eval); } else if constexpr (N==5) { PolicyMD p(beg,end); auto eval = KOKKOS_LAMBDA (int i,int j,int k,int l,int m) { From 97324178c5db87a945dd080b5aef51a7d6a580b1 Mon Sep 17 00:00:00 2001 From: Luca Bertagna Date: Mon, 4 May 2026 13:39:51 -0600 Subject: [PATCH 6/6] Fix ekat_test_config.h --- tests/ekat_test_config.h.in | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/ekat_test_config.h.in b/tests/ekat_test_config.h.in index 1007101b..15f837cc 100644 --- a/tests/ekat_test_config.h.in +++ b/tests/ekat_test_config.h.in @@ -9,6 +9,9 @@ // If your test does not use Real, then you don't need any compile definition // (you don't even need to include this file). +#cmakedefine EKAT_TEST_DOUBLE_PRECISION +#cmakedefine EKAT_TEST_SINGLE_PRECISION + #ifdef EKAT_TEST_DOUBLE_PRECISION using Real = double; #elif defined(EKAT_TEST_SINGLE_PRECISION)