Skip to content

Commit c03daf2

Browse files
Support ir.Expr/ir.Var openmp strings by extending constant inference (#57)
* Add tests for more openmp string patterns
1 parent daf0b06 commit c03daf2

2 files changed

Lines changed: 307 additions & 4 deletions

File tree

src/numba/openmp/omp_lower.py

Lines changed: 175 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3084,6 +3084,165 @@ def remove_ssa_from_func_ir(func_ir):
30843084
func_ir._definitions = build_definitions(func_ir.blocks)
30853085

30863086

3087+
class _ExtendedConstantInference:
3088+
"""
3089+
Extended ConstantInference that supports binop for string concatenation.
3090+
"""
3091+
3092+
def __init__(self, func_ir):
3093+
from numba.core.consts import ConstantInference
3094+
3095+
self._base_inference = ConstantInference(func_ir)
3096+
self._func_ir = func_ir
3097+
3098+
def infer_constant(self, name, loc=None):
3099+
"""Infer a constant value, delegating to base inference first."""
3100+
from numba.core.errors import ConstantInferenceError
3101+
3102+
try:
3103+
return self._base_inference.infer_constant(name, loc=loc)
3104+
except ConstantInferenceError:
3105+
# If base inference fails, check if the variable's definition is
3106+
# an expression we can handle (like binop or call)
3107+
try:
3108+
defn = self._func_ir.get_definition(name)
3109+
if isinstance(defn, ir.Expr):
3110+
if defn.op == "binop":
3111+
return self.infer_expr(defn, loc=loc)
3112+
elif defn.op == "call":
3113+
return self.infer_expr(defn, loc=loc)
3114+
except (KeyError, AttributeError):
3115+
pass
3116+
raise
3117+
3118+
def _infer_value(self, val, loc=None):
3119+
"""
3120+
Infer a constant from a value which might be a variable name or an expression.
3121+
"""
3122+
from numba.core.errors import ConstantInferenceError
3123+
3124+
if isinstance(val, ir.Var):
3125+
return self.infer_constant(val.name, loc=val.loc)
3126+
elif isinstance(val, ir.Expr):
3127+
return self.infer_expr(val, loc=loc)
3128+
elif isinstance(val, str):
3129+
# Direct variable name
3130+
return self.infer_constant(val, loc=loc)
3131+
else:
3132+
raise ConstantInferenceError(f"Cannot infer value for {val}", loc=loc)
3133+
3134+
def infer_expr(self, expr, loc=None):
3135+
"""
3136+
Infer an expression, with added support for binop (string concatenation)
3137+
and format_value() calls.
3138+
"""
3139+
from numba.core.errors import ConstantInferenceError
3140+
3141+
if expr.op == "binop":
3142+
# Support binary operations for string concatenation
3143+
try:
3144+
lhs = self._infer_value(expr.lhs, loc=expr.loc)
3145+
rhs = self._infer_value(expr.rhs, loc=expr.loc)
3146+
# String concatenation
3147+
if isinstance(lhs, str) and isinstance(rhs, str):
3148+
return lhs + rhs
3149+
except ConstantInferenceError:
3150+
raise
3151+
# If it's not string concatenation
3152+
raise ConstantInferenceError(
3153+
f"Cannot infer binop: {lhs!r} + {rhs!r}", loc=expr.loc
3154+
)
3155+
elif expr.op == "call":
3156+
# Handle str() and format_value() calls
3157+
try:
3158+
func = expr.func
3159+
3160+
# Try to infer what function is being called
3161+
func_name = None
3162+
if isinstance(func, ir.Global):
3163+
if func.value is str:
3164+
func_name = "str"
3165+
elif isinstance(func, ir.Var):
3166+
# Try to resolve the variable to see what function it points to
3167+
try:
3168+
func_defn = self._func_ir.get_definition(func.name)
3169+
# Handle ir.Global directly (Python 3.13+ for format_simple)
3170+
if isinstance(func_defn, ir.Global):
3171+
if func_defn.value is str:
3172+
func_name = "str"
3173+
elif (
3174+
isinstance(func_defn, ir.Expr) and func_defn.op == "global"
3175+
):
3176+
if func_defn.value is str:
3177+
func_name = "str"
3178+
elif (
3179+
hasattr(func_defn.value, "__name__")
3180+
and "format_value" in func_defn.value.__name__
3181+
):
3182+
func_name = "format_value"
3183+
else:
3184+
# Check if the variable name itself suggests what it is
3185+
if "format_value" in func.name:
3186+
func_name = "format_value"
3187+
except (KeyError, AttributeError):
3188+
if "format_value" in func.name:
3189+
func_name = "format_value"
3190+
3191+
# Handle str() calls
3192+
if func_name == "str" or (
3193+
isinstance(func, ir.Global) and func.value is str
3194+
):
3195+
if len(expr.args) >= 1:
3196+
arg_val = self._infer_value(expr.args[0], loc=expr.loc)
3197+
return str(arg_val)
3198+
3199+
# Handle format_value calls (used in f-strings)
3200+
if func_name == "format_value":
3201+
if len(expr.args) >= 1:
3202+
arg_val = self._infer_value(expr.args[0], loc=expr.loc)
3203+
return str(arg_val)
3204+
3205+
# If we don't recognize the function, don't try base inference
3206+
raise ConstantInferenceError(
3207+
f"Cannot infer call to unknown function: {func}", loc=expr.loc
3208+
)
3209+
3210+
except ConstantInferenceError:
3211+
raise
3212+
else:
3213+
# Delegate to base inference for other operations
3214+
return self._base_inference._infer_expr(expr)
3215+
3216+
3217+
def _try_infer_string_constant(arg, func_ir):
3218+
"""
3219+
Try to infer a constant string value from an IR node.
3220+
Uses extended ConstantInference that supports binop for string concatenation
3221+
and format_value calls used in f-strings.
3222+
3223+
Returns the string value if resolvable, None otherwise.
3224+
"""
3225+
from numba.core.errors import ConstantInferenceError
3226+
3227+
try:
3228+
inference = _ExtendedConstantInference(func_ir)
3229+
3230+
# For variables, use ConstantInference to resolve them
3231+
if isinstance(arg, ir.Var):
3232+
value = inference.infer_constant(arg.name, loc=arg.loc)
3233+
if isinstance(value, str):
3234+
return value
3235+
# For expressions, try using the extended inference
3236+
elif isinstance(arg, ir.Expr):
3237+
value = inference.infer_expr(arg)
3238+
if isinstance(value, str):
3239+
return value
3240+
except (ConstantInferenceError, AttributeError, NotImplementedError):
3241+
pass
3242+
3243+
return None
3244+
3245+
30873246
def _add_openmp_ir_nodes(func_ir, blocks, blk_start, blk_end, body_blocks, extra):
30883247
"""Given the starting and ending block of the with-context,
30893248
replaces the head block with a new block that has the starting
@@ -3097,14 +3256,16 @@ def _add_openmp_ir_nodes(func_ir, blocks, blk_start, blk_end, body_blocks, extra
30973256
args = extra["args"]
30983257
arg = args[0]
30993258

3259+
pragma_value = None
3260+
31003261
# If OpenMP argument is not a constant or not a string then raise exception
3101-
# Accept ir.Const, ir.FreeVar, and ir.Global (closure variables)
3102-
if not isinstance(arg, (ir.Const, ir.FreeVar, ir.Global)):
3262+
# Accept ir.Const, ir.FreeVar, ir.Global, ir.Expr, and ir.Var
3263+
if not isinstance(arg, (ir.Const, ir.FreeVar, ir.Global, ir.Expr, ir.Var)):
31033264
raise NonconstantOpenmpSpecification(
31043265
f"Non-constant OpenMP specification at line {arg.loc}"
31053266
)
31063267

3107-
# Extract the actual string value from Const, FreeVar, or Global
3268+
# Extract the actual string value from various IR types
31083269
if isinstance(arg, ir.Const):
31093270
pragma_value = arg.value
31103271
if not isinstance(pragma_value, str):
@@ -3125,6 +3286,17 @@ def _add_openmp_ir_nodes(func_ir, blocks, blk_start, blk_end, body_blocks, extra
31253286
raise NonStringOpenmpSpecification(
31263287
f"Non-string OpenMP specification at line {arg.loc}"
31273288
)
3289+
else:
3290+
# Handle ir.Var and ir.Expr using ConstantInference
3291+
pragma_value = _try_infer_string_constant(arg, func_ir)
3292+
if pragma_value is None:
3293+
raise NonconstantOpenmpSpecification(
3294+
f"Cannot infer constant OpenMP specification at line {arg.loc}"
3295+
)
3296+
if not isinstance(pragma_value, str):
3297+
raise NonStringOpenmpSpecification(
3298+
f"Non-string OpenMP specification at line {arg.loc}"
3299+
)
31283300

31293301
if DEBUG_OPENMP >= 1:
31303302
print("args:", args, type(args))

src/numba/openmp/tests/test_openmp.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2032,7 +2032,7 @@ def test_impl(v):
20322032
with self.assertRaises(NonconstantOpenmpSpecification) as raises:
20332033
test_impl(np.zeros(100))
20342034
self.assertIn(
2035-
"Non-constant OpenMP specification at line", str(raises.exception)
2035+
"Cannot infer constant OpenMP specification", str(raises.exception)
20362036
)
20372037

20382038
# def test_parallel_for_blocking_if(self):
@@ -5037,5 +5037,136 @@ def test_impl():
50375037
self.assertGreater(jit_num_devices, 0)
50385038

50395039

5040+
class TestOpenmpStringPatterns(TestOpenmpBase):
5041+
def __init__(self, *args):
5042+
TestOpenmpBase.__init__(self, *args)
5043+
5044+
def test_omp_jit_const_string(self):
5045+
@njit
5046+
def test_impl(x):
5047+
with openmp("parallel num_threads(4)"):
5048+
tid = omp_get_thread_num()
5049+
x[tid] = x[tid] + 1
5050+
return x
5051+
5052+
x = np.zeros(4)
5053+
x = test_impl(x)
5054+
np.testing.assert_array_equal(x, np.ones(4))
5055+
5056+
def test_omp_py_const_string(self):
5057+
omp_string = "parallel num_threads(4)"
5058+
5059+
@njit
5060+
def test_impl(x):
5061+
with openmp(omp_string):
5062+
tid = omp_get_thread_num()
5063+
x[tid] = x[tid] + 1
5064+
return x
5065+
5066+
x = np.zeros(4)
5067+
x = test_impl(x)
5068+
np.testing.assert_array_equal(x, np.ones(4))
5069+
5070+
def test_omp_jit_fstring(self):
5071+
@njit
5072+
def test_impl(x):
5073+
num_threads = 4
5074+
with openmp(f"parallel num_threads({num_threads})"):
5075+
tid = omp_get_thread_num()
5076+
x[tid] = x[tid] + 1
5077+
return x
5078+
5079+
x = np.zeros(4)
5080+
x = test_impl(x)
5081+
np.testing.assert_array_equal(x, np.ones(4))
5082+
5083+
def test_omp_py_fstring(self):
5084+
num_threads = 4
5085+
omp_string = f"parallel num_threads({num_threads})"
5086+
5087+
@njit
5088+
def test_impl(x):
5089+
with openmp(omp_string):
5090+
tid = omp_get_thread_num()
5091+
x[tid] = x[tid] + 1
5092+
return x
5093+
5094+
x = np.zeros(4)
5095+
x = test_impl(x)
5096+
np.testing.assert_array_equal(x, np.ones(4))
5097+
5098+
def test_omp_string_concat_literals(self):
5099+
@njit
5100+
def test_impl(x):
5101+
with openmp("parallel " + "num_threads(4)"):
5102+
tid = omp_get_thread_num()
5103+
x[tid] = x[tid] + 1
5104+
return x
5105+
5106+
x = np.zeros(4)
5107+
x = test_impl(x)
5108+
np.testing.assert_array_equal(x, np.ones(4))
5109+
5110+
def test_omp_string_concat_jit_variables(self):
5111+
@njit
5112+
def test_impl(x):
5113+
prefix = "parallel "
5114+
suffix = "num_threads(4)"
5115+
with openmp(prefix + suffix):
5116+
tid = omp_get_thread_num()
5117+
x[tid] = x[tid] + 1
5118+
return x
5119+
5120+
x = np.zeros(4)
5121+
x = test_impl(x)
5122+
np.testing.assert_array_equal(x, np.ones(4))
5123+
5124+
def test_omp_string_concat_variables(self):
5125+
num_threads = 4
5126+
omp_string = "parallel num_threads(" + str(num_threads) + ")"
5127+
5128+
@njit
5129+
def test_impl(x):
5130+
with openmp(omp_string):
5131+
tid = omp_get_thread_num()
5132+
x[tid] = x[tid] + 1
5133+
return x
5134+
5135+
x = np.zeros(4)
5136+
x = test_impl(x)
5137+
np.testing.assert_array_equal(x, np.ones(4))
5138+
5139+
def test_omp_nested_concat(self):
5140+
prefix = "parallel "
5141+
suffix = "num_threads(4)"
5142+
omp_string = prefix + suffix
5143+
5144+
@njit
5145+
def test_impl(x):
5146+
with openmp(omp_string):
5147+
tid = omp_get_thread_num()
5148+
x[tid] = x[tid] + 1
5149+
return x
5150+
5151+
x = np.zeros(4)
5152+
x = test_impl(x)
5153+
np.testing.assert_array_equal(x, np.ones(4))
5154+
5155+
def test_omp_explicit_str_call(self):
5156+
num_threads = 4
5157+
omp_string = "parallel " + "num_threads(" + str(num_threads) + ")"
5158+
5159+
@njit
5160+
def test_impl(x):
5161+
with openmp(omp_string):
5162+
tid = omp_get_thread_num()
5163+
x[tid] = x[tid] + 1
5164+
return x
5165+
5166+
x = np.zeros(4)
5167+
x = test_impl(x)
5168+
np.testing.assert_array_equal(x, np.ones(4))
5169+
5170+
50405171
if __name__ == "__main__":
50415172
unittest.main()

0 commit comments

Comments
 (0)