From 88ead3786e933add2a4dcb6adc971523704b2c67 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Wed, 12 Jan 2022 11:24:32 -0800 Subject: [PATCH 01/25] Adding missing stubs --- examples/triangle.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/triangle.py b/examples/triangle.py index 8b4d872..7d8a361 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -676,17 +676,18 @@ def lower_optix_Trace(context, builder, sig, args): dx, dy, dz = rayDirection.x, rayDirection.y, rayDirection.z n_stub_output_operands = 32 - 3 - outputs = [builder.alloca(ir.IntType(32)) for _ in range(n_stub_output_operands)] - - asm = ir.InlineAsm(ir.FunctionType(ir.VoidType(), []), - "call " - "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" - "29,%30,%31)," - "_optix_trace_typed_32," - "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" - "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);" - "=r," * 32 + "r,l,f,f,f,f,f,f,f,f,f,r,r,r,r,r,r," + "r," * 31 + "r", - outputs + [0, handle, ox, oy, oz, dx, dy, dz, tmin, tmax, rayTime, visibilityMask, rayFlags, SBToffset, SBTstride, missSBTIndex, 0, p0, p1, p2] + outputs + output_stubs = [builder.alloca(ir.IntType(32)) for _ in range(n_stub_output_operands)] + + asm = ir.InlineAsm( + ir.FunctionType(ir.VoidType(), []), + "call " + "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" + "29,%30,%31)," + "_optix_trace_typed_32," + "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" + "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);", + "=r," * 32 + "r,l,f,f,f,f,f,f,f,f,f,r,r,r,r,r,r," + "r," * 31 + "r", + [p0, p1, p2] + output_stubs + [0, handle, ox, oy, oz, dx, dy, dz, tmin, tmax, rayTime, visibilityMask, rayFlags, SBToffset, SBTstride, missSBTIndex, 3, p0, p1, p2] + output_stubs ) return builder.call(asm, []) From 9fd65c322ac5616fb89d1119ad73f8c35e4e9248 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Wed, 12 Jan 2022 13:12:14 -0800 Subject: [PATCH 02/25] optix trace is working...? --- examples/triangle.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/triangle.py b/examples/triangle.py index 4b8cde9..a1603f6 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -678,11 +678,12 @@ def lower_optix_Trace(context, builder, sig, args): n_payload_registers = 3 n_stub_output_operands = 32 - n_payload_registers outputs = ([p0, p1, p2] + - [builder.alloca(ir.IntType(32)) + [builder.load(builder.alloca(ir.IntType(32))) for _ in range(n_stub_output_operands)]) - asm = ir.InlineAsm(ir.FunctionType(ir.VoidType(), []), + retty = ir.LiteralStructType([ir.IntType(32)] * 32) + asm = ir.InlineAsm(ir.FunctionType(retty, []), "call " "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" "29,%30,%31)," @@ -690,11 +691,12 @@ def lower_optix_Trace(context, builder, sig, args): "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);", "=r," * 32 + "r,l,f,f,f,f,f,f,f,f,f,r,r,r,r,r,r," + "r," * 31 + "r", + side_effect=True ) zero = context.get_constant(types.int32, 0) c_payload_registers = context.get_constant(types.int32, n_payload_registers) - args = outputs + [zero, handle, ox, oy, oz, dx, dy, dz, tmin, tmax, rayTime, + args = [zero, handle, ox, oy, oz, dx, dy, dz, tmin, tmax, rayTime, visibilityMask, rayFlags, SBToffset, SBTstride, missSBTIndex, c_payload_registers] + outputs return builder.call(asm, args) From ac4e52a3b7fd0870f5e387963a0868193f04f165 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Mon, 17 Jan 2022 16:22:10 -0800 Subject: [PATCH 03/25] Implemented __miss_ms --- examples/triangle.py | 231 ++++++++++++++++++++++++++++++------------- 1 file changed, 163 insertions(+), 68 deletions(-) diff --git a/examples/triangle.py b/examples/triangle.py index a1603f6..f6baea5 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -15,7 +15,7 @@ from numba import cuda, float32, types, uint8, uint32 from numba.core import cgutils from numba.core.extending import (make_attribute_wrapper, models, overload, - register_model, typeof_impl) + register_model, typeof_impl, type_callable) from numba.core.imputils import lower_constant from numba.core.typing.templates import (AttributeTemplate, ConcreteTemplate, signature) @@ -413,6 +413,26 @@ def lower_make_uint3(context, builder, sig, args): # OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT = 1 << 7 +# OptiX types +# ----------- + +# Typing for OptiX types + +class SbtDataPointer(types.RawPointer): + def __init__(self): + super().__init__(name="SbtDataPointer") + + +sbt_data_pointer = SbtDataPointer() + + +# Models for OptiX types + +@register_model(SbtDataPointer) +class SbtDataPointerModel(models.OpaqueModel): + pass + + # Params # ------------ @@ -431,17 +451,8 @@ class ParamsStruct: ) -class MissDataStruct: - fields = { - ('bg_color', 'float3') - } - - # "Declare" a global called params - params = ParamsStruct() -MissData = MissDataStruct() - class Params(types.Type): def __init__(self): @@ -485,6 +496,8 @@ def typeof_params(val, c): # ParamsStruct lowering +# The below makes 'param' a global variable, accessible from any user defined +# kernels. @lower_constant(Params) def constant_params(context, builder, ty, pyval): @@ -500,24 +513,61 @@ def constant_params(context, builder, ty, pyval): return builder.load(gvar) -# OptiX types -# ----------- - -# Typing for OptiX types +# MissData +# ------------ -class SbtDataPointer(types.RawPointer): - def __init__(self): - super().__init__(name="SbtDataPointer") +# Structures as declared in triangle.h +class MissDataStruct: + fields = ( + ('bg_color', 'float3') + ) +MissData = MissDataStruct() -sbt_data_pointer = SbtDataPointer() +class MissData(types.Type): + def __init__(self): + super().__init__(name='MissDataType') +miss_data_type = MissData() -# Models for OptiX types +@register_model(MissData) +class MissDataModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ('bg_color', float3), + ] + super().__init__(dmm, fe_type, members) -@register_model(SbtDataPointer) -class SbtDataPointerModel(models.OpaqueModel): - pass +make_attribute_wrapper(MissData, 'bg_color', 'bg_color') + +@typeof_impl.register(MissDataStruct) +def typeof_miss_data(val, c): + return miss_data_type + +# MissData Constructor +@type_callable(MissDataStruct) +def type_miss_data_struct(context): + def typer(sbt_data_pointer): + if isinstance(sbt_data_pointer, SbtDataPointer): + return miss_data_type + return typer + +@lower(MissDataStruct, sbt_data_pointer) +def lower_miss_data_ctor(context, builder, sig, args): + # Anyway to err if this ctor is not called inside __miss__* program? + ptr = args[0] + ptr = builder.bitcast(ptr, + context.get_value_type(miss_data_type).as_pointer()) + miss_data = cgutils.create_struct_proxy(miss_data_type)(context, builder) + bg_color_ptr = cgutils.gep_inbounds(builder, ptr, 0, 0) + + xptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 0) + yptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 1) + zptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 2) + miss_data.bg_color.x = builder.load(xptr) + miss_data.bg_color.y = builder.load(yptr) + miss_data.bg_color.z = builder.load(zptr) + return miss_data._getvalue() # OptiX functions @@ -537,6 +587,14 @@ def _optix_GetLaunchDimensions(): def _optix_GetSbtDataPointer(): pass +def _optix_SetPayload_0(): + pass + +def _optix_SetPayload_1(): + pass + +def _optix_SetPayload_2(): + pass def _optix_Trace(): pass @@ -549,6 +607,10 @@ def _optix_Trace(): optix.GetLaunchIndex = _optix_GetLaunchIndex optix.GetLaunchDimensions = _optix_GetLaunchDimensions optix.GetSbtDataPointer = _optix_GetSbtDataPointer +optix.SetPayload_0 = _optix_SetPayload_0 +optix.SetPayload_1 = _optix_SetPayload_1 +optix.SetPayload_2 = _optix_SetPayload_2 + optix.Trace = _optix_Trace @@ -566,6 +628,22 @@ class OptixGetLaunchDimensions(ConcreteTemplate): cases = [signature(dim3)] +@register +class OptixGetSbtDataPointer(ConcreteTemplate): + key = optix.GetSbtDataPointer + cases = [signature(sbt_data_pointer)] + +def registerSetPayload(reg): + class OptixSetPayloadReg(ConcreteTemplate): + key = getattr(optix, 'SetPayload_' + str(reg)) + cases = [signature(types.void, uint32)] + register(OptixSetPayloadReg) + return OptixSetPayloadReg + +OptixSetPayload_0 = registerSetPayload(0) +OptixSetPayload_1 = registerSetPayload(1) +OptixSetPayload_2 = registerSetPayload(2) + @register class OptixTrace(ConcreteTemplate): key = optix.Trace @@ -601,6 +679,15 @@ def resolve_GetLaunchDimensions(self, mod): def resolve_GetSbtDataPointer(self, mod): return types.Function(OptixGetSbtDataPointer) + def resolve_SetPayload_0(self, mod): + return types.Function(OptixSetPayload_0) + + def resolve_SetPayload_1(self, mod): + return types.Function(OptixSetPayload_1) + + def resolve_SetPayload_2(self, mod): + return types.Function(OptixSetPayload_2) + def resolve_Trace(self, mod): return types.Function(OptixTrace) @@ -647,6 +734,18 @@ def lower_optix_getSbtDataPointer(context, builder, sig, args): return ptr +def lower_optix_SetPayloadReg(reg): + def lower_optix_SetPayload_impl(context, builder, sig, args): + asm = ir.InlineAsm(ir.FunctionType(ir.VoidType(), [ir.IntType(32), ir.IntType(32)]), + f"call _optix_set_payload_{reg};", + "r, r") + builder.call(asm, [context.get_constant(types.int32, reg), args[0]]) + lower(getattr(optix, f"SetPayload_{reg}"), uint32)(lower_optix_SetPayload_impl) + +lower_optix_SetPayloadReg(0) +lower_optix_SetPayloadReg(1) +lower_optix_SetPayloadReg(2) + @lower(optix.Trace, OptixTraversableHandle, float3, @@ -862,7 +961,7 @@ def set_pipeline_options(): ) -def create_module( ctx, pipeline_options, triangle_ptx ): +def create_module( ctx, pipeline_options, ptx ): print( "Creating optix module ..." ) @@ -875,47 +974,39 @@ def create_module( ctx, pipeline_options, triangle_ptx ): module, log = ctx.moduleCreateFromPTX( module_options, pipeline_options, - triangle_ptx - ) + ptx + ) print( "\tModule create log: <<<{}>>>".format( log ) ) return module -def create_program_groups( ctx, module ): +def create_program_groups( ctx, raygen_module, miss_prog_module, hitgroup_module ): print( "Creating program groups ... " ) program_group_options = optix.ProgramGroupOptions() raygen_prog_group_desc = optix.ProgramGroupDesc() + raygen_prog_group_desc.kind = optix.OPTIX_PROGRAM_GROUP_KIND_RAYGEN raygen_prog_group_desc.raygenModule = module raygen_prog_group_desc.raygenEntryFunctionName = "__raygen__rg" - raygen_prog_group, log = ctx.programGroupCreate( - [ raygen_prog_group_desc ], - program_group_options, - ) - print( "\tProgramGroup raygen create log: <<<{}>>>".format( log ) ) - + miss_prog_group_desc = optix.ProgramGroupDesc() + miss_prog_group_desc.kind = optix.OPTIX_PROGRAM_GROUP_KIND_MISS miss_prog_group_desc.missModule = module miss_prog_group_desc.missEntryFunctionName = "__miss__ms" - miss_prog_group, log = ctx.programGroupCreate( - [ miss_prog_group_desc ], - program_group_options, - ) - print( "\tProgramGroup miss create log: <<<{}>>>".format( log ) ) - hitgroup_prog_group_desc = optix.ProgramGroupDesc() + hitgroup_prog_group_desc.kind = optix.OPTIX_PROGRAM_GROUP_KIND_HITGROUP hitgroup_prog_group_desc.hitgroupModuleCH = module hitgroup_prog_group_desc.hitgroupEntryFunctionNameCH = "__closesthit__ch" - hitgroup_prog_group, log = ctx.programGroupCreate( - [ hitgroup_prog_group_desc ], + + prog_group, log = ctx.programGroupCreate( + [ raygen_prog_group_desc, miss_prog_group_desc, hitgroup_prog_group_desc ], program_group_options, ) - print( "\tProgramGroup hitgroup create log: <<<{}>>>".format( log ) ) - + print( "\tProgramGroup create log: <<<{}>>>".format( log ) ) - return [ raygen_prog_group[0], miss_prog_group[0], hitgroup_prog_group[0] ] + return prog_group def create_pipeline( ctx, program_groups, pipeline_compile_options ): @@ -1076,7 +1167,7 @@ def launch( pipeline, sbt, trav_handle ): pix_width, pix_height, 1 # depth - ) + ) stream.synchronize() @@ -1193,9 +1284,9 @@ def make_color(c): @cuda.jit(device=True) def setPayload(p): - optix.SetPayload_0(float_as_int(p.x)) - optix.SetPayload_1(float_as_int(p.y)) - optix.SetPayload_2(float_as_int(p.z)) + optix.SetPayload_0(uint32(p.x)) + optix.SetPayload_1(uint32(p.y)) + optix.SetPayload_2(uint32(p.z)) @cuda.jit(device=True) def computeRay(idx, dim, origin, direction): @@ -1226,33 +1317,31 @@ def __raygen__rg(): p0 = cuda.local.array(1, types.int32) p1 = cuda.local.array(1, types.int32) p2 = cuda.local.array(1, types.int32) - optix.Trace( - params.handle, - ray_origin, - ray_direction, - types.float32(0.0), # Min intersection distance - types.float32(1e16), # Max intersection distance - types.float32(0.0), # rayTime -- used for motion blur - OptixVisibilityMask(255), # Specify always visible - # OptixRayFlags.OPTIX_RAY_FLAG_NONE, - uint32(OPTIX_RAY_FLAG_NONE), - uint32(0), # SBT offset -- See SBT discussion - uint32(1), # SBT stride -- See SBT discussion - uint32(0), # missSBTIndex -- See SBT discussion - p0[0], p1[0], p2[0]) + # optix.Trace( + # params.handle, + # ray_origin, + # ray_direction, + # types.float32(0.0), # Min intersection distance + # types.float32(1e16), # Max intersection distance + # types.float32(0.0), # rayTime -- used for motion blur + # OptixVisibilityMask(255), # Specify always visible + # # OptixRayFlags.OPTIX_RAY_FLAG_NONE, + # uint32(OPTIX_RAY_FLAG_NONE), + # uint32(0), # SBT offset -- See SBT discussion + # uint32(1), # SBT stride -- See SBT discussion + # uint32(0), # missSBTIndex -- See SBT discussion + # p0[0], p1[0], p2[0]) result = make_float3(p0[0], p1[0], p2[0]) # Record results in our output raster params.image[idx.y * params.image_width + idx.x] = make_color( result ) -@cuda.jit def __miss__ms(): - miss_data = MissData(optix.GetSbtDataPointer()) + miss_data = MissDataStruct(optix.GetSbtDataPointer()) setPayload(miss_data.bg_color) -@cuda.jit def __closesthit__ch(): # When built-in triangle intersection is used, a number of fundamental # attributes are provided by the OptiX API, indlucing barycentric coordinates. @@ -1269,17 +1358,23 @@ def __closesthit__ch(): def main(): - triangle_ptx = compile_numba(__raygen__rg) + raygen_ptx = compile_numba(__raygen__rg) + miss_ptx = compile_numba(__miss__ms) + hitgroup_ptx = compile_numba(__closesthit__ch) + # triangle_ptx = compile_cuda( "examples/triangle.cu" ) - # print(triangle_ptx) init_optix() ctx = create_ctx() gas_handle, d_gas_output_buffer = create_accel(ctx) pipeline_options = set_pipeline_options() - module = create_module( ctx, pipeline_options, triangle_ptx ) - prog_groups = create_program_groups( ctx, module ) + + raygen_module = create_module( ctx, pipeline_options, raygen_ptx ) + miss_module = create_module( ctx, pipeline_options, miss_ptx ) + hitgroup_module = create_module( ctx, pipeline_options, hitgroup_ptx ) + + prog_groups = create_program_groups( ctx, raygen_module, miss_module, hitgroup_module ) pipeline = create_pipeline( ctx, prog_groups, pipeline_options ) sbt = create_sbt( prog_groups ) pix = launch( pipeline, sbt, gas_handle ) From 3e19bf8488d10372ad3da23cea0e2c3c7069fd31 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Tue, 18 Jan 2022 11:38:10 -0800 Subject: [PATCH 04/25] get_triangle_barycentrics --- examples/triangle.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/examples/triangle.py b/examples/triangle.py index f6baea5..7bcb420 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -596,6 +596,9 @@ def _optix_SetPayload_1(): def _optix_SetPayload_2(): pass +def _optix_GetTriangleBarycentrics(): + pass + def _optix_Trace(): pass @@ -607,6 +610,7 @@ def _optix_Trace(): optix.GetLaunchIndex = _optix_GetLaunchIndex optix.GetLaunchDimensions = _optix_GetLaunchDimensions optix.GetSbtDataPointer = _optix_GetSbtDataPointer +optix.GetTriangleBarycentrics = _optix_GetTriangleBarycentrics optix.SetPayload_0 = _optix_SetPayload_0 optix.SetPayload_1 = _optix_SetPayload_1 optix.SetPayload_2 = _optix_SetPayload_2 @@ -644,6 +648,11 @@ class OptixSetPayloadReg(ConcreteTemplate): OptixSetPayload_1 = registerSetPayload(1) OptixSetPayload_2 = registerSetPayload(2) +@register +class OptixGetTriangleBarycentrics(ConcreteTemplate): + key = optix.GetTriangleBarycentrics + cases = [signature(float2)] + @register class OptixTrace(ConcreteTemplate): key = optix.Trace @@ -688,6 +697,9 @@ def resolve_SetPayload_1(self, mod): def resolve_SetPayload_2(self, mod): return types.Function(OptixSetPayload_2) + def resolve_GetTriangleBarycentrics(self, mod): + return types.Function(OptixGetTriangleBarycentrics) + def resolve_Trace(self, mod): return types.Function(OptixTrace) @@ -746,6 +758,21 @@ def lower_optix_SetPayload_impl(context, builder, sig, args): lower_optix_SetPayloadReg(1) lower_optix_SetPayloadReg(2) +@lower(optix.GetTriangleBarycentrics) +def lower_optix_getTriangleBarycentrics(context, builder, sig, args): + f2 = cgutils.create_struct_proxy(float2)(context, builder) + retty = ir.LiteralStructType([ir.FloatType(), ir.FloatType()]) + asm = ir.InlineAsm( + ir.FunctionType(retty, []), + "call (%0, %1), _optix_get_triangle_barycentrics, ();", + "=f, =f" + ) + ret = builder.call(asm, []) + f2.x = builder.extract_value(ret, 0) + f2.y = builder.extract_value(ret, 1) + return f2._getvalue() + + @lower(optix.Trace, OptixTraversableHandle, float3, From bf1f6bf7d6f0a597e8d197f73b1a34a4d85d8e00 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Tue, 18 Jan 2022 13:37:32 -0800 Subject: [PATCH 05/25] Rendered a black image --- examples/triangle.py | 84 +++++++++++++++++++++++++------------------- 1 file changed, 47 insertions(+), 37 deletions(-) diff --git a/examples/triangle.py b/examples/triangle.py index 7bcb420..2361024 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -112,6 +112,16 @@ def __init__(self): float3 = Float3() +# Float2 typing (forward declaration) + +class Float2(types.Type): + def __init__(self): + super().__init__(name="Float2") + + +float2 = Float2() + + # Float3 data model @register_model(Float3) @@ -213,7 +223,8 @@ def make_float3(x, y, z): @register class MakeFloat3(ConcreteTemplate): key = make_float3 - cases = [signature(float3, types.float32, types.float32, types.float32)] + cases = [signature(float3, types.float32, types.float32, types.float32), + signature(float3, float2, types.float32)] register_global(make_float3, types.Function(MakeFloat3)) @@ -230,17 +241,19 @@ def lower_make_float3(context, builder, sig, args): return f3._getvalue() -# float2 -# ------ - -# Float2 typing +@lower(make_float3, float2, types.float32) +def lower_make_float3(context, builder, sig, args): + f2 = cgutils.create_struct_proxy(float2)(context, builder, args[0]) + f3 = cgutils.create_struct_proxy(float3)(context, builder) + f3.x = f2.x + f3.y = f2.y + f3.z = args[1] + return f3._getvalue() -class Float2(types.Type): - def __init__(self): - super().__init__(name="Float2") +# float2 +# ------ -float2 = Float2() # Float2 data model @@ -749,8 +762,8 @@ def lower_optix_getSbtDataPointer(context, builder, sig, args): def lower_optix_SetPayloadReg(reg): def lower_optix_SetPayload_impl(context, builder, sig, args): asm = ir.InlineAsm(ir.FunctionType(ir.VoidType(), [ir.IntType(32), ir.IntType(32)]), - f"call _optix_set_payload_{reg};", - "r, r") + f"call _optix_set_payload, ($0, $1);", + "r,r") builder.call(asm, [context.get_constant(types.int32, reg), args[0]]) lower(getattr(optix, f"SetPayload_{reg}"), uint32)(lower_optix_SetPayload_impl) @@ -764,8 +777,8 @@ def lower_optix_getTriangleBarycentrics(context, builder, sig, args): retty = ir.LiteralStructType([ir.FloatType(), ir.FloatType()]) asm = ir.InlineAsm( ir.FunctionType(retty, []), - "call (%0, %1), _optix_get_triangle_barycentrics, ();", - "=f, =f" + "call ($0, $1), _optix_get_triangle_barycentrics, ();", + "=f,=f" ) ret = builder.call(asm, []) f2.x = builder.extract_value(ret, 0) @@ -811,11 +824,11 @@ def lower_optix_Trace(context, builder, sig, args): retty = ir.LiteralStructType([ir.IntType(32)] * 32) asm = ir.InlineAsm(ir.FunctionType(retty, []), "call " - "(%0,%1,%2,%3,%4,%5,%6,%7,%8,%9,%10,%11,%12,%13,%14,%15,%16,%17,%18,%19,%20,%21,%22,%23,%24,%25,%26,%27,%28,%" - "29,%30,%31)," + "($0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29," + "$30,$31)," "_optix_trace_typed_32," - "(%32,%33,%34,%35,%36,%37,%38,%39,%40,%41,%42,%43,%44,%45,%46,%47,%48,%49,%50,%51,%52,%53,%54,%55,%56,%57,%58,%" - "59,%60,%61,%62,%63,%64,%65,%66,%67,%68,%69,%70,%71,%72,%73,%74,%75,%76,%77,%78,%79,%80);", + "($32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59," + "$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80);", "=r," * 32 + "r,l,f,f,f,f,f,f,f,f,f,r,r,r,r,r,r," + "r," * 31 + "r", side_effect=True ) @@ -1013,18 +1026,15 @@ def create_program_groups( ctx, raygen_module, miss_prog_module, hitgroup_module program_group_options = optix.ProgramGroupOptions() raygen_prog_group_desc = optix.ProgramGroupDesc() - raygen_prog_group_desc.kind = optix.OPTIX_PROGRAM_GROUP_KIND_RAYGEN - raygen_prog_group_desc.raygenModule = module + raygen_prog_group_desc.raygenModule = raygen_module raygen_prog_group_desc.raygenEntryFunctionName = "__raygen__rg" miss_prog_group_desc = optix.ProgramGroupDesc() - miss_prog_group_desc.kind = optix.OPTIX_PROGRAM_GROUP_KIND_MISS - miss_prog_group_desc.missModule = module + miss_prog_group_desc.missModule = miss_prog_module miss_prog_group_desc.missEntryFunctionName = "__miss__ms" hitgroup_prog_group_desc = optix.ProgramGroupDesc() - hitgroup_prog_group_desc.kind = optix.OPTIX_PROGRAM_GROUP_KIND_HITGROUP - hitgroup_prog_group_desc.hitgroupModuleCH = module + hitgroup_prog_group_desc.hitgroupModuleCH = hitgroup_module hitgroup_prog_group_desc.hitgroupEntryFunctionNameCH = "__closesthit__ch" prog_group, log = ctx.programGroupCreate( @@ -1344,20 +1354,20 @@ def __raygen__rg(): p0 = cuda.local.array(1, types.int32) p1 = cuda.local.array(1, types.int32) p2 = cuda.local.array(1, types.int32) - # optix.Trace( - # params.handle, - # ray_origin, - # ray_direction, - # types.float32(0.0), # Min intersection distance - # types.float32(1e16), # Max intersection distance - # types.float32(0.0), # rayTime -- used for motion blur - # OptixVisibilityMask(255), # Specify always visible - # # OptixRayFlags.OPTIX_RAY_FLAG_NONE, - # uint32(OPTIX_RAY_FLAG_NONE), - # uint32(0), # SBT offset -- See SBT discussion - # uint32(1), # SBT stride -- See SBT discussion - # uint32(0), # missSBTIndex -- See SBT discussion - # p0[0], p1[0], p2[0]) + optix.Trace( + params.handle, + ray_origin, + ray_direction, + types.float32(0.0), # Min intersection distance + types.float32(1e16), # Max intersection distance + types.float32(0.0), # rayTime -- used for motion blur + OptixVisibilityMask(255), # Specify always visible + # OptixRayFlags.OPTIX_RAY_FLAG_NONE, + uint32(OPTIX_RAY_FLAG_NONE), + uint32(0), # SBT offset -- See SBT discussion + uint32(1), # SBT stride -- See SBT discussion + uint32(0), # missSBTIndex -- See SBT discussion + p0[0], p1[0], p2[0]) result = make_float3(p0[0], p1[0], p2[0]) # Record results in our output raster From 83524e78eac7816ee016ed88f3c3036b253811d6 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Tue, 18 Jan 2022 15:09:12 -0800 Subject: [PATCH 06/25] minor cleanups and typo(bug) fixes --- examples/triangle.py | 60 ++++++++++++++++++++++---------------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/examples/triangle.py b/examples/triangle.py index 2361024..fbb7dbd 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -12,7 +12,7 @@ from llvmlite import ir -from numba import cuda, float32, types, uint8, uint32 +from numba import cuda, float32, types, uint8, uint32, int32 from numba.core import cgutils from numba.core.extending import (make_attribute_wrapper, models, overload, register_model, typeof_impl, type_callable) @@ -223,8 +223,10 @@ def make_float3(x, y, z): @register class MakeFloat3(ConcreteTemplate): key = make_float3 - cases = [signature(float3, types.float32, types.float32, types.float32), - signature(float3, float2, types.float32)] + cases = [ + signature(float3, types.float32, types.float32, types.float32), + signature(float3, float2, types.float32) + ] register_global(make_float3, types.Function(MakeFloat3)) @@ -1262,29 +1264,25 @@ def clamp(x, a, b): @overload(clamp, target="cuda") def jit_clamp(x, a, b): - if isinstance(x, types.Float): + if isinstance(x, types.Float) and isinstance(a, types.Float) and isinstance(b, types.Float): def clamp_float_impl(x, a, b): return max(a, min(x, b)) return clamp_float_impl - elif isinstance(x, Float3): + elif isinstance(x, Float3) and isinstance(a, types.Float) and isinstance(b, types.Float): def clamp_float3_impl(x, a, b): return make_float3(clamp(x.x, a, b), clamp(x.y, a, b), clamp(x.z, a, b)) return clamp_float3_impl -# def dot(a, b): -# pass - -# @overload(dot, target="cuda") -# def jit_dot(a, b): -# if isinstance(a, Float3) and isinstance(b, Float3): -# def dot_float3_impl(a, b): -# return a.x * b.x + a.y * b.y + a.z * b.z -# return dot_float3_impl - -@cuda.jit(device=True) def dot(a, b): - return a.x * b.x + a.y * b.y + a.z * b.z + pass + +@overload(dot, target="cuda") +def jit_dot(a, b): + if isinstance(a, Float3) and isinstance(b, Float3): + def dot_float3_impl(a, b): + return a.x * b.x + a.y * b.y + a.z * b.z + return dot_float3_impl @cuda.jit(device=True) @@ -1299,7 +1297,7 @@ def normalize(v): def toSRGB(c): # Use float32 for constants invGamma = float32(1.0) / float32(2.4) - powed = make_float3(math.pow(c.x, invGamma), math.pow(c.x, invGamma), math.pow(c.x, invGamma)) + powed = make_float3(math.pow(c.x, invGamma), math.pow(c.y, invGamma), math.pow(c.z, invGamma)) return make_float3( float32(12.92) * c.x if c.x < float32(0.0031308) else float32(1.055) * powed.x - float32(0.055), float32(12.92) * c.y if c.y < float32(0.0031308) else float32(1.055) * powed.y - float32(0.055), @@ -1309,8 +1307,8 @@ def toSRGB(c): @cuda.jit(device=True) def quantizeUnsigned8Bits(x): x = clamp( x, float32(0.0), float32(1.0) ) - N, Np1 = 1 << 8 - 1, 1 << 8 - return uint8(min(uint8(x * float32(Np1)), uint8(N))) + N, Np1 = (1 << 8) - 1, 1 << 8 + return uint8(min(uint32(x * float32(Np1)), uint32(N))) @cuda.jit(device=True) def make_color(c): @@ -1330,10 +1328,10 @@ def computeRay(idx, dim, origin, direction): U = params.cam_u V = params.cam_v W = params.cam_w - d = types.float32(2.0) * make_float2( - types.float32(idx.x) / types.float32(dim.x), - types.float32(idx.y) / types.float32(dim.y) - ) - types.float32(1.0) + d = float32(2.0) * make_float2( + float32(idx.x) / float32(dim.x), + float32(idx.y) / float32(dim.y) + ) - float32(1.0) origin = params.cam_eye direction = normalize(d.x * U + d.y * V + W) @@ -1351,16 +1349,16 @@ def __raygen__rg(): computeRay(make_uint3(idx.x, idx.y, 0), dim, ray_origin, ray_direction) # Trace the ray against our scene hierarchy - p0 = cuda.local.array(1, types.int32) - p1 = cuda.local.array(1, types.int32) - p2 = cuda.local.array(1, types.int32) + p0 = cuda.local.array(1, int32) + p1 = cuda.local.array(1, int32) + p2 = cuda.local.array(1, int32) optix.Trace( params.handle, ray_origin, ray_direction, - types.float32(0.0), # Min intersection distance - types.float32(1e16), # Max intersection distance - types.float32(0.0), # rayTime -- used for motion blur + float32(0.0), # Min intersection distance + float32(1e16), # Max intersection distance + float32(0.0), # rayTime -- used for motion blur OptixVisibilityMask(255), # Specify always visible # OptixRayFlags.OPTIX_RAY_FLAG_NONE, uint32(OPTIX_RAY_FLAG_NONE), @@ -1401,6 +1399,8 @@ def main(): # triangle_ptx = compile_cuda( "examples/triangle.cu" ) + print(raygen_ptx) + init_optix() ctx = create_ctx() From 7332313d5abf15dd95c7b27a8cfb0cadbf8d295e Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Tue, 18 Jan 2022 17:18:20 -0800 Subject: [PATCH 07/25] triangle! --- examples/triangle.py | 72 ++++++++++++++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 23 deletions(-) diff --git a/examples/triangle.py b/examples/triangle.py index fbb7dbd..8c0576a 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -406,6 +406,32 @@ def lower_make_uint3(context, builder, sig, args): return u4_3._getvalue() +# Temporary Payload Parameter Pack +class PayloadPack(types.Type): + def __init__(self): + super().__init__(name="PayloadPack") + + +payload_pack = PayloadPack() + + +# UInt3 data model + +@register_model(PayloadPack) +class PayloadPackModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ('p0', types.uint32), + ('p1', types.uint32), + ('p2', types.uint32), + ] + super().__init__(dmm, fe_type, members) + + +make_attribute_wrapper(PayloadPack, 'p0', 'p0') +make_attribute_wrapper(PayloadPack, 'p1', 'p1') +make_attribute_wrapper(PayloadPack, 'p2', 'p2') + # OptiX typedefs and enums # ----------- @@ -672,7 +698,7 @@ class OptixGetTriangleBarycentrics(ConcreteTemplate): class OptixTrace(ConcreteTemplate): key = optix.Trace cases = [signature( - types.void, + payload_pack, OptixTraversableHandle, float3, float3, @@ -684,9 +710,6 @@ class OptixTrace(ConcreteTemplate): uint32, uint32, uint32, - uint32, # payload register 0 - uint32, # payload register 1 - uint32, # payload register 2 )] @@ -800,25 +823,23 @@ def lower_optix_getTriangleBarycentrics(context, builder, sig, args): uint32, uint32, uint32, - uint32, # payload register 0 - uint32, # payload register 1 - uint32, # payload register 2 ) def lower_optix_Trace(context, builder, sig, args): # Only implements the version that accepts 3 payload registers (handle, rayOrigin, rayDirection, tmin, tmax, rayTime, visibilityMask, - rayFlags, SBToffset, SBTstride, missSBTIndex, p0, p1, p2) = args + rayFlags, SBToffset, SBTstride, missSBTIndex) = args rayOrigin = cgutils.create_struct_proxy(float3)(context, builder, rayOrigin) rayDirection = cgutils.create_struct_proxy(float3)(context, builder, rayDirection) + output = cgutils.create_struct_proxy(payload_pack)(context, builder) ox, oy, oz = rayOrigin.x, rayOrigin.y, rayOrigin.z dx, dy, dz = rayDirection.x, rayDirection.y, rayDirection.z n_payload_registers = 3 n_stub_output_operands = 32 - n_payload_registers - outputs = ([p0, p1, p2] + + outputs = ([output.p0, output.p1, output.p2] + [builder.load(builder.alloca(ir.IntType(32))) for _ in range(n_stub_output_operands)]) @@ -840,7 +861,11 @@ def lower_optix_Trace(context, builder, sig, args): args = [zero, handle, ox, oy, oz, dx, dy, dz, tmin, tmax, rayTime, visibilityMask, rayFlags, SBToffset, SBTstride, missSBTIndex, c_payload_registers] + outputs - return builder.call(asm, args) + ret = builder.call(asm, args) + output.p0 = builder.extract_value(ret, 0) + output.p1 = builder.extract_value(ret, 1) + output.p2 = builder.extract_value(ret, 2) + return output._getvalue() #------------------------------------------------------------------------------- @@ -1319,15 +1344,16 @@ def make_color(c): @cuda.jit(device=True) def setPayload(p): - optix.SetPayload_0(uint32(p.x)) - optix.SetPayload_1(uint32(p.y)) - optix.SetPayload_2(uint32(p.z)) + optix.SetPayload_0(cuda.libdevice.float_as_int(p.x)) + optix.SetPayload_1(cuda.libdevice.float_as_int(p.y)) + optix.SetPayload_2(cuda.libdevice.float_as_int(p.z)) @cuda.jit(device=True) -def computeRay(idx, dim, origin, direction): +def computeRay(idx, dim): U = params.cam_u V = params.cam_v W = params.cam_w + # Normalizing coordinates to [-1.0, 1.0] d = float32(2.0) * make_float2( float32(idx.x) / float32(dim.x), float32(idx.y) / float32(dim.y) @@ -1335,6 +1361,7 @@ def computeRay(idx, dim, origin, direction): origin = params.cam_eye direction = normalize(d.x * U + d.y * V + W) + return origin, direction def __raygen__rg(): @@ -1344,15 +1371,10 @@ def __raygen__rg(): # Map our launch idx to a screen location and create a ray from the camera # location through the screen - ray_origin = make_float3(float32(0.0), float32(0.0), float32(0.0)) - ray_direction = make_float3(float32(0.0), float32(0.0), float32(0.0)) - computeRay(make_uint3(idx.x, idx.y, 0), dim, ray_origin, ray_direction) + ray_origin, ray_direction = computeRay(make_uint3(idx.x, idx.y, 0), dim) # Trace the ray against our scene hierarchy - p0 = cuda.local.array(1, int32) - p1 = cuda.local.array(1, int32) - p2 = cuda.local.array(1, int32) - optix.Trace( + payload_pack = optix.Trace( params.handle, ray_origin, ray_direction, @@ -1365,8 +1387,12 @@ def __raygen__rg(): uint32(0), # SBT offset -- See SBT discussion uint32(1), # SBT stride -- See SBT discussion uint32(0), # missSBTIndex -- See SBT discussion - p0[0], p1[0], p2[0]) - result = make_float3(p0[0], p1[0], p2[0]) + ) + result = make_float3( + cuda.libdevice.int_as_float(payload_pack.p0), + cuda.libdevice.int_as_float(payload_pack.p1), + cuda.libdevice.int_as_float(payload_pack.p2) + ) # Record results in our output raster params.image[idx.y * params.image_width + idx.x] = make_color( result ) From cad88efb1a4f8c990f2a5d0ee5a21aa2cc973174 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Tue, 18 Jan 2022 18:33:51 -0800 Subject: [PATCH 08/25] Background! --- examples/triangle.py | 31 +++++++++++++++++++++++++------ 1 file changed, 25 insertions(+), 6 deletions(-) diff --git a/examples/triangle.py b/examples/triangle.py index 8c0576a..feb78c7 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -596,19 +596,37 @@ def typer(sbt_data_pointer): @lower(MissDataStruct, sbt_data_pointer) def lower_miss_data_ctor(context, builder, sig, args): # Anyway to err if this ctor is not called inside __miss__* program? + # TODO: Optimize ptr = args[0] ptr = builder.bitcast(ptr, context.get_value_type(miss_data_type).as_pointer()) - miss_data = cgutils.create_struct_proxy(miss_data_type)(context, builder) + bg_color_ptr = cgutils.gep_inbounds(builder, ptr, 0, 0) xptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 0) yptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 1) zptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 2) - miss_data.bg_color.x = builder.load(xptr) - miss_data.bg_color.y = builder.load(yptr) - miss_data.bg_color.z = builder.load(zptr) - return miss_data._getvalue() + + output_miss_data = cgutils.create_struct_proxy(miss_data_type)(context, builder) + output_bg_color_ptr = cgutils.gep_inbounds(builder, output_miss_data._getpointer(), 0, 0) + output_bg_color_x_ptr = cgutils.gep_inbounds(builder, output_bg_color_ptr, 0, 0) + output_bg_color_y_ptr = cgutils.gep_inbounds(builder, output_bg_color_ptr, 0, 1) + output_bg_color_z_ptr = cgutils.gep_inbounds(builder, output_bg_color_ptr, 0, 2) + + x = builder.load(xptr) + y = builder.load(yptr) + z = builder.load(zptr) + + builder.store(x, output_bg_color_x_ptr) + builder.store(y, output_bg_color_y_ptr) + builder.store(z, output_bg_color_z_ptr) + + + # Doesn't seem to do what's expected? + # miss_data.bg_color.x = builder.load(xptr) + # miss_data.bg_color.y = builder.load(yptr) + # miss_data.bg_color.z = builder.load(zptr) + return output_miss_data._getvalue() # OptiX functions @@ -826,6 +844,7 @@ def lower_optix_getTriangleBarycentrics(context, builder, sig, args): ) def lower_optix_Trace(context, builder, sig, args): # Only implements the version that accepts 3 payload registers + # TODO: Optimize returns (handle, rayOrigin, rayDirection, tmin, tmax, rayTime, visibilityMask, rayFlags, SBToffset, SBTstride, missSBTIndex) = args @@ -1425,7 +1444,7 @@ def main(): # triangle_ptx = compile_cuda( "examples/triangle.cu" ) - print(raygen_ptx) + print(miss_ptx) init_optix() From 7244cb98a83591bfec50118ed3c372b2703a5a18 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Wed, 19 Jan 2022 13:09:01 -0800 Subject: [PATCH 09/25] First pass style format --- examples/triangle.py | 1027 +++++++++++++++++++++++------------------- 1 file changed, 569 insertions(+), 458 deletions(-) diff --git a/examples/triangle.py b/examples/triangle.py index feb78c7..aa73976 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -3,23 +3,28 @@ import ctypes # C interop helpers import math -from operator import add, mul, sub from enum import Enum +from operator import add, mul, sub import cupy as cp # CUDA bindings import numpy as np # Packing of structures in C-compatible format -import optix - from llvmlite import ir - -from numba import cuda, float32, types, uint8, uint32, int32 +from numba import cuda, float32, int32, types, uint8, uint32 from numba.core import cgutils -from numba.core.extending import (make_attribute_wrapper, models, overload, - register_model, typeof_impl, type_callable) +from numba.core.extending import ( + make_attribute_wrapper, + models, + overload, + register_model, + type_callable, + typeof_impl, +) from numba.core.imputils import lower_constant -from numba.core.typing.templates import (AttributeTemplate, ConcreteTemplate, - signature) - +from numba.core.typing.templates import ( + AttributeTemplate, + ConcreteTemplate, + signature, +) from numba.cuda import get_current_device from numba.cuda.compiler import compile_cuda as numba_compile_cuda from numba.cuda.cudadecl import register, register_attr, register_global @@ -28,6 +33,8 @@ from numba.cuda.types import dim3 from PIL import Image, ImageOps # Image IO +import optix + # ------------------------------------------------------------------------------- # # Numba extensions for general CUDA / OptiX support @@ -43,12 +50,14 @@ # Prototype a function to construct a uchar4 + def make_uchar4(x, y, z, w): pass # UChar4 typing + class UChar4(types.Type): def __init__(self): super().__init__(name="UChar4") @@ -60,8 +69,7 @@ def __init__(self): @register class MakeUChar4(ConcreteTemplate): key = make_uchar4 - cases = [signature(uchar4, types.uchar, types.uchar, types.uchar, - types.uchar)] + cases = [signature(uchar4, types.uchar, types.uchar, types.uchar, types.uchar)] register_global(make_uchar4, types.Function(MakeUChar4)) @@ -69,26 +77,28 @@ class MakeUChar4(ConcreteTemplate): # UChar4 data model + @register_model(UChar4) class UChar4Model(models.StructModel): def __init__(self, dmm, fe_type): members = [ - ('x', types.uchar), - ('y', types.uchar), - ('z', types.uchar), - ('w', types.uchar), + ("x", types.uchar), + ("y", types.uchar), + ("z", types.uchar), + ("w", types.uchar), ] super().__init__(dmm, fe_type, members) -make_attribute_wrapper(UChar4, 'x', 'x') -make_attribute_wrapper(UChar4, 'y', 'y') -make_attribute_wrapper(UChar4, 'z', 'z') -make_attribute_wrapper(UChar4, 'w', 'w') +make_attribute_wrapper(UChar4, "x", "x") +make_attribute_wrapper(UChar4, "y", "y") +make_attribute_wrapper(UChar4, "z", "z") +make_attribute_wrapper(UChar4, "w", "w") # UChar4 lowering + @lower(make_uchar4, types.uchar, types.uchar, types.uchar, types.uchar) def lower_make_uchar4(context, builder, sig, args): uc4 = cgutils.create_struct_proxy(uchar4)(context, builder) @@ -104,6 +114,7 @@ def lower_make_uchar4(context, builder, sig, args): # Float3 typing + class Float3(types.Type): def __init__(self): super().__init__(name="Float3") @@ -114,6 +125,7 @@ def __init__(self): # Float2 typing (forward declaration) + class Float2(types.Type): def __init__(self): super().__init__(name="Float2") @@ -124,20 +136,21 @@ def __init__(self): # Float3 data model + @register_model(Float3) class Float3Model(models.StructModel): def __init__(self, dmm, fe_type): members = [ - ('x', types.float32), - ('y', types.float32), - ('z', types.float32), + ("x", types.float32), + ("y", types.float32), + ("z", types.float32), ] super().__init__(dmm, fe_type, members) -make_attribute_wrapper(Float3, 'x', 'x') -make_attribute_wrapper(Float3, 'y', 'y') -make_attribute_wrapper(Float3, 'z', 'z') +make_attribute_wrapper(Float3, "x", "x") +make_attribute_wrapper(Float3, "y", "y") +make_attribute_wrapper(Float3, "z", "z") def lower_float3_ops(op): @@ -146,17 +159,21 @@ class Float3_op_template(ConcreteTemplate): cases = [ signature(float3, float3, float3), signature(float3, types.float32, float3), - signature(float3, float3, types.float32) + signature(float3, float3, types.float32), ] def float3_op_impl(context, builder, sig, args): def op_attr(lhs, rhs, res, attr): - setattr(res, attr, context.compile_internal( - builder, - lambda x, y: op(x, y), - signature(types.float32, types.float32, types.float32), - (getattr(lhs, attr), getattr(rhs, attr)) - )) + setattr( + res, + attr, + context.compile_internal( + builder, + lambda x, y: op(x, y), + signature(types.float32, types.float32, types.float32), + (getattr(lhs, attr), getattr(rhs, attr)), + ), + ) arg0, arg1 = args @@ -166,8 +183,7 @@ def op_attr(lhs, rhs, res, attr): lf3.y = arg0 lf3.z = arg0 else: - lf3 = cgutils.create_struct_proxy(float3)(context, builder, - value=args[0]) + lf3 = cgutils.create_struct_proxy(float3)(context, builder, value=args[0]) if isinstance(sig.args[1], types.Float): rf3 = cgutils.create_struct_proxy(float3)(context, builder) @@ -175,13 +191,12 @@ def op_attr(lhs, rhs, res, attr): rf3.y = arg1 rf3.z = arg1 else: - rf3 = cgutils.create_struct_proxy(float3)(context, builder, - value=args[1]) + rf3 = cgutils.create_struct_proxy(float3)(context, builder, value=args[1]) res = cgutils.create_struct_proxy(float3)(context, builder) - op_attr(lf3, rf3, res, 'x') - op_attr(lf3, rf3, res, 'y') - op_attr(lf3, rf3, res, 'z') + op_attr(lf3, rf3, res, "x") + op_attr(lf3, rf3, res, "y") + op_attr(lf3, rf3, res, "z") return res._getvalue() register_global(op, types.Function(Float3_op_template)) @@ -204,6 +219,7 @@ def add_float32_float3_impl(context, builder, sig, args): res.z = builder.fadd(s, rhs.z) return res._getvalue() + @lower(add, float3, float32) def add_float3_float32_impl(context, builder, sig, args): lhs = cgutils.create_struct_proxy(float3)(context, builder, args[0]) @@ -214,8 +230,10 @@ def add_float3_float32_impl(context, builder, sig, args): res.z = builder.fadd(lhs.z, s) return res._getvalue() + # Prototype a function to construct a float3 + def make_float3(x, y, z): pass @@ -225,7 +243,7 @@ class MakeFloat3(ConcreteTemplate): key = make_float3 cases = [ signature(float3, types.float32, types.float32, types.float32), - signature(float3, float2, types.float32) + signature(float3, float2, types.float32), ] @@ -234,6 +252,7 @@ class MakeFloat3(ConcreteTemplate): # make_float3 lowering + @lower(make_float3, types.float32, types.float32, types.float32) def lower_make_float3(context, builder, sig, args): f3 = cgutils.create_struct_proxy(float3)(context, builder) @@ -257,21 +276,21 @@ def lower_make_float3(context, builder, sig, args): # ------ - # Float2 data model + @register_model(Float2) class Float2Model(models.StructModel): def __init__(self, dmm, fe_type): members = [ - ('x', types.float32), - ('y', types.float32), + ("x", types.float32), + ("y", types.float32), ] super().__init__(dmm, fe_type, members) -make_attribute_wrapper(Float2, 'x', 'x') -make_attribute_wrapper(Float2, 'y', 'y') +make_attribute_wrapper(Float2, "x", "x") +make_attribute_wrapper(Float2, "y", "y") def lower_float2_ops(op): @@ -280,17 +299,21 @@ class Float2_op_template(ConcreteTemplate): cases = [ signature(float2, float2, float2), signature(float2, types.float32, float2), - signature(float2, float2, types.float32) + signature(float2, float2, types.float32), ] def float2_op_impl(context, builder, sig, args): def op_attr(lhs, rhs, res, attr): - setattr(res, attr, context.compile_internal( - builder, - lambda x, y: op(x, y), - signature(types.float32, types.float32, types.float32), - (getattr(lhs, attr), getattr(rhs, attr)) - )) + setattr( + res, + attr, + context.compile_internal( + builder, + lambda x, y: op(x, y), + signature(types.float32, types.float32, types.float32), + (getattr(lhs, attr), getattr(rhs, attr)), + ), + ) arg0, arg1 = args @@ -299,20 +322,18 @@ def op_attr(lhs, rhs, res, attr): lf2.x = arg0 lf2.y = arg0 else: - lf2 = cgutils.create_struct_proxy(float2)(context, builder, - value=args[0]) + lf2 = cgutils.create_struct_proxy(float2)(context, builder, value=args[0]) if isinstance(sig.args[1], types.Float): rf2 = cgutils.create_struct_proxy(float2)(context, builder) rf2.x = arg1 rf2.y = arg1 else: - rf2 = cgutils.create_struct_proxy(float2)(context, builder, - value=args[1]) + rf2 = cgutils.create_struct_proxy(float2)(context, builder, value=args[1]) res = cgutils.create_struct_proxy(float2)(context, builder) - op_attr(lf2, rf2, res, 'x') - op_attr(lf2, rf2, res, 'y') + op_attr(lf2, rf2, res, "x") + op_attr(lf2, rf2, res, "y") return res._getvalue() register_global(op, types.Function(Float2_op_template)) @@ -327,6 +348,7 @@ def op_attr(lhs, rhs, res, attr): # Prototype a function to construct a float2 + def make_float2(x, y): pass @@ -342,6 +364,7 @@ class MakeFloat2(ConcreteTemplate): # make_float2 lowering + @lower(make_float2, types.float32, types.float32) def lower_make_float2(context, builder, sig, args): f2 = cgutils.create_struct_proxy(float2)(context, builder) @@ -353,6 +376,7 @@ def lower_make_float2(context, builder, sig, args): # uint3 # ------ + class UInt3(types.Type): def __init__(self): super().__init__(name="UInt3") @@ -363,24 +387,26 @@ def __init__(self): # UInt3 data model + @register_model(UInt3) class UInt3Model(models.StructModel): def __init__(self, dmm, fe_type): members = [ - ('x', types.uint32), - ('y', types.uint32), - ('z', types.uint32), + ("x", types.uint32), + ("y", types.uint32), + ("z", types.uint32), ] super().__init__(dmm, fe_type, members) -make_attribute_wrapper(UInt3, 'x', 'x') -make_attribute_wrapper(UInt3, 'y', 'y') -make_attribute_wrapper(UInt3, 'z', 'z') +make_attribute_wrapper(UInt3, "x", "x") +make_attribute_wrapper(UInt3, "y", "y") +make_attribute_wrapper(UInt3, "z", "z") # Prototype a function to construct a uint3 + def make_uint3(x, y, z): pass @@ -396,6 +422,7 @@ class MakeUInt3(ConcreteTemplate): # make_uint3 lowering + @lower(make_uint3, types.uint32, types.uint32, types.uint32) def lower_make_uint3(context, builder, sig, args): # u4 = uint32 @@ -417,28 +444,29 @@ def __init__(self): # UInt3 data model + @register_model(PayloadPack) class PayloadPackModel(models.StructModel): def __init__(self, dmm, fe_type): members = [ - ('p0', types.uint32), - ('p1', types.uint32), - ('p2', types.uint32), + ("p0", types.uint32), + ("p1", types.uint32), + ("p2", types.uint32), ] super().__init__(dmm, fe_type, members) -make_attribute_wrapper(PayloadPack, 'p0', 'p0') -make_attribute_wrapper(PayloadPack, 'p1', 'p1') -make_attribute_wrapper(PayloadPack, 'p2', 'p2') +make_attribute_wrapper(PayloadPack, "p0", "p0") +make_attribute_wrapper(PayloadPack, "p1", "p1") +make_attribute_wrapper(PayloadPack, "p2", "p2") # OptiX typedefs and enums # ----------- -OptixVisibilityMask = types.Integer('OptixVisibilityMask', bitwidth=32, - signed=False) -OptixTraversableHandle = types.Integer('OptixTraversableHandle', bitwidth=64, - signed=False) +OptixVisibilityMask = types.Integer("OptixVisibilityMask", bitwidth=32, signed=False) +OptixTraversableHandle = types.Integer( + "OptixTraversableHandle", bitwidth=64, signed=False +) OPTIX_RAY_FLAG_NONE = 0 @@ -459,6 +487,7 @@ def __init__(self, dmm, fe_type): # Typing for OptiX types + class SbtDataPointer(types.RawPointer): def __init__(self): super().__init__(name="SbtDataPointer") @@ -469,6 +498,7 @@ def __init__(self): # Models for OptiX types + @register_model(SbtDataPointer) class SbtDataPointerModel(models.OpaqueModel): pass @@ -479,25 +509,27 @@ class SbtDataPointerModel(models.OpaqueModel): # Structures as declared in triangle.h + class ParamsStruct: fields = ( - ('image', 'uchar4*'), - ('image_width', 'unsigned int'), - ('image_height', 'unsigned int'), - ('cam_eye', 'float3'), - ('cam_u', 'float3'), - ('cam_v', 'float3'), - ('cam_w', 'float3'), - ('handle', 'OptixTraversableHandle'), + ("image", "uchar4*"), + ("image_width", "unsigned int"), + ("image_height", "unsigned int"), + ("cam_eye", "float3"), + ("cam_u", "float3"), + ("cam_v", "float3"), + ("cam_w", "float3"), + ("handle", "OptixTraversableHandle"), ) # "Declare" a global called params params = ParamsStruct() + class Params(types.Type): def __init__(self): - super().__init__(name='ParamsType') + super().__init__(name="ParamsType") params_type = Params() @@ -505,30 +537,31 @@ def __init__(self): # ParamsStruct data model + @register_model(Params) class ParamsModel(models.StructModel): def __init__(self, dmm, fe_type): members = [ - ('image', types.CPointer(uchar4)), - ('image_width', types.uint32), - ('image_height', types.uint32), - ('cam_eye', float3), - ('cam_u', float3), - ('cam_v', float3), - ('cam_w', float3), - ('handle', OptixTraversableHandle), + ("image", types.CPointer(uchar4)), + ("image_width", types.uint32), + ("image_height", types.uint32), + ("cam_eye", float3), + ("cam_u", float3), + ("cam_v", float3), + ("cam_w", float3), + ("handle", OptixTraversableHandle), ] super().__init__(dmm, fe_type, members) -make_attribute_wrapper(Params, 'image', 'image') -make_attribute_wrapper(Params, 'image_width', 'image_width') -make_attribute_wrapper(Params, 'image_height', 'image_height') -make_attribute_wrapper(Params, 'cam_eye', 'cam_eye') -make_attribute_wrapper(Params, 'cam_u', 'cam_u') -make_attribute_wrapper(Params, 'cam_v', 'cam_v') -make_attribute_wrapper(Params, 'cam_w', 'cam_w') -make_attribute_wrapper(Params, 'handle', 'handle') +make_attribute_wrapper(Params, "image", "image") +make_attribute_wrapper(Params, "image_width", "image_width") +make_attribute_wrapper(Params, "image_height", "image_height") +make_attribute_wrapper(Params, "cam_eye", "cam_eye") +make_attribute_wrapper(Params, "cam_u", "cam_u") +make_attribute_wrapper(Params, "cam_v", "cam_v") +make_attribute_wrapper(Params, "cam_w", "cam_w") +make_attribute_wrapper(Params, "handle", "handle") @typeof_impl.register(ParamsStruct) @@ -540,15 +573,17 @@ def typeof_params(val, c): # The below makes 'param' a global variable, accessible from any user defined # kernels. + @lower_constant(Params) def constant_params(context, builder, ty, pyval): try: - gvar = builder.module.get_global('params') + gvar = builder.module.get_global("params") except KeyError: llty = context.get_value_type(ty) - gvar = cgutils.add_global_variable(builder.module, llty, 'params', - addrspace=nvvm.ADDRSPACE_CONSTANT) - gvar.linkage = 'external' + gvar = cgutils.add_global_variable( + builder.module, llty, "params", addrspace=nvvm.ADDRSPACE_CONSTANT + ) + gvar.linkage = "external" gvar.global_constant = True return builder.load(gvar) @@ -559,47 +594,53 @@ def constant_params(context, builder, ty, pyval): # Structures as declared in triangle.h class MissDataStruct: - fields = ( - ('bg_color', 'float3') - ) + fields = ("bg_color", "float3") + MissData = MissDataStruct() + class MissData(types.Type): def __init__(self): - super().__init__(name='MissDataType') + super().__init__(name="MissDataType") + miss_data_type = MissData() + @register_model(MissData) class MissDataModel(models.StructModel): def __init__(self, dmm, fe_type): members = [ - ('bg_color', float3), + ("bg_color", float3), ] super().__init__(dmm, fe_type, members) -make_attribute_wrapper(MissData, 'bg_color', 'bg_color') + +make_attribute_wrapper(MissData, "bg_color", "bg_color") + @typeof_impl.register(MissDataStruct) def typeof_miss_data(val, c): return miss_data_type + # MissData Constructor @type_callable(MissDataStruct) def type_miss_data_struct(context): def typer(sbt_data_pointer): if isinstance(sbt_data_pointer, SbtDataPointer): return miss_data_type + return typer + @lower(MissDataStruct, sbt_data_pointer) def lower_miss_data_ctor(context, builder, sig, args): # Anyway to err if this ctor is not called inside __miss__* program? # TODO: Optimize ptr = args[0] - ptr = builder.bitcast(ptr, - context.get_value_type(miss_data_type).as_pointer()) + ptr = builder.bitcast(ptr, context.get_value_type(miss_data_type).as_pointer()) bg_color_ptr = cgutils.gep_inbounds(builder, ptr, 0, 0) @@ -608,7 +649,9 @@ def lower_miss_data_ctor(context, builder, sig, args): zptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 2) output_miss_data = cgutils.create_struct_proxy(miss_data_type)(context, builder) - output_bg_color_ptr = cgutils.gep_inbounds(builder, output_miss_data._getpointer(), 0, 0) + output_bg_color_ptr = cgutils.gep_inbounds( + builder, output_miss_data._getpointer(), 0, 0 + ) output_bg_color_x_ptr = cgutils.gep_inbounds(builder, output_bg_color_ptr, 0, 0) output_bg_color_y_ptr = cgutils.gep_inbounds(builder, output_bg_color_ptr, 0, 1) output_bg_color_z_ptr = cgutils.gep_inbounds(builder, output_bg_color_ptr, 0, 2) @@ -620,7 +663,6 @@ def lower_miss_data_ctor(context, builder, sig, args): builder.store(x, output_bg_color_x_ptr) builder.store(y, output_bg_color_y_ptr) builder.store(z, output_bg_color_z_ptr) - # Doesn't seem to do what's expected? # miss_data.bg_color.x = builder.load(xptr) @@ -635,6 +677,7 @@ def lower_miss_data_ctor(context, builder, sig, args): # Here we "prototype" the OptiX functions that the user will call in their # kernels, so that Numba has something to refer to when compiling the kernel. + def _optix_GetLaunchIndex(): pass @@ -646,18 +689,23 @@ def _optix_GetLaunchDimensions(): def _optix_GetSbtDataPointer(): pass + def _optix_SetPayload_0(): pass + def _optix_SetPayload_1(): pass + def _optix_SetPayload_2(): pass + def _optix_GetTriangleBarycentrics(): pass + def _optix_Trace(): pass @@ -679,6 +727,7 @@ def _optix_Trace(): # OptiX function typing + @register class OptixGetLaunchIndex(ConcreteTemplate): key = optix.GetLaunchIndex @@ -696,39 +745,46 @@ class OptixGetSbtDataPointer(ConcreteTemplate): key = optix.GetSbtDataPointer cases = [signature(sbt_data_pointer)] + def registerSetPayload(reg): class OptixSetPayloadReg(ConcreteTemplate): - key = getattr(optix, 'SetPayload_' + str(reg)) + key = getattr(optix, "SetPayload_" + str(reg)) cases = [signature(types.void, uint32)] + register(OptixSetPayloadReg) return OptixSetPayloadReg + OptixSetPayload_0 = registerSetPayload(0) OptixSetPayload_1 = registerSetPayload(1) OptixSetPayload_2 = registerSetPayload(2) + @register class OptixGetTriangleBarycentrics(ConcreteTemplate): key = optix.GetTriangleBarycentrics cases = [signature(float2)] + @register class OptixTrace(ConcreteTemplate): key = optix.Trace - cases = [signature( - payload_pack, - OptixTraversableHandle, - float3, - float3, - float32, - float32, - float32, - OptixVisibilityMask, - uint32, - uint32, - uint32, - uint32, - )] + cases = [ + signature( + payload_pack, + OptixTraversableHandle, + float3, + float3, + float32, + float32, + float32, + OptixVisibilityMask, + uint32, + uint32, + uint32, + uint32, + ) + ] @register_attr @@ -743,60 +799,67 @@ def resolve_GetLaunchDimensions(self, mod): def resolve_GetSbtDataPointer(self, mod): return types.Function(OptixGetSbtDataPointer) - + def resolve_SetPayload_0(self, mod): return types.Function(OptixSetPayload_0) - + def resolve_SetPayload_1(self, mod): return types.Function(OptixSetPayload_1) - + def resolve_SetPayload_2(self, mod): return types.Function(OptixSetPayload_2) - + def resolve_GetTriangleBarycentrics(self, mod): return types.Function(OptixGetTriangleBarycentrics) - + def resolve_Trace(self, mod): return types.Function(OptixTrace) # OptiX function lowering + @lower(optix.GetLaunchIndex) def lower_optix_getLaunchIndex(context, builder, sig, args): def get_launch_index(axis): - asm = ir.InlineAsm(ir.FunctionType(ir.IntType(32), []), - f"call ($0), _optix_get_launch_index_{axis}, ();", - "=r") + asm = ir.InlineAsm( + ir.FunctionType(ir.IntType(32), []), + f"call ($0), _optix_get_launch_index_{axis}, ();", + "=r", + ) return builder.call(asm, []) index = cgutils.create_struct_proxy(dim3)(context, builder) - index.x = get_launch_index('x') - index.y = get_launch_index('y') - index.z = get_launch_index('z') + index.x = get_launch_index("x") + index.y = get_launch_index("y") + index.z = get_launch_index("z") return index._getvalue() @lower(optix.GetLaunchDimensions) def lower_optix_getLaunchDimensions(context, builder, sig, args): def get_launch_dimensions(axis): - asm = ir.InlineAsm(ir.FunctionType(ir.IntType(32), []), - f"call ($0), _optix_get_launch_dimension_{axis}, ();", - "=r") + asm = ir.InlineAsm( + ir.FunctionType(ir.IntType(32), []), + f"call ($0), _optix_get_launch_dimension_{axis}, ();", + "=r", + ) return builder.call(asm, []) index = cgutils.create_struct_proxy(dim3)(context, builder) - index.x = get_launch_dimensions('x') - index.y = get_launch_dimensions('y') - index.z = get_launch_dimensions('z') + index.x = get_launch_dimensions("x") + index.y = get_launch_dimensions("y") + index.z = get_launch_dimensions("z") return index._getvalue() @lower(optix.GetSbtDataPointer) def lower_optix_getSbtDataPointer(context, builder, sig, args): - asm = ir.InlineAsm(ir.FunctionType(ir.IntType(64), []), - "call ($0), _optix_get_sbt_data_ptr_64, ();", - "=l") + asm = ir.InlineAsm( + ir.FunctionType(ir.IntType(64), []), + "call ($0), _optix_get_sbt_data_ptr_64, ();", + "=l", + ) ptr = builder.call(asm, []) ptr = builder.inttoptr(ptr, ir.IntType(8).as_pointer()) return ptr @@ -804,24 +867,29 @@ def lower_optix_getSbtDataPointer(context, builder, sig, args): def lower_optix_SetPayloadReg(reg): def lower_optix_SetPayload_impl(context, builder, sig, args): - asm = ir.InlineAsm(ir.FunctionType(ir.VoidType(), [ir.IntType(32), ir.IntType(32)]), + asm = ir.InlineAsm( + ir.FunctionType(ir.VoidType(), [ir.IntType(32), ir.IntType(32)]), f"call _optix_set_payload, ($0, $1);", - "r,r") + "r,r", + ) builder.call(asm, [context.get_constant(types.int32, reg), args[0]]) + lower(getattr(optix, f"SetPayload_{reg}"), uint32)(lower_optix_SetPayload_impl) + lower_optix_SetPayloadReg(0) lower_optix_SetPayloadReg(1) lower_optix_SetPayloadReg(2) + @lower(optix.GetTriangleBarycentrics) def lower_optix_getTriangleBarycentrics(context, builder, sig, args): f2 = cgutils.create_struct_proxy(float2)(context, builder) retty = ir.LiteralStructType([ir.FloatType(), ir.FloatType()]) asm = ir.InlineAsm( - ir.FunctionType(retty, []), + ir.FunctionType(retty, []), "call ($0, $1), _optix_get_triangle_barycentrics, ();", - "=f,=f" + "=f,=f", ) ret = builder.call(asm, []) f2.x = builder.extract_value(ret, 0) @@ -829,25 +897,37 @@ def lower_optix_getTriangleBarycentrics(context, builder, sig, args): return f2._getvalue() -@lower(optix.Trace, - OptixTraversableHandle, - float3, - float3, - float32, - float32, - float32, - OptixVisibilityMask, - uint32, - uint32, - uint32, - uint32, +@lower( + optix.Trace, + OptixTraversableHandle, + float3, + float3, + float32, + float32, + float32, + OptixVisibilityMask, + uint32, + uint32, + uint32, + uint32, ) def lower_optix_Trace(context, builder, sig, args): # Only implements the version that accepts 3 payload registers # TODO: Optimize returns - (handle, rayOrigin, rayDirection, tmin, tmax, rayTime, visibilityMask, - rayFlags, SBToffset, SBTstride, missSBTIndex) = args + ( + handle, + rayOrigin, + rayDirection, + tmin, + tmax, + rayTime, + visibilityMask, + rayFlags, + SBToffset, + SBTstride, + missSBTIndex, + ) = args rayOrigin = cgutils.create_struct_proxy(float3)(context, builder, rayOrigin) rayDirection = cgutils.create_struct_proxy(float3)(context, builder, rayDirection) @@ -858,28 +938,45 @@ def lower_optix_Trace(context, builder, sig, args): n_payload_registers = 3 n_stub_output_operands = 32 - n_payload_registers - outputs = ([output.p0, output.p1, output.p2] + - [builder.load(builder.alloca(ir.IntType(32))) - for _ in range(n_stub_output_operands)]) - + outputs = [output.p0, output.p1, output.p2] + [ + builder.load(builder.alloca(ir.IntType(32))) + for _ in range(n_stub_output_operands) + ] retty = ir.LiteralStructType([ir.IntType(32)] * 32) - asm = ir.InlineAsm(ir.FunctionType(retty, []), - "call " - "($0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29," - "$30,$31)," - "_optix_trace_typed_32," - "($32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59," - "$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80);", - "=r," * 32 + "r,l,f,f,f,f,f,f,f,f,f,r,r,r,r,r,r," + "r," * 31 + "r", - side_effect=True + asm = ir.InlineAsm( + ir.FunctionType(retty, []), + "call " + "($0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29," + "$30,$31)," + "_optix_trace_typed_32," + "($32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59," + "$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80);", + "=r," * 32 + "r,l,f,f,f,f,f,f,f,f,f,r,r,r,r,r,r," + "r," * 31 + "r", + side_effect=True, ) zero = context.get_constant(types.int32, 0) c_payload_registers = context.get_constant(types.int32, n_payload_registers) - args = [zero, handle, ox, oy, oz, dx, dy, dz, tmin, tmax, rayTime, - visibilityMask, rayFlags, SBToffset, SBTstride, - missSBTIndex, c_payload_registers] + outputs + args = [ + zero, + handle, + ox, + oy, + oz, + dx, + dy, + dz, + tmin, + tmax, + rayTime, + visibilityMask, + rayFlags, + SBToffset, + SBTstride, + missSBTIndex, + c_payload_registers, + ] + outputs ret = builder.call(asm, args) output.p0 = builder.extract_value(ret, 0) output.p1 = builder.extract_value(ret, 1) @@ -887,7 +984,7 @@ def lower_optix_Trace(context, builder, sig, args): return output._getvalue() -#------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------- # # Util # @@ -916,43 +1013,42 @@ def round_up(val, mult_of): def get_aligned_itemsize(formats, alignment): names = [] for i in range(len(formats)): - names.append( 'x'+str(i) ) + names.append("x" + str(i)) - temp_dtype = np.dtype( { - 'names' : names, - 'formats' : formats, - 'align' : True - } ) - return round_up( temp_dtype.itemsize, alignment ) + temp_dtype = np.dtype({"names": names, "formats": formats, "align": True}) + return round_up(temp_dtype.itemsize, alignment) -def array_to_device_memory( numpy_array, stream=cp.cuda.Stream() ): +def array_to_device_memory(numpy_array, stream=cp.cuda.Stream()): - byte_size = numpy_array.size*numpy_array.dtype.itemsize + byte_size = numpy_array.size * numpy_array.dtype.itemsize - h_ptr = ctypes.c_void_p( numpy_array.ctypes.data ) - d_mem = cp.cuda.memory.alloc( byte_size ) - d_mem.copy_from_async( h_ptr, byte_size, stream ) + h_ptr = ctypes.c_void_p(numpy_array.ctypes.data) + d_mem = cp.cuda.memory.alloc(byte_size) + d_mem.copy_from_async(h_ptr, byte_size, stream) return d_mem -def compile_cuda( cuda_file ): - with open( cuda_file, 'rb' ) as f: +def compile_cuda(cuda_file): + with open(cuda_file, "rb") as f: src = f.read() from pynvrtc.compiler import Program - prog = Program( src.decode(), cuda_file ) - ptx = prog.compile( [ - '-use_fast_math', - '-lineinfo', - '-default-device', - '-std=c++11', - '-rdc', - 'true', - #'-IC:\\ProgramData\\NVIDIA Corporation\OptiX SDK 7.2.0\include', - #'-IC:\\Program Files\\NVIDIA GPU Computing Toolkit\CUDA\\v11.1\include' - '-I/usr/local/cuda/include', - f'-I{optix.include_path}' - ] ) + + prog = Program(src.decode(), cuda_file) + ptx = prog.compile( + [ + "-use_fast_math", + "-lineinfo", + "-default-device", + "-std=c++11", + "-rdc", + "true", + #'-IC:\\ProgramData\\NVIDIA Corporation\OptiX SDK 7.2.0\include', + #'-IC:\\Program Files\\NVIDIA GPU Computing Toolkit\CUDA\\v11.1\include' + "-I/usr/local/cuda/include", + f"-I{optix.include_path}", + ] + ) return ptx @@ -962,299 +1058,279 @@ def compile_cuda( cuda_file ): # # ------------------------------------------------------------------------------- + def init_optix(): - print( "Initializing cuda ..." ) - cp.cuda.runtime.free( 0 ) + print("Initializing cuda ...") + cp.cuda.runtime.free(0) - print( "Initializing optix ..." ) + print("Initializing optix ...") optix.init() def create_ctx(): - print( "Creating optix device context ..." ) + print("Creating optix device context ...") # Note that log callback data is no longer needed. We can # instead send a callable class instance as the log-function # which stores any data needed global logger logger = Logger() - + # OptiX param struct fields can be set with optional # keyword constructor arguments. - ctx_options = optix.DeviceContextOptions( - logCallbackFunction = logger, - logCallbackLevel = 4 - ) + ctx_options = optix.DeviceContextOptions( + logCallbackFunction=logger, logCallbackLevel=4 + ) # They can also be set and queried as properties on the struct - ctx_options.validationMode = optix.DEVICE_CONTEXT_VALIDATION_MODE_ALL + ctx_options.validationMode = optix.DEVICE_CONTEXT_VALIDATION_MODE_ALL - cu_ctx = 0 - return optix.deviceContextCreate( cu_ctx, ctx_options ) + cu_ctx = 0 + return optix.deviceContextCreate(cu_ctx, ctx_options) -def create_accel( ctx ): - +def create_accel(ctx): + accel_options = optix.AccelBuildOptions( - buildFlags = int( optix.BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS), - operation = optix.BUILD_OPERATION_BUILD - ) + buildFlags=int(optix.BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS), + operation=optix.BUILD_OPERATION_BUILD, + ) global vertices - vertices = cp.array( [ - -0.5, -0.5, 0.0, - 0.5, -0.5, 0.0, - 0.0, 0.5, 0.0 - ], dtype = 'f4') - - triangle_input_flags = [ optix.GEOMETRY_FLAG_NONE ] + vertices = cp.array([-0.5, -0.5, 0.0, 0.5, -0.5, 0.0, 0.0, 0.5, 0.0], dtype="f4") + + triangle_input_flags = [optix.GEOMETRY_FLAG_NONE] triangle_input = optix.BuildInputTriangleArray() - triangle_input.vertexFormat = optix.VERTEX_FORMAT_FLOAT3 - triangle_input.numVertices = len( vertices ) - triangle_input.vertexBuffers = [ vertices.data.ptr ] - triangle_input.flags = triangle_input_flags - triangle_input.numSbtRecords = 1; - - gas_buffer_sizes = ctx.accelComputeMemoryUsage( [accel_options], [triangle_input] ) - - d_temp_buffer_gas = cp.cuda.alloc( gas_buffer_sizes.tempSizeInBytes ) - d_gas_output_buffer = cp.cuda.alloc( gas_buffer_sizes.outputSizeInBytes) - - gas_handle = ctx.accelBuild( - 0, # CUDA stream - [ accel_options ], - [ triangle_input ], + triangle_input.vertexFormat = optix.VERTEX_FORMAT_FLOAT3 + triangle_input.numVertices = len(vertices) + triangle_input.vertexBuffers = [vertices.data.ptr] + triangle_input.flags = triangle_input_flags + triangle_input.numSbtRecords = 1 + + gas_buffer_sizes = ctx.accelComputeMemoryUsage([accel_options], [triangle_input]) + + d_temp_buffer_gas = cp.cuda.alloc(gas_buffer_sizes.tempSizeInBytes) + d_gas_output_buffer = cp.cuda.alloc(gas_buffer_sizes.outputSizeInBytes) + + gas_handle = ctx.accelBuild( + 0, # CUDA stream + [accel_options], + [triangle_input], d_temp_buffer_gas.ptr, gas_buffer_sizes.tempSizeInBytes, d_gas_output_buffer.ptr, gas_buffer_sizes.outputSizeInBytes, - [] # emitted properties - ) + [], # emitted properties + ) return (gas_handle, d_gas_output_buffer) def set_pipeline_options(): return optix.PipelineCompileOptions( - usesMotionBlur = False, - traversableGraphFlags = - int( optix.TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS ), - numPayloadValues = 3, - numAttributeValues = 3, - exceptionFlags = int( optix.EXCEPTION_FLAG_NONE ), - pipelineLaunchParamsVariableName = "params", - usesPrimitiveTypeFlags = optix.PRIMITIVE_TYPE_FLAGS_TRIANGLE - ) + usesMotionBlur=False, + traversableGraphFlags=int(optix.TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS), + numPayloadValues=3, + numAttributeValues=3, + exceptionFlags=int(optix.EXCEPTION_FLAG_NONE), + pipelineLaunchParamsVariableName="params", + usesPrimitiveTypeFlags=optix.PRIMITIVE_TYPE_FLAGS_TRIANGLE, + ) -def create_module( ctx, pipeline_options, ptx ): - print( "Creating optix module ..." ) - +def create_module(ctx, pipeline_options, ptx): + print("Creating optix module ...") module_options = optix.ModuleCompileOptions( - maxRegisterCount = optix.COMPILE_DEFAULT_MAX_REGISTER_COUNT, - optLevel = optix.COMPILE_OPTIMIZATION_DEFAULT, - debugLevel = optix.COMPILE_DEBUG_LEVEL_LINEINFO + maxRegisterCount=optix.COMPILE_DEFAULT_MAX_REGISTER_COUNT, + optLevel=optix.COMPILE_OPTIMIZATION_DEFAULT, + debugLevel=optix.COMPILE_DEBUG_LEVEL_LINEINFO, ) - module, log = ctx.moduleCreateFromPTX( - module_options, - pipeline_options, - ptx - ) - print( "\tModule create log: <<<{}>>>".format( log ) ) + module, log = ctx.moduleCreateFromPTX(module_options, pipeline_options, ptx) + print("\tModule create log: <<<{}>>>".format(log)) return module -def create_program_groups( ctx, raygen_module, miss_prog_module, hitgroup_module ): - print( "Creating program groups ... " ) +def create_program_groups(ctx, raygen_module, miss_prog_module, hitgroup_module): + print("Creating program groups ... ") program_group_options = optix.ProgramGroupOptions() - raygen_prog_group_desc = optix.ProgramGroupDesc() - raygen_prog_group_desc.raygenModule = raygen_module - raygen_prog_group_desc.raygenEntryFunctionName = "__raygen__rg" + raygen_prog_group_desc = optix.ProgramGroupDesc() + raygen_prog_group_desc.raygenModule = raygen_module + raygen_prog_group_desc.raygenEntryFunctionName = "__raygen__rg" - miss_prog_group_desc = optix.ProgramGroupDesc() - miss_prog_group_desc.missModule = miss_prog_module - miss_prog_group_desc.missEntryFunctionName = "__miss__ms" + miss_prog_group_desc = optix.ProgramGroupDesc() + miss_prog_group_desc.missModule = miss_prog_module + miss_prog_group_desc.missEntryFunctionName = "__miss__ms" - hitgroup_prog_group_desc = optix.ProgramGroupDesc() - hitgroup_prog_group_desc.hitgroupModuleCH = hitgroup_module + hitgroup_prog_group_desc = optix.ProgramGroupDesc() + hitgroup_prog_group_desc.hitgroupModuleCH = hitgroup_module hitgroup_prog_group_desc.hitgroupEntryFunctionNameCH = "__closesthit__ch" prog_group, log = ctx.programGroupCreate( - [ raygen_prog_group_desc, miss_prog_group_desc, hitgroup_prog_group_desc ], - program_group_options, - ) - print( "\tProgramGroup create log: <<<{}>>>".format( log ) ) + [raygen_prog_group_desc, miss_prog_group_desc, hitgroup_prog_group_desc], + program_group_options, + ) + print("\tProgramGroup create log: <<<{}>>>".format(log)) return prog_group -def create_pipeline( ctx, program_groups, pipeline_compile_options ): - print( "Creating pipeline ... " ) +def create_pipeline(ctx, program_groups, pipeline_compile_options): + print("Creating pipeline ... ") - max_trace_depth = 1 - pipeline_link_options = optix.PipelineLinkOptions() + max_trace_depth = 1 + pipeline_link_options = optix.PipelineLinkOptions() pipeline_link_options.maxTraceDepth = max_trace_depth - pipeline_link_options.debugLevel = optix.COMPILE_DEBUG_LEVEL_FULL + pipeline_link_options.debugLevel = optix.COMPILE_DEBUG_LEVEL_FULL log = "" pipeline = ctx.pipelineCreate( - pipeline_compile_options, - pipeline_link_options, - program_groups, - log) + pipeline_compile_options, pipeline_link_options, program_groups, log + ) stack_sizes = optix.StackSizes() for prog_group in program_groups: - optix.util.accumulateStackSizes( prog_group, stack_sizes ) - - (dc_stack_size_from_trav, dc_stack_size_from_state, cc_stack_size) = \ - optix.util.computeStackSizes( - stack_sizes, - max_trace_depth, - 0, # maxCCDepth - 0 # maxDCDepth - ) - - pipeline.setStackSize( - dc_stack_size_from_trav, - dc_stack_size_from_state, - cc_stack_size, - 1 # maxTraversableDepth - ) + optix.util.accumulateStackSizes(prog_group, stack_sizes) + + ( + dc_stack_size_from_trav, + dc_stack_size_from_state, + cc_stack_size, + ) = optix.util.computeStackSizes( + stack_sizes, max_trace_depth, 0, 0 # maxCCDepth # maxDCDepth + ) + + pipeline.setStackSize( + dc_stack_size_from_trav, + dc_stack_size_from_state, + cc_stack_size, + 1, # maxTraversableDepth + ) return pipeline -def create_sbt( prog_groups ): - print( "Creating sbt ... " ) +def create_sbt(prog_groups): + print("Creating sbt ... ") - (raygen_prog_group, miss_prog_group, hitgroup_prog_group ) = prog_groups + (raygen_prog_group, miss_prog_group, hitgroup_prog_group) = prog_groups global d_raygen_sbt global d_miss_sbt - header_format = '{}B'.format( optix.SBT_RECORD_HEADER_SIZE ) + header_format = "{}B".format(optix.SBT_RECORD_HEADER_SIZE) # # raygen record # - formats = [ header_format ] - itemsize = get_aligned_itemsize( formats, optix.SBT_RECORD_ALIGNMENT ) - dtype = np.dtype( { - 'names' : ['header' ], - 'formats' : formats, - 'itemsize': itemsize, - 'align' : True - } ) - h_raygen_sbt = np.array( [ 0 ], dtype=dtype ) - optix.sbtRecordPackHeader( raygen_prog_group, h_raygen_sbt ) - global d_raygen_sbt - d_raygen_sbt = array_to_device_memory( h_raygen_sbt ) - + formats = [header_format] + itemsize = get_aligned_itemsize(formats, optix.SBT_RECORD_ALIGNMENT) + dtype = np.dtype( + {"names": ["header"], "formats": formats, "itemsize": itemsize, "align": True} + ) + h_raygen_sbt = np.array([0], dtype=dtype) + optix.sbtRecordPackHeader(raygen_prog_group, h_raygen_sbt) + global d_raygen_sbt + d_raygen_sbt = array_to_device_memory(h_raygen_sbt) + # # miss record # - formats = [ header_format, 'f4', 'f4', 'f4'] - itemsize = get_aligned_itemsize( formats, optix.SBT_RECORD_ALIGNMENT ) - dtype = np.dtype( { - 'names' : ['header', 'r', 'g', 'b' ], - 'formats' : formats, - 'itemsize': itemsize, - 'align' : True - } ) - h_miss_sbt = np.array( [ (0, 0.3, 0.1, 0.2) ], dtype=dtype ) - optix.sbtRecordPackHeader( miss_prog_group, h_miss_sbt ) - global d_miss_sbt - d_miss_sbt = array_to_device_memory( h_miss_sbt ) - + formats = [header_format, "f4", "f4", "f4"] + itemsize = get_aligned_itemsize(formats, optix.SBT_RECORD_ALIGNMENT) + dtype = np.dtype( + { + "names": ["header", "r", "g", "b"], + "formats": formats, + "itemsize": itemsize, + "align": True, + } + ) + h_miss_sbt = np.array([(0, 0.3, 0.1, 0.2)], dtype=dtype) + optix.sbtRecordPackHeader(miss_prog_group, h_miss_sbt) + global d_miss_sbt + d_miss_sbt = array_to_device_memory(h_miss_sbt) + # # hitgroup record # - formats = [ header_format ] - itemsize = get_aligned_itemsize( formats, optix.SBT_RECORD_ALIGNMENT ) - dtype = np.dtype( { - 'names' : ['header' ], - 'formats' : formats, - 'itemsize': itemsize, - 'align' : True - } ) - h_hitgroup_sbt = np.array( [ (0) ], dtype=dtype ) - optix.sbtRecordPackHeader( hitgroup_prog_group, h_hitgroup_sbt ) + formats = [header_format] + itemsize = get_aligned_itemsize(formats, optix.SBT_RECORD_ALIGNMENT) + dtype = np.dtype( + {"names": ["header"], "formats": formats, "itemsize": itemsize, "align": True} + ) + h_hitgroup_sbt = np.array([(0)], dtype=dtype) + optix.sbtRecordPackHeader(hitgroup_prog_group, h_hitgroup_sbt) global d_hitgroup_sbt - d_hitgroup_sbt = array_to_device_memory( h_hitgroup_sbt ) - + d_hitgroup_sbt = array_to_device_memory(h_hitgroup_sbt) + sbt = optix.ShaderBindingTable() - sbt.raygenRecord = d_raygen_sbt.ptr - sbt.missRecordBase = d_miss_sbt.ptr - sbt.missRecordStrideInBytes = d_miss_sbt.mem.size - sbt.missRecordCount = 1 - sbt.hitgroupRecordBase = d_hitgroup_sbt.ptr + sbt.raygenRecord = d_raygen_sbt.ptr + sbt.missRecordBase = d_miss_sbt.ptr + sbt.missRecordStrideInBytes = d_miss_sbt.mem.size + sbt.missRecordCount = 1 + sbt.hitgroupRecordBase = d_hitgroup_sbt.ptr sbt.hitgroupRecordStrideInBytes = d_hitgroup_sbt.mem.size - sbt.hitgroupRecordCount = 1 + sbt.hitgroupRecordCount = 1 return sbt -def launch( pipeline, sbt, trav_handle ): - print( "Launching ... " ) +def launch(pipeline, sbt, trav_handle): + print("Launching ... ") - pix_bytes = pix_width*pix_height*4 - - h_pix = np.zeros( (pix_width,pix_height,4), 'B' ) - h_pix[0:pix_width, 0:pix_height] = [255, 128, 0, 255] - d_pix = cp.array( h_pix ) + pix_bytes = pix_width * pix_height * 4 + h_pix = np.zeros((pix_width, pix_height, 4), "B") + h_pix[0:pix_width, 0:pix_height] = [255, 128, 0, 255] + d_pix = cp.array(h_pix) params = [ - ( 'u8', 'image', d_pix.data.ptr ), - ( 'u4', 'image_width', pix_width ), - ( 'u4', 'image_height', pix_height ), - ( 'f4', 'cam_eye_x', 0 ), - ( 'f4', 'cam_eye_y', 0 ), - ( 'f4', 'cam_eye_z', 2.0 ), - ( 'f4', 'cam_U_x', 1.10457 ), - ( 'f4', 'cam_U_y', 0 ), - ( 'f4', 'cam_U_z', 0 ), - ( 'f4', 'cam_V_x', 0 ), - ( 'f4', 'cam_V_y', 0.828427 ), - ( 'f4', 'cam_V_z', 0 ), - ( 'f4', 'cam_W_x', 0 ), - ( 'f4', 'cam_W_y', 0 ), - ( 'f4', 'cam_W_z', -2.0 ), - ( 'u8', 'trav_handle', trav_handle ) + ("u8", "image", d_pix.data.ptr), + ("u4", "image_width", pix_width), + ("u4", "image_height", pix_height), + ("f4", "cam_eye_x", 0), + ("f4", "cam_eye_y", 0), + ("f4", "cam_eye_z", 2.0), + ("f4", "cam_U_x", 1.10457), + ("f4", "cam_U_y", 0), + ("f4", "cam_U_z", 0), + ("f4", "cam_V_x", 0), + ("f4", "cam_V_y", 0.828427), + ("f4", "cam_V_z", 0), + ("f4", "cam_W_x", 0), + ("f4", "cam_W_y", 0), + ("f4", "cam_W_z", -2.0), + ("u8", "trav_handle", trav_handle), ] - - formats = [ x[0] for x in params ] - names = [ x[1] for x in params ] - values = [ x[2] for x in params ] - itemsize = get_aligned_itemsize( formats, 8 ) - params_dtype = np.dtype( { - 'names' : names, - 'formats' : formats, - 'itemsize': itemsize, - 'align' : True - } ) - h_params = np.array( [ tuple(values) ], dtype=params_dtype ) - d_params = array_to_device_memory( h_params ) + + formats = [x[0] for x in params] + names = [x[1] for x in params] + values = [x[2] for x in params] + itemsize = get_aligned_itemsize(formats, 8) + params_dtype = np.dtype( + {"names": names, "formats": formats, "itemsize": itemsize, "align": True} + ) + h_params = np.array([tuple(values)], dtype=params_dtype) + d_params = array_to_device_memory(h_params) stream = cp.cuda.Stream() - optix.launch( - pipeline, - stream.ptr, - d_params.ptr, - h_params.dtype.itemsize, + optix.launch( + pipeline, + stream.ptr, + d_params.ptr, + h_params.dtype.itemsize, sbt, pix_width, pix_height, - 1 # depth + 1, # depth ) stream.synchronize() - h_pix = cp.asnumpy( d_pix ) + h_pix = cp.asnumpy(d_pix) return h_pix @@ -1264,6 +1340,7 @@ def launch( pipeline, sbt, trav_handle ): # An equivalent to the compile_cuda function for Python kernels. The types of # the arguments to the kernel must be provided, if there are any. + def compile_numba(f, sig=(), debug=False): # Based on numba.cuda.compile_ptx. We don't just use # compile_ptx_for_current_device because it generates a kernel with a @@ -1271,19 +1348,19 @@ def compile_numba(f, sig=(), debug=False): # added to compile_ptx in Numba to not mangle the function name. nvvm_options = { - 'debug': debug, - 'fastmath': False, - 'opt': 0 if debug else 3, + "debug": debug, + "fastmath": False, + "opt": 0 if debug else 3, } - cres = numba_compile_cuda(f, None, sig, debug=debug, - nvvm_options=nvvm_options) + cres = numba_compile_cuda(f, None, sig, debug=debug, nvvm_options=nvvm_options) fname = cres.fndesc.llvm_func_name tgt = cres.target_context filename = cres.type_annotation.filename linenum = int(cres.type_annotation.linenum) - lib, kernel = tgt.prepare_cuda_kernel(cres.library, cres.fndesc, debug, - nvvm_options, filename, linenum) + lib, kernel = tgt.prepare_cuda_kernel( + cres.library, cres.fndesc, debug, nvvm_options, filename, linenum + ) cc = get_current_device().compute_capability ptx = lib.get_asm_str(cc=cc) @@ -1293,12 +1370,12 @@ def compile_numba(f, sig=(), debug=False): return ptx.replace(mangled_name, original_name) -#------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------- # # User code / kernel - the following section is what we'd expect a user of # PyOptiX to write. # -#------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------- # vec_math @@ -1306,26 +1383,42 @@ def compile_numba(f, sig=(), debug=False): def clamp(x, a, b): pass + @overload(clamp, target="cuda") def jit_clamp(x, a, b): - if isinstance(x, types.Float) and isinstance(a, types.Float) and isinstance(b, types.Float): + if ( + isinstance(x, types.Float) + and isinstance(a, types.Float) + and isinstance(b, types.Float) + ): + def clamp_float_impl(x, a, b): return max(a, min(x, b)) + return clamp_float_impl - elif isinstance(x, Float3) and isinstance(a, types.Float) and isinstance(b, types.Float): + elif ( + isinstance(x, Float3) + and isinstance(a, types.Float) + and isinstance(b, types.Float) + ): + def clamp_float3_impl(x, a, b): return make_float3(clamp(x.x, a, b), clamp(x.y, a, b), clamp(x.z, a, b)) + return clamp_float3_impl def dot(a, b): pass + @overload(dot, target="cuda") def jit_dot(a, b): if isinstance(a, Float3) and isinstance(b, Float3): + def dot_float3_impl(a, b): return a.x * b.x + a.y * b.y + a.z * b.z + return dot_float3_impl @@ -1337,36 +1430,55 @@ def normalize(v): # Helpers + @cuda.jit(device=True) def toSRGB(c): # Use float32 for constants invGamma = float32(1.0) / float32(2.4) - powed = make_float3(math.pow(c.x, invGamma), math.pow(c.y, invGamma), math.pow(c.z, invGamma)) + powed = make_float3( + math.pow(c.x, invGamma), math.pow(c.y, invGamma), math.pow(c.z, invGamma) + ) return make_float3( - float32(12.92) * c.x if c.x < float32(0.0031308) else float32(1.055) * powed.x - float32(0.055), - float32(12.92) * c.y if c.y < float32(0.0031308) else float32(1.055) * powed.y - float32(0.055), - float32(12.92) * c.z if c.z < float32(0.0031308) else float32(1.055) * powed.z - float32(0.055)) + float32(12.92) * c.x + if c.x < float32(0.0031308) + else float32(1.055) * powed.x - float32(0.055), + float32(12.92) * c.y + if c.y < float32(0.0031308) + else float32(1.055) * powed.y - float32(0.055), + float32(12.92) * c.z + if c.z < float32(0.0031308) + else float32(1.055) * powed.z - float32(0.055), + ) @cuda.jit(device=True) def quantizeUnsigned8Bits(x): - x = clamp( x, float32(0.0), float32(1.0) ) + x = clamp(x, float32(0.0), float32(1.0)) N, Np1 = (1 << 8) - 1, 1 << 8 return uint8(min(uint32(x * float32(Np1)), uint32(N))) + @cuda.jit(device=True) def make_color(c): srgb = toSRGB(clamp(c, float32(0.0), float32(1.0))) - return make_uchar4(quantizeUnsigned8Bits(srgb.x), quantizeUnsigned8Bits(srgb.y), quantizeUnsigned8Bits(srgb.z), uint8(255)) + return make_uchar4( + quantizeUnsigned8Bits(srgb.x), + quantizeUnsigned8Bits(srgb.y), + quantizeUnsigned8Bits(srgb.z), + uint8(255), + ) + # ray functions + @cuda.jit(device=True) def setPayload(p): optix.SetPayload_0(cuda.libdevice.float_as_int(p.x)) optix.SetPayload_1(cuda.libdevice.float_as_int(p.y)) optix.SetPayload_2(cuda.libdevice.float_as_int(p.z)) + @cuda.jit(device=True) def computeRay(idx, dim): U = params.cam_u @@ -1374,9 +1486,8 @@ def computeRay(idx, dim): W = params.cam_w # Normalizing coordinates to [-1.0, 1.0] d = float32(2.0) * make_float2( - float32(idx.x) / float32(dim.x), - float32(idx.y) / float32(dim.y) - ) - float32(1.0) + float32(idx.x) / float32(dim.x), float32(idx.y) / float32(dim.y) + ) - float32(1.0) origin = params.cam_eye direction = normalize(d.x * U + d.y * V + W) @@ -1394,31 +1505,31 @@ def __raygen__rg(): # Trace the ray against our scene hierarchy payload_pack = optix.Trace( - params.handle, - ray_origin, - ray_direction, - float32(0.0), # Min intersection distance - float32(1e16), # Max intersection distance - float32(0.0), # rayTime -- used for motion blur - OptixVisibilityMask(255), # Specify always visible - # OptixRayFlags.OPTIX_RAY_FLAG_NONE, - uint32(OPTIX_RAY_FLAG_NONE), - uint32(0), # SBT offset -- See SBT discussion - uint32(1), # SBT stride -- See SBT discussion - uint32(0), # missSBTIndex -- See SBT discussion + params.handle, + ray_origin, + ray_direction, + float32(0.0), # Min intersection distance + float32(1e16), # Max intersection distance + float32(0.0), # rayTime -- used for motion blur + OptixVisibilityMask(255), # Specify always visible + # OptixRayFlags.OPTIX_RAY_FLAG_NONE, + uint32(OPTIX_RAY_FLAG_NONE), + uint32(0), # SBT offset -- See SBT discussion + uint32(1), # SBT stride -- See SBT discussion + uint32(0), # missSBTIndex -- See SBT discussion ) result = make_float3( - cuda.libdevice.int_as_float(payload_pack.p0), - cuda.libdevice.int_as_float(payload_pack.p1), - cuda.libdevice.int_as_float(payload_pack.p2) + cuda.libdevice.int_as_float(payload_pack.p0), + cuda.libdevice.int_as_float(payload_pack.p1), + cuda.libdevice.int_as_float(payload_pack.p2), ) # Record results in our output raster - params.image[idx.y * params.image_width + idx.x] = make_color( result ) + params.image[idx.y * params.image_width + idx.x] = make_color(result) def __miss__ms(): - miss_data = MissDataStruct(optix.GetSbtDataPointer()) + miss_data = MissDataStruct(optix.GetSbtDataPointer()) setPayload(miss_data.bg_color) @@ -1430,42 +1541,42 @@ def __closesthit__ch(): setPayload(make_float3(barycentrics, float32(1.0))) -#------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------- # # main # -#------------------------------------------------------------------------------- +# ------------------------------------------------------------------------------- def main(): raygen_ptx = compile_numba(__raygen__rg) miss_ptx = compile_numba(__miss__ms) hitgroup_ptx = compile_numba(__closesthit__ch) - - # triangle_ptx = compile_cuda( "examples/triangle.cu" ) - print(miss_ptx) + # triangle_ptx = compile_cuda( "examples/triangle.cu" ) init_optix() - ctx = create_ctx() + ctx = create_ctx() gas_handle, d_gas_output_buffer = create_accel(ctx) pipeline_options = set_pipeline_options() - - raygen_module = create_module( ctx, pipeline_options, raygen_ptx ) - miss_module = create_module( ctx, pipeline_options, miss_ptx ) - hitgroup_module = create_module( ctx, pipeline_options, hitgroup_ptx ) - prog_groups = create_program_groups( ctx, raygen_module, miss_module, hitgroup_module ) - pipeline = create_pipeline( ctx, prog_groups, pipeline_options ) - sbt = create_sbt( prog_groups ) - pix = launch( pipeline, sbt, gas_handle ) + raygen_module = create_module(ctx, pipeline_options, raygen_ptx) + miss_module = create_module(ctx, pipeline_options, miss_ptx) + hitgroup_module = create_module(ctx, pipeline_options, hitgroup_ptx) + + prog_groups = create_program_groups( + ctx, raygen_module, miss_module, hitgroup_module + ) + pipeline = create_pipeline(ctx, prog_groups, pipeline_options) + sbt = create_sbt(prog_groups) + pix = launch(pipeline, sbt, gas_handle) - print( "Total number of log messages: {}".format( logger.num_mssgs ) ) + print("Total number of log messages: {}".format(logger.num_mssgs)) - pix = pix.reshape( ( pix_height, pix_width, 4 ) ) # PIL expects [ y, x ] resolution - img = ImageOps.flip( Image.fromarray( pix, 'RGBA' ) ) # PIL expects y = 0 at bottom - img.save( 'my.png' ) + pix = pix.reshape((pix_height, pix_width, 4)) # PIL expects [ y, x ] resolution + img = ImageOps.flip(Image.fromarray(pix, "RGBA")) # PIL expects y = 0 at bottom + img.save("my.png") img.show() From 52ec43cc129f6982a91088132e352d7a88a9fe64 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Wed, 19 Jan 2022 13:09:31 -0800 Subject: [PATCH 10/25] Add style format config files --- .flake8 | 19 +++++++++++++++++++ .pre-commit-config.yaml | 33 +++++++++++++++++++++++++++++++++ setup.cfg | 19 +++++++++++++++++++ 3 files changed, 71 insertions(+) create mode 100644 .flake8 create mode 100644 .pre-commit-config.yaml create mode 100644 setup.cfg diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..f2d388a --- /dev/null +++ b/.flake8 @@ -0,0 +1,19 @@ +[flake8] +exclude = __init__.py +ignore = + # line break before binary operator + W503, + # whitespace before : + E203 + +[pydocstyle] +match = .*\.py +match-dir = examples +# In addition to numpy style, we additionally ignore: +add-ignore = + # magic methods + D105, + # no docstring in __init__ + D107, + # newlines before docstrings + D204 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..15b1a1b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,33 @@ +repos: + - repo: https://github.com/PyCQA/isort + rev: 5.6.4 + hooks: + - id: isort + alias: isort + name: isort + args: ["--settings-path=setup.cfg"] + files: examples/.* + exclude: __init__.py$ + types: [text] + types_or: [python] + - repo: https://github.com/psf/black + rev: 19.10b0 + hooks: + - id: black + files: examples/.* + - repo: https://github.com/PyCQA/flake8 + rev: 3.8.3 + hooks: + - id: flake8 + alias: flake8 + name: flake8 + args: ["--config=.flake8"] + files: python/.*\.py$ + # - repo: https://github.com/PyCQA/pydocstyle + # rev: 6.1.1 + # hooks: + # - id: pydocstyle + # args: ["--config=.flake8"] + +default_language_version: + python: python3 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..d839db8 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,19 @@ +[tool.black] +line-length = 79 +target-version = ["py36"] + +[isort] +line_length=79 +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +combine_as_imports=True +order_by_type=True +skip= + .eggs + .git + .hg + .mypy_cache + .tox + .venv + __init__.py \ No newline at end of file From 5418f8eb12a937cfa274135f68332eb24a7a3151 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Wed, 19 Jan 2022 13:15:17 -0800 Subject: [PATCH 11/25] Add gitignore --- .gitignore | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c958bba --- /dev/null +++ b/.gitignore @@ -0,0 +1,22 @@ +## Common +__pycache__ +*.py[cod] +*$py.class +*.a +*.o +*.so +*.dylib +.cache +.vscode +*.swp +*.pytest_cache +DartConfiguration.tcl +.DS_Store +*.manifest +*.spec +.nfs* +.clangd + +## build files +build/ +*.ptx From 289b9f7933acd6d48fbbcc80a59ea46fb7b500a3 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Wed, 19 Jan 2022 16:26:02 -0800 Subject: [PATCH 12/25] Initial modularize numba supports --- examples/numba_support.py | 973 ++++++++++++++++++++++++++++++++++++++ examples/pyramid.py | 616 ++++++++++++++++++++++++ 2 files changed, 1589 insertions(+) create mode 100644 examples/numba_support.py create mode 100755 examples/pyramid.py diff --git a/examples/numba_support.py b/examples/numba_support.py new file mode 100644 index 0000000..4d693b6 --- /dev/null +++ b/examples/numba_support.py @@ -0,0 +1,973 @@ +# ------------------------------------------------------------------------------- +# +# Numba extensions for general CUDA / OptiX support +# +# ------------------------------------------------------------------------------- + +from operator import add, mul, sub + +from llvmlite import ir +from numba import cuda, float32, int32, types, uint8, uint32 +from numba.core import cgutils +from numba.core.extending import ( + make_attribute_wrapper, + models, + overload, + register_model, + type_callable, + typeof_impl, +) +from numba.core.imputils import lower_constant +from numba.core.typing.templates import ( + AttributeTemplate, + ConcreteTemplate, + signature, +) +from numba.cuda.cudadecl import register, register_attr, register_global +from numba.cuda.cudadrv import nvvm +from numba.cuda.cudaimpl import lower +from numba.cuda.types import dim3 + +import optix + +# UChar4 +# ------ + +# Numba presently doesn't implement the UChar4 type (which is fairly standard +# CUDA) so we provide some minimal support for it here. + + +# Prototype a function to construct a uchar4 + + +def make_uchar4(x, y, z, w): + pass + + +# UChar4 typing + + +class UChar4(types.Type): + def __init__(self): + super().__init__(name="UChar4") + + +uchar4 = UChar4() + + +@register +class MakeUChar4(ConcreteTemplate): + key = make_uchar4 + cases = [signature(uchar4, types.uchar, types.uchar, types.uchar, types.uchar)] + + +register_global(make_uchar4, types.Function(MakeUChar4)) + + +# UChar4 data model + + +@register_model(UChar4) +class UChar4Model(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("x", types.uchar), + ("y", types.uchar), + ("z", types.uchar), + ("w", types.uchar), + ] + super().__init__(dmm, fe_type, members) + + +make_attribute_wrapper(UChar4, "x", "x") +make_attribute_wrapper(UChar4, "y", "y") +make_attribute_wrapper(UChar4, "z", "z") +make_attribute_wrapper(UChar4, "w", "w") + + +# UChar4 lowering + + +@lower(make_uchar4, types.uchar, types.uchar, types.uchar, types.uchar) +def lower_make_uchar4(context, builder, sig, args): + uc4 = cgutils.create_struct_proxy(uchar4)(context, builder) + uc4.x = args[0] + uc4.y = args[1] + uc4.z = args[2] + uc4.w = args[3] + return uc4._getvalue() + + +# float3 +# ------ + +# Float3 typing + + +class Float3(types.Type): + def __init__(self): + super().__init__(name="Float3") + + +float3 = Float3() + + +# Float2 typing (forward declaration) + + +class Float2(types.Type): + def __init__(self): + super().__init__(name="Float2") + + +float2 = Float2() + + +# Float3 data model + + +@register_model(Float3) +class Float3Model(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("x", types.float32), + ("y", types.float32), + ("z", types.float32), + ] + super().__init__(dmm, fe_type, members) + + +make_attribute_wrapper(Float3, "x", "x") +make_attribute_wrapper(Float3, "y", "y") +make_attribute_wrapper(Float3, "z", "z") + + +def lower_float3_ops(op): + class Float3_op_template(ConcreteTemplate): + key = op + cases = [ + signature(float3, float3, float3), + signature(float3, types.float32, float3), + signature(float3, float3, types.float32), + ] + + def float3_op_impl(context, builder, sig, args): + def op_attr(lhs, rhs, res, attr): + setattr( + res, + attr, + context.compile_internal( + builder, + lambda x, y: op(x, y), + signature(types.float32, types.float32, types.float32), + (getattr(lhs, attr), getattr(rhs, attr)), + ), + ) + + arg0, arg1 = args + + if isinstance(sig.args[0], types.Float): + lf3 = cgutils.create_struct_proxy(float3)(context, builder) + lf3.x = arg0 + lf3.y = arg0 + lf3.z = arg0 + else: + lf3 = cgutils.create_struct_proxy(float3)(context, builder, value=args[0]) + + if isinstance(sig.args[1], types.Float): + rf3 = cgutils.create_struct_proxy(float3)(context, builder) + rf3.x = arg1 + rf3.y = arg1 + rf3.z = arg1 + else: + rf3 = cgutils.create_struct_proxy(float3)(context, builder, value=args[1]) + + res = cgutils.create_struct_proxy(float3)(context, builder) + op_attr(lf3, rf3, res, "x") + op_attr(lf3, rf3, res, "y") + op_attr(lf3, rf3, res, "z") + return res._getvalue() + + register_global(op, types.Function(Float3_op_template)) + lower(op, float3, float3)(float3_op_impl) + lower(op, types.float32, float3)(float3_op_impl) + lower(op, float3, types.float32)(float3_op_impl) + + +lower_float3_ops(mul) +lower_float3_ops(add) + + +@lower(add, float32, float3) +def add_float32_float3_impl(context, builder, sig, args): + s = args[0] + rhs = cgutils.create_struct_proxy(float3)(context, builder, args[1]) + res = cgutils.create_struct_proxy(float3)(context, builder) + res.x = builder.fadd(s, rhs.x) + res.y = builder.fadd(s, rhs.y) + res.z = builder.fadd(s, rhs.z) + return res._getvalue() + + +@lower(add, float3, float32) +def add_float3_float32_impl(context, builder, sig, args): + lhs = cgutils.create_struct_proxy(float3)(context, builder, args[0]) + s = args[1] + res = cgutils.create_struct_proxy(float3)(context, builder) + res.x = builder.fadd(lhs.x, s) + res.y = builder.fadd(lhs.y, s) + res.z = builder.fadd(lhs.z, s) + return res._getvalue() + + +# Prototype a function to construct a float3 + + +def make_float3(x, y, z): + pass + + +@register +class MakeFloat3(ConcreteTemplate): + key = make_float3 + cases = [ + signature(float3, types.float32, types.float32, types.float32), + signature(float3, float2, types.float32), + ] + + +register_global(make_float3, types.Function(MakeFloat3)) + + +# make_float3 lowering + + +@lower(make_float3, types.float32, types.float32, types.float32) +def lower_make_float3(context, builder, sig, args): + f3 = cgutils.create_struct_proxy(float3)(context, builder) + f3.x = args[0] + f3.y = args[1] + f3.z = args[2] + return f3._getvalue() + + +@lower(make_float3, float2, types.float32) +def lower_make_float3(context, builder, sig, args): + f2 = cgutils.create_struct_proxy(float2)(context, builder, args[0]) + f3 = cgutils.create_struct_proxy(float3)(context, builder) + f3.x = f2.x + f3.y = f2.y + f3.z = args[1] + return f3._getvalue() + + +# float2 +# ------ + + +# Float2 data model + + +@register_model(Float2) +class Float2Model(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("x", types.float32), + ("y", types.float32), + ] + super().__init__(dmm, fe_type, members) + + +make_attribute_wrapper(Float2, "x", "x") +make_attribute_wrapper(Float2, "y", "y") + + +def lower_float2_ops(op): + class Float2_op_template(ConcreteTemplate): + key = op + cases = [ + signature(float2, float2, float2), + signature(float2, types.float32, float2), + signature(float2, float2, types.float32), + ] + + def float2_op_impl(context, builder, sig, args): + def op_attr(lhs, rhs, res, attr): + setattr( + res, + attr, + context.compile_internal( + builder, + lambda x, y: op(x, y), + signature(types.float32, types.float32, types.float32), + (getattr(lhs, attr), getattr(rhs, attr)), + ), + ) + + arg0, arg1 = args + + if isinstance(sig.args[0], types.Float): + lf2 = cgutils.create_struct_proxy(float2)(context, builder) + lf2.x = arg0 + lf2.y = arg0 + else: + lf2 = cgutils.create_struct_proxy(float2)(context, builder, value=args[0]) + + if isinstance(sig.args[1], types.Float): + rf2 = cgutils.create_struct_proxy(float2)(context, builder) + rf2.x = arg1 + rf2.y = arg1 + else: + rf2 = cgutils.create_struct_proxy(float2)(context, builder, value=args[1]) + + res = cgutils.create_struct_proxy(float2)(context, builder) + op_attr(lf2, rf2, res, "x") + op_attr(lf2, rf2, res, "y") + return res._getvalue() + + register_global(op, types.Function(Float2_op_template)) + lower(op, float2, float2)(float2_op_impl) + lower(op, types.Float, float2)(float2_op_impl) + lower(op, float2, types.Float)(float2_op_impl) + + +lower_float2_ops(mul) +lower_float2_ops(sub) + + +# Prototype a function to construct a float2 + + +def make_float2(x, y): + pass + + +@register +class MakeFloat2(ConcreteTemplate): + key = make_float2 + cases = [signature(float2, types.float32, types.float32)] + + +register_global(make_float2, types.Function(MakeFloat2)) + + +# make_float2 lowering + + +@lower(make_float2, types.float32, types.float32) +def lower_make_float2(context, builder, sig, args): + f2 = cgutils.create_struct_proxy(float2)(context, builder) + f2.x = args[0] + f2.y = args[1] + return f2._getvalue() + + +# uint3 +# ------ + + +class UInt3(types.Type): + def __init__(self): + super().__init__(name="UInt3") + + +uint3 = UInt3() + + +# UInt3 data model + + +@register_model(UInt3) +class UInt3Model(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("x", types.uint32), + ("y", types.uint32), + ("z", types.uint32), + ] + super().__init__(dmm, fe_type, members) + + +make_attribute_wrapper(UInt3, "x", "x") +make_attribute_wrapper(UInt3, "y", "y") +make_attribute_wrapper(UInt3, "z", "z") + + +# Prototype a function to construct a uint3 + + +def make_uint3(x, y, z): + pass + + +@register +class MakeUInt3(ConcreteTemplate): + key = make_uint3 + cases = [signature(uint3, types.uint32, types.uint32, types.uint32)] + + +register_global(make_uint3, types.Function(MakeUInt3)) + + +# make_uint3 lowering + + +@lower(make_uint3, types.uint32, types.uint32, types.uint32) +def lower_make_uint3(context, builder, sig, args): + # u4 = uint32 + u4_3 = cgutils.create_struct_proxy(uint3)(context, builder) + u4_3.x = args[0] + u4_3.y = args[1] + u4_3.z = args[2] + return u4_3._getvalue() + + +# Temporary Payload Parameter Pack +class PayloadPack(types.Type): + def __init__(self): + super().__init__(name="PayloadPack") + + +payload_pack = PayloadPack() + + +# UInt3 data model + + +@register_model(PayloadPack) +class PayloadPackModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("p0", types.uint32), + ("p1", types.uint32), + ("p2", types.uint32), + ] + super().__init__(dmm, fe_type, members) + + +make_attribute_wrapper(PayloadPack, "p0", "p0") +make_attribute_wrapper(PayloadPack, "p1", "p1") +make_attribute_wrapper(PayloadPack, "p2", "p2") + +# OptiX typedefs and enums +# ----------- + +OptixVisibilityMask = types.Integer("OptixVisibilityMask", bitwidth=32, signed=False) +OptixTraversableHandle = types.Integer( + "OptixTraversableHandle", bitwidth=64, signed=False +) + + +OPTIX_RAY_FLAG_NONE = 0 +# class OptixRayFlags(Enum): +# OPTIX_RAY_FLAG_NONE = 0 +# OPTIX_RAY_FLAG_DISABLE_ANYHIT = 1 << 0 +# OPTIX_RAY_FLAG_ENFORCE_ANYHIT = 1 << 1 +# OPTIX_RAY_FLAG_TERMINATE_ON_FIRST_HIT = 1 << 2 +# OPTIX_RAY_FLAG_DISABLE_CLOSESTHIT = 1 << 3, +# OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES = 1 << 4 +# OPTIX_RAY_FLAG_CULL_FRONT_FACING_TRIANGLES = 1 << 5 +# OPTIX_RAY_FLAG_CULL_DISABLED_ANYHIT = 1 << 6 +# OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT = 1 << 7 + + +# OptiX types +# ----------- + +# Typing for OptiX types + + +class SbtDataPointer(types.RawPointer): + def __init__(self): + super().__init__(name="SbtDataPointer") + + +sbt_data_pointer = SbtDataPointer() + + +# Models for OptiX types + + +@register_model(SbtDataPointer) +class SbtDataPointerModel(models.OpaqueModel): + pass + + +# Params +# ------------ + +# Structures as declared in triangle.h + + +class ParamsStruct: + fields = ( + ("image", "uchar4*"), + ("image_width", "unsigned int"), + ("image_height", "unsigned int"), + ("cam_eye", "float3"), + ("cam_u", "float3"), + ("cam_v", "float3"), + ("cam_w", "float3"), + ("handle", "OptixTraversableHandle"), + ) + + +# "Declare" a global called params +params = ParamsStruct() + + +class Params(types.Type): + def __init__(self): + super().__init__(name="ParamsType") + + +params_type = Params() + + +# ParamsStruct data model + + +@register_model(Params) +class ParamsModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("image", types.CPointer(uchar4)), + ("image_width", types.uint32), + ("image_height", types.uint32), + ("cam_eye", float3), + ("cam_u", float3), + ("cam_v", float3), + ("cam_w", float3), + ("handle", OptixTraversableHandle), + ] + super().__init__(dmm, fe_type, members) + + +make_attribute_wrapper(Params, "image", "image") +make_attribute_wrapper(Params, "image_width", "image_width") +make_attribute_wrapper(Params, "image_height", "image_height") +make_attribute_wrapper(Params, "cam_eye", "cam_eye") +make_attribute_wrapper(Params, "cam_u", "cam_u") +make_attribute_wrapper(Params, "cam_v", "cam_v") +make_attribute_wrapper(Params, "cam_w", "cam_w") +make_attribute_wrapper(Params, "handle", "handle") + + +@typeof_impl.register(ParamsStruct) +def typeof_params(val, c): + return params_type + + +# ParamsStruct lowering +# The below makes 'param' a global variable, accessible from any user defined +# kernels. + + +@lower_constant(Params) +def constant_params(context, builder, ty, pyval): + try: + gvar = builder.module.get_global("params") + except KeyError: + llty = context.get_value_type(ty) + gvar = cgutils.add_global_variable( + builder.module, llty, "params", addrspace=nvvm.ADDRSPACE_CONSTANT + ) + gvar.linkage = "external" + gvar.global_constant = True + + return builder.load(gvar) + + +# MissData +# ------------ + +# Structures as declared in triangle.h +class MissDataStruct: + fields = ("bg_color", "float3") + + +MissData = MissDataStruct() + + +class MissData(types.Type): + def __init__(self): + super().__init__(name="MissDataType") + + +miss_data_type = MissData() + + +@register_model(MissData) +class MissDataModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [ + ("bg_color", float3), + ] + super().__init__(dmm, fe_type, members) + + +make_attribute_wrapper(MissData, "bg_color", "bg_color") + + +@typeof_impl.register(MissDataStruct) +def typeof_miss_data(val, c): + return miss_data_type + + +# MissData Constructor +@type_callable(MissDataStruct) +def type_miss_data_struct(context): + def typer(sbt_data_pointer): + if isinstance(sbt_data_pointer, SbtDataPointer): + return miss_data_type + + return typer + + +@lower(MissDataStruct, sbt_data_pointer) +def lower_miss_data_ctor(context, builder, sig, args): + # Anyway to err if this ctor is not called inside __miss__* program? + # TODO: Optimize + ptr = args[0] + ptr = builder.bitcast(ptr, context.get_value_type(miss_data_type).as_pointer()) + + bg_color_ptr = cgutils.gep_inbounds(builder, ptr, 0, 0) + + xptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 0) + yptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 1) + zptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 2) + + output_miss_data = cgutils.create_struct_proxy(miss_data_type)(context, builder) + output_bg_color_ptr = cgutils.gep_inbounds( + builder, output_miss_data._getpointer(), 0, 0 + ) + output_bg_color_x_ptr = cgutils.gep_inbounds(builder, output_bg_color_ptr, 0, 0) + output_bg_color_y_ptr = cgutils.gep_inbounds(builder, output_bg_color_ptr, 0, 1) + output_bg_color_z_ptr = cgutils.gep_inbounds(builder, output_bg_color_ptr, 0, 2) + + x = builder.load(xptr) + y = builder.load(yptr) + z = builder.load(zptr) + + builder.store(x, output_bg_color_x_ptr) + builder.store(y, output_bg_color_y_ptr) + builder.store(z, output_bg_color_z_ptr) + + # Doesn't seem to do what's expected? + # miss_data.bg_color.x = builder.load(xptr) + # miss_data.bg_color.y = builder.load(yptr) + # miss_data.bg_color.z = builder.load(zptr) + return output_miss_data._getvalue() + + +# OptiX functions +# --------------- + +# Here we "prototype" the OptiX functions that the user will call in their +# kernels, so that Numba has something to refer to when compiling the kernel. + + +def _optix_GetLaunchIndex(): + pass + + +def _optix_GetLaunchDimensions(): + pass + + +def _optix_GetSbtDataPointer(): + pass + + +def _optix_SetPayload_0(): + pass + + +def _optix_SetPayload_1(): + pass + + +def _optix_SetPayload_2(): + pass + + +def _optix_GetTriangleBarycentrics(): + pass + + +def _optix_Trace(): + pass + + +# Monkey-patch the functions into the optix module, so the user can write +# optix.GetLaunchIndex etc., for symmetry with the rest of the API implemented +# in PyOptiX. + +optix.GetLaunchIndex = _optix_GetLaunchIndex +optix.GetLaunchDimensions = _optix_GetLaunchDimensions +optix.GetSbtDataPointer = _optix_GetSbtDataPointer +optix.GetTriangleBarycentrics = _optix_GetTriangleBarycentrics +optix.SetPayload_0 = _optix_SetPayload_0 +optix.SetPayload_1 = _optix_SetPayload_1 +optix.SetPayload_2 = _optix_SetPayload_2 + +optix.Trace = _optix_Trace + + +# OptiX function typing + + +@register +class OptixGetLaunchIndex(ConcreteTemplate): + key = optix.GetLaunchIndex + cases = [signature(dim3)] + + +@register +class OptixGetLaunchDimensions(ConcreteTemplate): + key = optix.GetLaunchDimensions + cases = [signature(dim3)] + + +@register +class OptixGetSbtDataPointer(ConcreteTemplate): + key = optix.GetSbtDataPointer + cases = [signature(sbt_data_pointer)] + + +def registerSetPayload(reg): + class OptixSetPayloadReg(ConcreteTemplate): + key = getattr(optix, "SetPayload_" + str(reg)) + cases = [signature(types.void, uint32)] + + register(OptixSetPayloadReg) + return OptixSetPayloadReg + + +OptixSetPayload_0 = registerSetPayload(0) +OptixSetPayload_1 = registerSetPayload(1) +OptixSetPayload_2 = registerSetPayload(2) + + +@register +class OptixGetTriangleBarycentrics(ConcreteTemplate): + key = optix.GetTriangleBarycentrics + cases = [signature(float2)] + + +@register +class OptixTrace(ConcreteTemplate): + key = optix.Trace + cases = [ + signature( + payload_pack, + OptixTraversableHandle, + float3, + float3, + float32, + float32, + float32, + OptixVisibilityMask, + uint32, + uint32, + uint32, + uint32, + ) + ] + + +@register_attr +class OptixModuleTemplate(AttributeTemplate): + key = types.Module(optix) + + def resolve_GetLaunchIndex(self, mod): + return types.Function(OptixGetLaunchIndex) + + def resolve_GetLaunchDimensions(self, mod): + return types.Function(OptixGetLaunchDimensions) + + def resolve_GetSbtDataPointer(self, mod): + return types.Function(OptixGetSbtDataPointer) + + def resolve_SetPayload_0(self, mod): + return types.Function(OptixSetPayload_0) + + def resolve_SetPayload_1(self, mod): + return types.Function(OptixSetPayload_1) + + def resolve_SetPayload_2(self, mod): + return types.Function(OptixSetPayload_2) + + def resolve_GetTriangleBarycentrics(self, mod): + return types.Function(OptixGetTriangleBarycentrics) + + def resolve_Trace(self, mod): + return types.Function(OptixTrace) + + +# OptiX function lowering + + +@lower(optix.GetLaunchIndex) +def lower_optix_getLaunchIndex(context, builder, sig, args): + def get_launch_index(axis): + asm = ir.InlineAsm( + ir.FunctionType(ir.IntType(32), []), + f"call ($0), _optix_get_launch_index_{axis}, ();", + "=r", + ) + return builder.call(asm, []) + + index = cgutils.create_struct_proxy(dim3)(context, builder) + index.x = get_launch_index("x") + index.y = get_launch_index("y") + index.z = get_launch_index("z") + return index._getvalue() + + +@lower(optix.GetLaunchDimensions) +def lower_optix_getLaunchDimensions(context, builder, sig, args): + def get_launch_dimensions(axis): + asm = ir.InlineAsm( + ir.FunctionType(ir.IntType(32), []), + f"call ($0), _optix_get_launch_dimension_{axis}, ();", + "=r", + ) + return builder.call(asm, []) + + index = cgutils.create_struct_proxy(dim3)(context, builder) + index.x = get_launch_dimensions("x") + index.y = get_launch_dimensions("y") + index.z = get_launch_dimensions("z") + return index._getvalue() + + +@lower(optix.GetSbtDataPointer) +def lower_optix_getSbtDataPointer(context, builder, sig, args): + asm = ir.InlineAsm( + ir.FunctionType(ir.IntType(64), []), + "call ($0), _optix_get_sbt_data_ptr_64, ();", + "=l", + ) + ptr = builder.call(asm, []) + ptr = builder.inttoptr(ptr, ir.IntType(8).as_pointer()) + return ptr + + +def lower_optix_SetPayloadReg(reg): + def lower_optix_SetPayload_impl(context, builder, sig, args): + asm = ir.InlineAsm( + ir.FunctionType(ir.VoidType(), [ir.IntType(32), ir.IntType(32)]), + f"call _optix_set_payload, ($0, $1);", + "r,r", + ) + builder.call(asm, [context.get_constant(types.int32, reg), args[0]]) + + lower(getattr(optix, f"SetPayload_{reg}"), uint32)(lower_optix_SetPayload_impl) + + +lower_optix_SetPayloadReg(0) +lower_optix_SetPayloadReg(1) +lower_optix_SetPayloadReg(2) + + +@lower(optix.GetTriangleBarycentrics) +def lower_optix_getTriangleBarycentrics(context, builder, sig, args): + f2 = cgutils.create_struct_proxy(float2)(context, builder) + retty = ir.LiteralStructType([ir.FloatType(), ir.FloatType()]) + asm = ir.InlineAsm( + ir.FunctionType(retty, []), + "call ($0, $1), _optix_get_triangle_barycentrics, ();", + "=f,=f", + ) + ret = builder.call(asm, []) + f2.x = builder.extract_value(ret, 0) + f2.y = builder.extract_value(ret, 1) + return f2._getvalue() + + +@lower( + optix.Trace, + OptixTraversableHandle, + float3, + float3, + float32, + float32, + float32, + OptixVisibilityMask, + uint32, + uint32, + uint32, + uint32, +) +def lower_optix_Trace(context, builder, sig, args): + # Only implements the version that accepts 3 payload registers + # TODO: Optimize returns, adapt to 0-8 payload registers. + + ( + handle, + rayOrigin, + rayDirection, + tmin, + tmax, + rayTime, + visibilityMask, + rayFlags, + SBToffset, + SBTstride, + missSBTIndex, + ) = args + + rayOrigin = cgutils.create_struct_proxy(float3)(context, builder, rayOrigin) + rayDirection = cgutils.create_struct_proxy(float3)(context, builder, rayDirection) + output = cgutils.create_struct_proxy(payload_pack)(context, builder) + + ox, oy, oz = rayOrigin.x, rayOrigin.y, rayOrigin.z + dx, dy, dz = rayDirection.x, rayDirection.y, rayDirection.z + + n_payload_registers = 3 + n_stub_output_operands = 32 - n_payload_registers + outputs = [output.p0, output.p1, output.p2] + [ + builder.load(builder.alloca(ir.IntType(32))) + for _ in range(n_stub_output_operands) + ] + + retty = ir.LiteralStructType([ir.IntType(32)] * 32) + asm = ir.InlineAsm( + ir.FunctionType(retty, []), + "call " + "($0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29," + "$30,$31)," + "_optix_trace_typed_32," + "($32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59," + "$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80);", + "=r," * 32 + "r,l,f,f,f,f,f,f,f,f,f,r,r,r,r,r,r," + "r," * 31 + "r", + side_effect=True, + ) + + zero = context.get_constant(types.int32, 0) + c_payload_registers = context.get_constant(types.int32, n_payload_registers) + args = [ + zero, + handle, + ox, + oy, + oz, + dx, + dy, + dz, + tmin, + tmax, + rayTime, + visibilityMask, + rayFlags, + SBToffset, + SBTstride, + missSBTIndex, + c_payload_registers, + ] + outputs + ret = builder.call(asm, args) + output.p0 = builder.extract_value(ret, 0) + output.p1 = builder.extract_value(ret, 1) + output.p2 = builder.extract_value(ret, 2) + return output._getvalue() diff --git a/examples/pyramid.py b/examples/pyramid.py new file mode 100755 index 0000000..2848c84 --- /dev/null +++ b/examples/pyramid.py @@ -0,0 +1,616 @@ +#!/usr/bin/env python3 + + +import ctypes # C interop helpers +import math +from enum import Enum + +import cupy as cp # CUDA bindings +import numpy as np # Packing of structures in C-compatible format +from numba import cuda, float32, int32, types, uint8, uint32 +from numba.core.extending import overload +from numba.cuda.compiler import compile_cuda as numba_compile_cuda +from numba.cuda import get_current_device + +from PIL import Image, ImageOps # Image IO + +import optix +from numba_support import float2, float3, uchar4, uint3, make_float2, make_float3, make_uchar4, make_uint3, params, OptixVisibilityMask, Float3, OPTIX_RAY_FLAG_NONE, MissDataStruct + +# ------------------------------------------------------------------------------- +# +# Util +# +# ------------------------------------------------------------------------------- +pix_width = 1024 +pix_height = 768 + + +class Logger: + def __init__(self): + self.num_mssgs = 0 + + def __call__(self, level, tag, mssg): + print("[{:>2}][{:>12}]: {}".format(level, tag, mssg)) + self.num_mssgs += 1 + + +def log_callback(level, tag, mssg): + print("[{:>2}][{:>12}]: {}".format(level, tag, mssg)) + + +def round_up(val, mult_of): + return val if val % mult_of == 0 else val + mult_of - val % mult_of + + +def get_aligned_itemsize(formats, alignment): + names = [] + for i in range(len(formats)): + names.append("x" + str(i)) + + temp_dtype = np.dtype({"names": names, "formats": formats, "align": True}) + return round_up(temp_dtype.itemsize, alignment) + + +def array_to_device_memory(numpy_array, stream=cp.cuda.Stream()): + + byte_size = numpy_array.size * numpy_array.dtype.itemsize + + h_ptr = ctypes.c_void_p(numpy_array.ctypes.data) + d_mem = cp.cuda.memory.alloc(byte_size) + d_mem.copy_from_async(h_ptr, byte_size, stream) + return d_mem + + +def compile_cuda(cuda_file): + with open(cuda_file, "rb") as f: + src = f.read() + from pynvrtc.compiler import Program + + prog = Program(src.decode(), cuda_file) + ptx = prog.compile( + [ + "-use_fast_math", + "-lineinfo", + "-default-device", + "-std=c++11", + "-rdc", + "true", + #'-IC:\\ProgramData\\NVIDIA Corporation\OptiX SDK 7.2.0\include', + #'-IC:\\Program Files\\NVIDIA GPU Computing Toolkit\CUDA\\v11.1\include' + "-I/usr/local/cuda/include", + f"-I{optix.include_path}", + ] + ) + return ptx + + +# ------------------------------------------------------------------------------- +# +# Optix setup +# +# ------------------------------------------------------------------------------- + + +def init_optix(): + print("Initializing cuda ...") + cp.cuda.runtime.free(0) + + print("Initializing optix ...") + optix.init() + + +def create_ctx(): + print("Creating optix device context ...") + + # Note that log callback data is no longer needed. We can + # instead send a callable class instance as the log-function + # which stores any data needed + global logger + logger = Logger() + + # OptiX param struct fields can be set with optional + # keyword constructor arguments. + ctx_options = optix.DeviceContextOptions( + logCallbackFunction=logger, logCallbackLevel=4 + ) + + # They can also be set and queried as properties on the struct + ctx_options.validationMode = optix.DEVICE_CONTEXT_VALIDATION_MODE_ALL + + cu_ctx = 0 + return optix.deviceContextCreate(cu_ctx, ctx_options) + + +def create_accel(ctx): + + accel_options = optix.AccelBuildOptions( + buildFlags=int(optix.BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS), + operation=optix.BUILD_OPERATION_BUILD, + ) + + global vertices + vertices = cp.array([-0.5, -0.5, 0.0, 0.5, -0.5, 0.0, 0.0, 0.5, 0.0], dtype="f4") + + triangle_input_flags = [optix.GEOMETRY_FLAG_NONE] + triangle_input = optix.BuildInputTriangleArray() + triangle_input.vertexFormat = optix.VERTEX_FORMAT_FLOAT3 + triangle_input.numVertices = len(vertices) + triangle_input.vertexBuffers = [vertices.data.ptr] + triangle_input.flags = triangle_input_flags + triangle_input.numSbtRecords = 1 + + gas_buffer_sizes = ctx.accelComputeMemoryUsage([accel_options], [triangle_input]) + + d_temp_buffer_gas = cp.cuda.alloc(gas_buffer_sizes.tempSizeInBytes) + d_gas_output_buffer = cp.cuda.alloc(gas_buffer_sizes.outputSizeInBytes) + + gas_handle = ctx.accelBuild( + 0, # CUDA stream + [accel_options], + [triangle_input], + d_temp_buffer_gas.ptr, + gas_buffer_sizes.tempSizeInBytes, + d_gas_output_buffer.ptr, + gas_buffer_sizes.outputSizeInBytes, + [], # emitted properties + ) + + return (gas_handle, d_gas_output_buffer) + + +def set_pipeline_options(): + return optix.PipelineCompileOptions( + usesMotionBlur=False, + traversableGraphFlags=int(optix.TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS), + numPayloadValues=3, + numAttributeValues=3, + exceptionFlags=int(optix.EXCEPTION_FLAG_NONE), + pipelineLaunchParamsVariableName="params", + usesPrimitiveTypeFlags=optix.PRIMITIVE_TYPE_FLAGS_TRIANGLE, + ) + + +def create_module(ctx, pipeline_options, ptx): + print("Creating optix module ...") + + module_options = optix.ModuleCompileOptions( + maxRegisterCount=optix.COMPILE_DEFAULT_MAX_REGISTER_COUNT, + optLevel=optix.COMPILE_OPTIMIZATION_DEFAULT, + debugLevel=optix.COMPILE_DEBUG_LEVEL_LINEINFO, + ) + + module, log = ctx.moduleCreateFromPTX(module_options, pipeline_options, ptx) + print("\tModule create log: <<<{}>>>".format(log)) + return module + + +def create_program_groups(ctx, raygen_module, miss_prog_module, hitgroup_module): + print("Creating program groups ... ") + + program_group_options = optix.ProgramGroupOptions() + + raygen_prog_group_desc = optix.ProgramGroupDesc() + raygen_prog_group_desc.raygenModule = raygen_module + raygen_prog_group_desc.raygenEntryFunctionName = "__raygen__rg" + + miss_prog_group_desc = optix.ProgramGroupDesc() + miss_prog_group_desc.missModule = miss_prog_module + miss_prog_group_desc.missEntryFunctionName = "__miss__ms" + + hitgroup_prog_group_desc = optix.ProgramGroupDesc() + hitgroup_prog_group_desc.hitgroupModuleCH = hitgroup_module + hitgroup_prog_group_desc.hitgroupEntryFunctionNameCH = "__closesthit__ch" + + prog_group, log = ctx.programGroupCreate( + [raygen_prog_group_desc, miss_prog_group_desc, hitgroup_prog_group_desc], + program_group_options, + ) + print("\tProgramGroup create log: <<<{}>>>".format(log)) + + return prog_group + + +def create_pipeline(ctx, program_groups, pipeline_compile_options): + print("Creating pipeline ... ") + + max_trace_depth = 1 + pipeline_link_options = optix.PipelineLinkOptions() + pipeline_link_options.maxTraceDepth = max_trace_depth + pipeline_link_options.debugLevel = optix.COMPILE_DEBUG_LEVEL_FULL + + log = "" + pipeline = ctx.pipelineCreate( + pipeline_compile_options, pipeline_link_options, program_groups, log + ) + + stack_sizes = optix.StackSizes() + for prog_group in program_groups: + optix.util.accumulateStackSizes(prog_group, stack_sizes) + + ( + dc_stack_size_from_trav, + dc_stack_size_from_state, + cc_stack_size, + ) = optix.util.computeStackSizes( + stack_sizes, max_trace_depth, 0, 0 # maxCCDepth # maxDCDepth + ) + + pipeline.setStackSize( + dc_stack_size_from_trav, + dc_stack_size_from_state, + cc_stack_size, + 1, # maxTraversableDepth + ) + + return pipeline + + +def create_sbt(prog_groups): + print("Creating sbt ... ") + + (raygen_prog_group, miss_prog_group, hitgroup_prog_group) = prog_groups + + global d_raygen_sbt + global d_miss_sbt + + header_format = "{}B".format(optix.SBT_RECORD_HEADER_SIZE) + + # + # raygen record + # + formats = [header_format] + itemsize = get_aligned_itemsize(formats, optix.SBT_RECORD_ALIGNMENT) + dtype = np.dtype( + {"names": ["header"], "formats": formats, "itemsize": itemsize, "align": True} + ) + h_raygen_sbt = np.array([0], dtype=dtype) + optix.sbtRecordPackHeader(raygen_prog_group, h_raygen_sbt) + global d_raygen_sbt + d_raygen_sbt = array_to_device_memory(h_raygen_sbt) + + # + # miss record + # + formats = [header_format, "f4", "f4", "f4"] + itemsize = get_aligned_itemsize(formats, optix.SBT_RECORD_ALIGNMENT) + dtype = np.dtype( + { + "names": ["header", "r", "g", "b"], + "formats": formats, + "itemsize": itemsize, + "align": True, + } + ) + h_miss_sbt = np.array([(0, 0.3, 0.1, 0.2)], dtype=dtype) + optix.sbtRecordPackHeader(miss_prog_group, h_miss_sbt) + global d_miss_sbt + d_miss_sbt = array_to_device_memory(h_miss_sbt) + + # + # hitgroup record + # + formats = [header_format] + itemsize = get_aligned_itemsize(formats, optix.SBT_RECORD_ALIGNMENT) + dtype = np.dtype( + {"names": ["header"], "formats": formats, "itemsize": itemsize, "align": True} + ) + h_hitgroup_sbt = np.array([(0)], dtype=dtype) + optix.sbtRecordPackHeader(hitgroup_prog_group, h_hitgroup_sbt) + global d_hitgroup_sbt + d_hitgroup_sbt = array_to_device_memory(h_hitgroup_sbt) + + sbt = optix.ShaderBindingTable() + sbt.raygenRecord = d_raygen_sbt.ptr + sbt.missRecordBase = d_miss_sbt.ptr + sbt.missRecordStrideInBytes = d_miss_sbt.mem.size + sbt.missRecordCount = 1 + sbt.hitgroupRecordBase = d_hitgroup_sbt.ptr + sbt.hitgroupRecordStrideInBytes = d_hitgroup_sbt.mem.size + sbt.hitgroupRecordCount = 1 + return sbt + + +def launch(pipeline, sbt, trav_handle): + print("Launching ... ") + + pix_bytes = pix_width * pix_height * 4 + + h_pix = np.zeros((pix_width, pix_height, 4), "B") + h_pix[0:pix_width, 0:pix_height] = [255, 128, 0, 255] + d_pix = cp.array(h_pix) + + params = [ + ("u8", "image", d_pix.data.ptr), + ("u4", "image_width", pix_width), + ("u4", "image_height", pix_height), + ("f4", "cam_eye_x", 0), + ("f4", "cam_eye_y", 0), + ("f4", "cam_eye_z", 2.0), + ("f4", "cam_U_x", 1.10457), + ("f4", "cam_U_y", 0), + ("f4", "cam_U_z", 0), + ("f4", "cam_V_x", 0), + ("f4", "cam_V_y", 0.828427), + ("f4", "cam_V_z", 0), + ("f4", "cam_W_x", 0), + ("f4", "cam_W_y", 0), + ("f4", "cam_W_z", -2.0), + ("u8", "trav_handle", trav_handle), + ] + + formats = [x[0] for x in params] + names = [x[1] for x in params] + values = [x[2] for x in params] + itemsize = get_aligned_itemsize(formats, 8) + params_dtype = np.dtype( + {"names": names, "formats": formats, "itemsize": itemsize, "align": True} + ) + h_params = np.array([tuple(values)], dtype=params_dtype) + d_params = array_to_device_memory(h_params) + + stream = cp.cuda.Stream() + optix.launch( + pipeline, + stream.ptr, + d_params.ptr, + h_params.dtype.itemsize, + sbt, + pix_width, + pix_height, + 1, # depth + ) + + stream.synchronize() + + h_pix = cp.asnumpy(d_pix) + return h_pix + + +# Numba compilation +# ----------------- + +# An equivalent to the compile_cuda function for Python kernels. The types of +# the arguments to the kernel must be provided, if there are any. + + +def compile_numba(f, sig=(), debug=False): + # Based on numba.cuda.compile_ptx. We don't just use + # compile_ptx_for_current_device because it generates a kernel with a + # mangled name. For proceeding beyond this prototype, an option should be + # added to compile_ptx in Numba to not mangle the function name. + + nvvm_options = { + "debug": debug, + "fastmath": False, + "opt": 0 if debug else 3, + } + + cres = numba_compile_cuda(f, None, sig, debug=debug, nvvm_options=nvvm_options) + fname = cres.fndesc.llvm_func_name + tgt = cres.target_context + filename = cres.type_annotation.filename + linenum = int(cres.type_annotation.linenum) + lib, kernel = tgt.prepare_cuda_kernel( + cres.library, cres.fndesc, debug, nvvm_options, filename, linenum + ) + cc = get_current_device().compute_capability + ptx = lib.get_asm_str(cc=cc) + + # Demangle name + mangled_name = kernel.name + original_name = cres.library.name + return ptx.replace(mangled_name, original_name) + + +# ------------------------------------------------------------------------------- +# +# User code / kernel - the following section is what we'd expect a user of +# PyOptiX to write. +# +# ------------------------------------------------------------------------------- + +# vec_math + +# Overload for Clamp +def clamp(x, a, b): + pass + + +@overload(clamp, target="cuda") +def jit_clamp(x, a, b): + if ( + isinstance(x, types.Float) + and isinstance(a, types.Float) + and isinstance(b, types.Float) + ): + + def clamp_float_impl(x, a, b): + return max(a, min(x, b)) + + return clamp_float_impl + elif ( + isinstance(x, Float3) + and isinstance(a, types.Float) + and isinstance(b, types.Float) + ): + + def clamp_float3_impl(x, a, b): + return make_float3(clamp(x.x, a, b), clamp(x.y, a, b), clamp(x.z, a, b)) + + return clamp_float3_impl + + +def dot(a, b): + pass + + +@overload(dot, target="cuda") +def jit_dot(a, b): + if isinstance(a, Float3) and isinstance(b, Float3): + def dot_float3_impl(a, b): + return a.x * b.x + a.y * b.y + a.z * b.z + + return dot_float3_impl + + +@cuda.jit(device=True) +def normalize(v): + invLen = float32(1.0) / math.sqrt(dot(v, v)) + return v * invLen + + +# Helpers + + +@cuda.jit(device=True) +def toSRGB(c): + # Use float32 for constants + invGamma = float32(1.0) / float32(2.4) + powed = make_float3( + math.pow(c.x, invGamma), math.pow(c.y, invGamma), math.pow(c.z, invGamma) + ) + return make_float3( + float32(12.92) * c.x + if c.x < float32(0.0031308) + else float32(1.055) * powed.x - float32(0.055), + float32(12.92) * c.y + if c.y < float32(0.0031308) + else float32(1.055) * powed.y - float32(0.055), + float32(12.92) * c.z + if c.z < float32(0.0031308) + else float32(1.055) * powed.z - float32(0.055), + ) + + +@cuda.jit(device=True) +def quantizeUnsigned8Bits(x): + x = clamp(x, float32(0.0), float32(1.0)) + N, Np1 = (1 << 8) - 1, 1 << 8 + return uint8(min(uint32(x * float32(Np1)), uint32(N))) + + +@cuda.jit(device=True) +def make_color(c): + srgb = toSRGB(clamp(c, float32(0.0), float32(1.0))) + return make_uchar4( + quantizeUnsigned8Bits(srgb.x), + quantizeUnsigned8Bits(srgb.y), + quantizeUnsigned8Bits(srgb.z), + uint8(255), + ) + + +# ray functions + + +@cuda.jit(device=True) +def setPayload(p): + optix.SetPayload_0(cuda.libdevice.float_as_int(p.x)) + optix.SetPayload_1(cuda.libdevice.float_as_int(p.y)) + optix.SetPayload_2(cuda.libdevice.float_as_int(p.z)) + + +@cuda.jit(device=True) +def computeRay(idx, dim): + U = params.cam_u + V = params.cam_v + W = params.cam_w + # Normalizing coordinates to [-1.0, 1.0] + d = float32(2.0) * make_float2( + float32(idx.x) / float32(dim.x), float32(idx.y) / float32(dim.y) + ) - float32(1.0) + + origin = params.cam_eye + direction = normalize(d.x * U + d.y * V + W) + return origin, direction + + +def __raygen__rg(): + # Lookup our location within the launch grid + idx = optix.GetLaunchIndex() + dim = optix.GetLaunchDimensions() + + # Map our launch idx to a screen location and create a ray from the camera + # location through the screen + ray_origin, ray_direction = computeRay(make_uint3(idx.x, idx.y, 0), dim) + + # Trace the ray against our scene hierarchy + payload_pack = optix.Trace( + params.handle, + ray_origin, + ray_direction, + float32(0.0), # Min intersection distance + float32(1e16), # Max intersection distance + float32(0.0), # rayTime -- used for motion blur + OptixVisibilityMask(255), # Specify always visible + # OptixRayFlags.OPTIX_RAY_FLAG_NONE, + uint32(OPTIX_RAY_FLAG_NONE), + uint32(0), # SBT offset -- See SBT discussion + uint32(1), # SBT stride -- See SBT discussion + uint32(0), # missSBTIndex -- See SBT discussion + ) + result = make_float3( + cuda.libdevice.int_as_float(payload_pack.p0), + cuda.libdevice.int_as_float(payload_pack.p1), + cuda.libdevice.int_as_float(payload_pack.p2), + ) + + # Record results in our output raster + params.image[idx.y * params.image_width + idx.x] = make_color(result) + + +def __miss__ms(): + miss_data = MissDataStruct(optix.GetSbtDataPointer()) + setPayload(miss_data.bg_color) + + +def __closesthit__ch(): + # When built-in triangle intersection is used, a number of fundamental + # attributes are provided by the OptiX API, indlucing barycentric coordinates. + barycentrics = optix.GetTriangleBarycentrics() + + setPayload(make_float3(barycentrics, float32(1.0))) + + +# ------------------------------------------------------------------------------- +# +# main +# +# ------------------------------------------------------------------------------- + + +def main(): + raygen_ptx = compile_numba(__raygen__rg) + miss_ptx = compile_numba(__miss__ms) + hitgroup_ptx = compile_numba(__closesthit__ch) + + # triangle_ptx = compile_cuda( "examples/triangle.cu" ) + + init_optix() + + ctx = create_ctx() + gas_handle, d_gas_output_buffer = create_accel(ctx) + pipeline_options = set_pipeline_options() + + raygen_module = create_module(ctx, pipeline_options, raygen_ptx) + miss_module = create_module(ctx, pipeline_options, miss_ptx) + hitgroup_module = create_module(ctx, pipeline_options, hitgroup_ptx) + + prog_groups = create_program_groups( + ctx, raygen_module, miss_module, hitgroup_module + ) + pipeline = create_pipeline(ctx, prog_groups, pipeline_options) + sbt = create_sbt(prog_groups) + pix = launch(pipeline, sbt, gas_handle) + + print("Total number of log messages: {}".format(logger.num_mssgs)) + + pix = pix.reshape((pix_height, pix_width, 4)) # PIL expects [ y, x ] resolution + img = ImageOps.flip(Image.fromarray(pix, "RGBA")) # PIL expects y = 0 at bottom + img.save("pyramid.png") + img.show() + + +if __name__ == "__main__": + main() From 691270e28dbaab47e11d78663b47304e3f70a3ca Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Wed, 19 Jan 2022 20:22:26 -0800 Subject: [PATCH 13/25] Getting pyramid example working --- examples/pyramid.py | 130 +++++++++++++++++++++++++++++++++----------- 1 file changed, 99 insertions(+), 31 deletions(-) diff --git a/examples/pyramid.py b/examples/pyramid.py index 2848c84..1c94ebe 100755 --- a/examples/pyramid.py +++ b/examples/pyramid.py @@ -9,13 +9,26 @@ import numpy as np # Packing of structures in C-compatible format from numba import cuda, float32, int32, types, uint8, uint32 from numba.core.extending import overload -from numba.cuda.compiler import compile_cuda as numba_compile_cuda from numba.cuda import get_current_device - +from numba.cuda.compiler import compile_cuda as numba_compile_cuda +from numba_support import ( + OPTIX_RAY_FLAG_NONE, + Float3, + MissDataStruct, + OptixVisibilityMask, + float2, + float3, + make_float2, + make_float3, + make_uchar4, + make_uint3, + params, + uchar4, + uint3, +) from PIL import Image, ImageOps # Image IO import optix -from numba_support import float2, float3, uchar4, uint3, make_float2, make_float3, make_uchar4, make_uint3, params, OptixVisibilityMask, Float3, OPTIX_RAY_FLAG_NONE, MissDataStruct # ------------------------------------------------------------------------------- # @@ -130,17 +143,42 @@ def create_accel(ctx): ) global vertices - vertices = cp.array([-0.5, -0.5, 0.0, 0.5, -0.5, 0.0, 0.0, 0.5, 0.0], dtype="f4") + global indices + # fmt: off + vertices = cp.array( + [ + -1.0, -1.0, 0, + 1.0, -1.0, 0, + 1.0, 1.0, 0, + -1.0, 1.0, 0, + 0, 0, 2.0 + ], dtype="f4" + ) + indices = cp.array( + [ + 0, 1, 2, + 0, 2, 3, + 0, 1, 4, + 1, 2, 4, + 2, 3, 4, + 3, 0, 4 + ], dtype="uint32" + ) + # fmt: on + + pyramid_input_flags = [optix.GEOMETRY_FLAG_NONE] - triangle_input_flags = [optix.GEOMETRY_FLAG_NONE] - triangle_input = optix.BuildInputTriangleArray() - triangle_input.vertexFormat = optix.VERTEX_FORMAT_FLOAT3 - triangle_input.numVertices = len(vertices) - triangle_input.vertexBuffers = [vertices.data.ptr] - triangle_input.flags = triangle_input_flags - triangle_input.numSbtRecords = 1 + pyramid_input = optix.BuildInputTriangleArray() + pyramid_input.vertexFormat = optix.VERTEX_FORMAT_FLOAT3 + pyramid_input.numVertices = len(vertices) + pyramid_input.vertexBuffers = [vertices.data.ptr] + pyramid_input.indexFormat = optix.INDICES_FORMAT_UNSIGNED_INT3 + pyramid_input.numIndexTriplets = len(indices) // 3 + pyramid_input.indexBuffer = indices.data.ptr + pyramid_input.flags = pyramid_input_flags + pyramid_input.numSbtRecords = 1 - gas_buffer_sizes = ctx.accelComputeMemoryUsage([accel_options], [triangle_input]) + gas_buffer_sizes = ctx.accelComputeMemoryUsage([accel_options], [pyramid_input]) d_temp_buffer_gas = cp.cuda.alloc(gas_buffer_sizes.tempSizeInBytes) d_gas_output_buffer = cp.cuda.alloc(gas_buffer_sizes.outputSizeInBytes) @@ -148,7 +186,7 @@ def create_accel(ctx): gas_handle = ctx.accelBuild( 0, # CUDA stream [accel_options], - [triangle_input], + [pyramid_input], d_temp_buffer_gas.ptr, gas_buffer_sizes.tempSizeInBytes, d_gas_output_buffer.ptr, @@ -311,7 +349,7 @@ def create_sbt(prog_groups): return sbt -def launch(pipeline, sbt, trav_handle): +def launch(pipeline, sbt, trav_handle, cam): print("Launching ... ") pix_bytes = pix_width * pix_height * 4 @@ -320,22 +358,24 @@ def launch(pipeline, sbt, trav_handle): h_pix[0:pix_width, 0:pix_height] = [255, 128, 0, 255] d_pix = cp.array(h_pix) + cam_eye, cam_U, cam_V, cam_W = cam + params = [ ("u8", "image", d_pix.data.ptr), ("u4", "image_width", pix_width), ("u4", "image_height", pix_height), - ("f4", "cam_eye_x", 0), - ("f4", "cam_eye_y", 0), - ("f4", "cam_eye_z", 2.0), - ("f4", "cam_U_x", 1.10457), - ("f4", "cam_U_y", 0), - ("f4", "cam_U_z", 0), - ("f4", "cam_V_x", 0), - ("f4", "cam_V_y", 0.828427), - ("f4", "cam_V_z", 0), - ("f4", "cam_W_x", 0), - ("f4", "cam_W_y", 0), - ("f4", "cam_W_z", -2.0), + ("f4", "cam_eye_x", cam_eye[0]), + ("f4", "cam_eye_y", cam_eye[1]), + ("f4", "cam_eye_z", cam_eye[2]), + ("f4", "cam_U_x", cam_U[0]), + ("f4", "cam_U_y", cam_U[1]), + ("f4", "cam_U_z", cam_U[2]), + ("f4", "cam_V_x", cam_V[0]), + ("f4", "cam_V_y", cam_V[1]), + ("f4", "cam_V_z", cam_V[2]), + ("f4", "cam_W_x", cam_W[0]), + ("f4", "cam_W_y", cam_W[1]), + ("f4", "cam_W_z", cam_W[2]), ("u8", "trav_handle", trav_handle), ] @@ -448,6 +488,7 @@ def dot(a, b): @overload(dot, target="cuda") def jit_dot(a, b): if isinstance(a, Float3) and isinstance(b, Float3): + def dot_float3_impl(a, b): return a.x * b.x + a.y * b.y + a.z * b.z @@ -575,12 +616,12 @@ def __closesthit__ch(): # ------------------------------------------------------------------------------- # -# main +# render # # ------------------------------------------------------------------------------- -def main(): +def render(cam, t): raygen_ptx = compile_numba(__raygen__rg) miss_ptx = compile_numba(__miss__ms) hitgroup_ptx = compile_numba(__closesthit__ch) @@ -602,15 +643,42 @@ def main(): ) pipeline = create_pipeline(ctx, prog_groups, pipeline_options) sbt = create_sbt(prog_groups) - pix = launch(pipeline, sbt, gas_handle) + pix = launch(pipeline, sbt, gas_handle, cam) print("Total number of log messages: {}".format(logger.num_mssgs)) pix = pix.reshape((pix_height, pix_width, 4)) # PIL expects [ y, x ] resolution img = ImageOps.flip(Image.fromarray(pix, "RGBA")) # PIL expects y = 0 at bottom - img.save("pyramid.png") + img.save(f"output/pyramid_{t}.png") img.show() +def lookat(eye, at, up): + W = at - eye + Wnorm = np.linalg.norm(W) + if np.allclose(Wnorm, 0.0): + raise ValueError("Target too close to eye.") + W = W / Wnorm + U = np.cross(W, up) + U = U / np.linalg.norm(U) + V = np.cross(U, W) + V = V / np.linalg.norm(V) + return U, V, W + + +def polar2cart(r, theta): + return (r * math.cos(theta), r * math.sin(theta)) + + if __name__ == "__main__": - main() + + for t in range(0, 361, 6): + rad = math.radians(t) + cart = polar2cart(1.5, rad) + eye = np.array([*cart, 2.5]) + at = np.array([0.0, 0.0, 0.0]) + up = np.array([0.0, 0.0, 1.0]) + U, V, W = lookat(eye, at, up) + + # print(eye, U, V, W) + render((eye, U, V, W), t) From 9c825c0381f79a78ed4175f24a8c127d7e31b1eb Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 20 Jan 2022 11:08:22 -0800 Subject: [PATCH 14/25] Consolidate triangle example to numba_support too --- examples/triangle.py | 988 +------------------------------------------ 1 file changed, 17 insertions(+), 971 deletions(-) diff --git a/examples/triangle.py b/examples/triangle.py index aa73976..8a6169c 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -4,985 +4,31 @@ import ctypes # C interop helpers import math from enum import Enum -from operator import add, mul, sub import cupy as cp # CUDA bindings import numpy as np # Packing of structures in C-compatible format -from llvmlite import ir from numba import cuda, float32, int32, types, uint8, uint32 -from numba.core import cgutils -from numba.core.extending import ( - make_attribute_wrapper, - models, - overload, - register_model, - type_callable, - typeof_impl, -) -from numba.core.imputils import lower_constant -from numba.core.typing.templates import ( - AttributeTemplate, - ConcreteTemplate, - signature, -) +from numba.core.extending import overload from numba.cuda import get_current_device from numba.cuda.compiler import compile_cuda as numba_compile_cuda -from numba.cuda.cudadecl import register, register_attr, register_global -from numba.cuda.cudadrv import nvvm -from numba.cuda.cudaimpl import lower -from numba.cuda.types import dim3 -from PIL import Image, ImageOps # Image IO - -import optix - -# ------------------------------------------------------------------------------- -# -# Numba extensions for general CUDA / OptiX support -# -# ------------------------------------------------------------------------------- - -# UChar4 -# ------ - -# Numba presently doesn't implement the UChar4 type (which is fairly standard -# CUDA) so we provide some minimal support for it here. - - -# Prototype a function to construct a uchar4 - - -def make_uchar4(x, y, z, w): - pass - - -# UChar4 typing - - -class UChar4(types.Type): - def __init__(self): - super().__init__(name="UChar4") - - -uchar4 = UChar4() - - -@register -class MakeUChar4(ConcreteTemplate): - key = make_uchar4 - cases = [signature(uchar4, types.uchar, types.uchar, types.uchar, types.uchar)] - - -register_global(make_uchar4, types.Function(MakeUChar4)) - - -# UChar4 data model - - -@register_model(UChar4) -class UChar4Model(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("x", types.uchar), - ("y", types.uchar), - ("z", types.uchar), - ("w", types.uchar), - ] - super().__init__(dmm, fe_type, members) - - -make_attribute_wrapper(UChar4, "x", "x") -make_attribute_wrapper(UChar4, "y", "y") -make_attribute_wrapper(UChar4, "z", "z") -make_attribute_wrapper(UChar4, "w", "w") - - -# UChar4 lowering - - -@lower(make_uchar4, types.uchar, types.uchar, types.uchar, types.uchar) -def lower_make_uchar4(context, builder, sig, args): - uc4 = cgutils.create_struct_proxy(uchar4)(context, builder) - uc4.x = args[0] - uc4.y = args[1] - uc4.z = args[2] - uc4.w = args[3] - return uc4._getvalue() - - -# float3 -# ------ - -# Float3 typing - - -class Float3(types.Type): - def __init__(self): - super().__init__(name="Float3") - - -float3 = Float3() - - -# Float2 typing (forward declaration) - - -class Float2(types.Type): - def __init__(self): - super().__init__(name="Float2") - - -float2 = Float2() - - -# Float3 data model - - -@register_model(Float3) -class Float3Model(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("x", types.float32), - ("y", types.float32), - ("z", types.float32), - ] - super().__init__(dmm, fe_type, members) - - -make_attribute_wrapper(Float3, "x", "x") -make_attribute_wrapper(Float3, "y", "y") -make_attribute_wrapper(Float3, "z", "z") - - -def lower_float3_ops(op): - class Float3_op_template(ConcreteTemplate): - key = op - cases = [ - signature(float3, float3, float3), - signature(float3, types.float32, float3), - signature(float3, float3, types.float32), - ] - - def float3_op_impl(context, builder, sig, args): - def op_attr(lhs, rhs, res, attr): - setattr( - res, - attr, - context.compile_internal( - builder, - lambda x, y: op(x, y), - signature(types.float32, types.float32, types.float32), - (getattr(lhs, attr), getattr(rhs, attr)), - ), - ) - - arg0, arg1 = args - - if isinstance(sig.args[0], types.Float): - lf3 = cgutils.create_struct_proxy(float3)(context, builder) - lf3.x = arg0 - lf3.y = arg0 - lf3.z = arg0 - else: - lf3 = cgutils.create_struct_proxy(float3)(context, builder, value=args[0]) - - if isinstance(sig.args[1], types.Float): - rf3 = cgutils.create_struct_proxy(float3)(context, builder) - rf3.x = arg1 - rf3.y = arg1 - rf3.z = arg1 - else: - rf3 = cgutils.create_struct_proxy(float3)(context, builder, value=args[1]) - - res = cgutils.create_struct_proxy(float3)(context, builder) - op_attr(lf3, rf3, res, "x") - op_attr(lf3, rf3, res, "y") - op_attr(lf3, rf3, res, "z") - return res._getvalue() - - register_global(op, types.Function(Float3_op_template)) - lower(op, float3, float3)(float3_op_impl) - lower(op, types.float32, float3)(float3_op_impl) - lower(op, float3, types.float32)(float3_op_impl) - - -lower_float3_ops(mul) -lower_float3_ops(add) - - -@lower(add, float32, float3) -def add_float32_float3_impl(context, builder, sig, args): - s = args[0] - rhs = cgutils.create_struct_proxy(float3)(context, builder, args[1]) - res = cgutils.create_struct_proxy(float3)(context, builder) - res.x = builder.fadd(s, rhs.x) - res.y = builder.fadd(s, rhs.y) - res.z = builder.fadd(s, rhs.z) - return res._getvalue() - - -@lower(add, float3, float32) -def add_float3_float32_impl(context, builder, sig, args): - lhs = cgutils.create_struct_proxy(float3)(context, builder, args[0]) - s = args[1] - res = cgutils.create_struct_proxy(float3)(context, builder) - res.x = builder.fadd(lhs.x, s) - res.y = builder.fadd(lhs.y, s) - res.z = builder.fadd(lhs.z, s) - return res._getvalue() - - -# Prototype a function to construct a float3 - - -def make_float3(x, y, z): - pass - - -@register -class MakeFloat3(ConcreteTemplate): - key = make_float3 - cases = [ - signature(float3, types.float32, types.float32, types.float32), - signature(float3, float2, types.float32), - ] - - -register_global(make_float3, types.Function(MakeFloat3)) - - -# make_float3 lowering - - -@lower(make_float3, types.float32, types.float32, types.float32) -def lower_make_float3(context, builder, sig, args): - f3 = cgutils.create_struct_proxy(float3)(context, builder) - f3.x = args[0] - f3.y = args[1] - f3.z = args[2] - return f3._getvalue() - - -@lower(make_float3, float2, types.float32) -def lower_make_float3(context, builder, sig, args): - f2 = cgutils.create_struct_proxy(float2)(context, builder, args[0]) - f3 = cgutils.create_struct_proxy(float3)(context, builder) - f3.x = f2.x - f3.y = f2.y - f3.z = args[1] - return f3._getvalue() - - -# float2 -# ------ - - -# Float2 data model - - -@register_model(Float2) -class Float2Model(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("x", types.float32), - ("y", types.float32), - ] - super().__init__(dmm, fe_type, members) - - -make_attribute_wrapper(Float2, "x", "x") -make_attribute_wrapper(Float2, "y", "y") - - -def lower_float2_ops(op): - class Float2_op_template(ConcreteTemplate): - key = op - cases = [ - signature(float2, float2, float2), - signature(float2, types.float32, float2), - signature(float2, float2, types.float32), - ] - - def float2_op_impl(context, builder, sig, args): - def op_attr(lhs, rhs, res, attr): - setattr( - res, - attr, - context.compile_internal( - builder, - lambda x, y: op(x, y), - signature(types.float32, types.float32, types.float32), - (getattr(lhs, attr), getattr(rhs, attr)), - ), - ) - - arg0, arg1 = args - - if isinstance(sig.args[0], types.Float): - lf2 = cgutils.create_struct_proxy(float2)(context, builder) - lf2.x = arg0 - lf2.y = arg0 - else: - lf2 = cgutils.create_struct_proxy(float2)(context, builder, value=args[0]) - - if isinstance(sig.args[1], types.Float): - rf2 = cgutils.create_struct_proxy(float2)(context, builder) - rf2.x = arg1 - rf2.y = arg1 - else: - rf2 = cgutils.create_struct_proxy(float2)(context, builder, value=args[1]) - - res = cgutils.create_struct_proxy(float2)(context, builder) - op_attr(lf2, rf2, res, "x") - op_attr(lf2, rf2, res, "y") - return res._getvalue() - - register_global(op, types.Function(Float2_op_template)) - lower(op, float2, float2)(float2_op_impl) - lower(op, types.Float, float2)(float2_op_impl) - lower(op, float2, types.Float)(float2_op_impl) - - -lower_float2_ops(mul) -lower_float2_ops(sub) - - -# Prototype a function to construct a float2 - - -def make_float2(x, y): - pass - - -@register -class MakeFloat2(ConcreteTemplate): - key = make_float2 - cases = [signature(float2, types.float32, types.float32)] - - -register_global(make_float2, types.Function(MakeFloat2)) - - -# make_float2 lowering - - -@lower(make_float2, types.float32, types.float32) -def lower_make_float2(context, builder, sig, args): - f2 = cgutils.create_struct_proxy(float2)(context, builder) - f2.x = args[0] - f2.y = args[1] - return f2._getvalue() - - -# uint3 -# ------ - - -class UInt3(types.Type): - def __init__(self): - super().__init__(name="UInt3") - - -uint3 = UInt3() - - -# UInt3 data model - - -@register_model(UInt3) -class UInt3Model(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("x", types.uint32), - ("y", types.uint32), - ("z", types.uint32), - ] - super().__init__(dmm, fe_type, members) - - -make_attribute_wrapper(UInt3, "x", "x") -make_attribute_wrapper(UInt3, "y", "y") -make_attribute_wrapper(UInt3, "z", "z") - - -# Prototype a function to construct a uint3 - - -def make_uint3(x, y, z): - pass - - -@register -class MakeUInt3(ConcreteTemplate): - key = make_uint3 - cases = [signature(uint3, types.uint32, types.uint32, types.uint32)] - - -register_global(make_uint3, types.Function(MakeUInt3)) - - -# make_uint3 lowering - - -@lower(make_uint3, types.uint32, types.uint32, types.uint32) -def lower_make_uint3(context, builder, sig, args): - # u4 = uint32 - u4_3 = cgutils.create_struct_proxy(uint3)(context, builder) - u4_3.x = args[0] - u4_3.y = args[1] - u4_3.z = args[2] - return u4_3._getvalue() - - -# Temporary Payload Parameter Pack -class PayloadPack(types.Type): - def __init__(self): - super().__init__(name="PayloadPack") - - -payload_pack = PayloadPack() - - -# UInt3 data model - - -@register_model(PayloadPack) -class PayloadPackModel(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("p0", types.uint32), - ("p1", types.uint32), - ("p2", types.uint32), - ] - super().__init__(dmm, fe_type, members) - - -make_attribute_wrapper(PayloadPack, "p0", "p0") -make_attribute_wrapper(PayloadPack, "p1", "p1") -make_attribute_wrapper(PayloadPack, "p2", "p2") - -# OptiX typedefs and enums -# ----------- - -OptixVisibilityMask = types.Integer("OptixVisibilityMask", bitwidth=32, signed=False) -OptixTraversableHandle = types.Integer( - "OptixTraversableHandle", bitwidth=64, signed=False -) - - -OPTIX_RAY_FLAG_NONE = 0 -# class OptixRayFlags(Enum): -# OPTIX_RAY_FLAG_NONE = 0 -# OPTIX_RAY_FLAG_DISABLE_ANYHIT = 1 << 0 -# OPTIX_RAY_FLAG_ENFORCE_ANYHIT = 1 << 1 -# OPTIX_RAY_FLAG_TERMINATE_ON_FIRST_HIT = 1 << 2 -# OPTIX_RAY_FLAG_DISABLE_CLOSESTHIT = 1 << 3, -# OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES = 1 << 4 -# OPTIX_RAY_FLAG_CULL_FRONT_FACING_TRIANGLES = 1 << 5 -# OPTIX_RAY_FLAG_CULL_DISABLED_ANYHIT = 1 << 6 -# OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT = 1 << 7 - - -# OptiX types -# ----------- - -# Typing for OptiX types - - -class SbtDataPointer(types.RawPointer): - def __init__(self): - super().__init__(name="SbtDataPointer") - - -sbt_data_pointer = SbtDataPointer() - - -# Models for OptiX types - - -@register_model(SbtDataPointer) -class SbtDataPointerModel(models.OpaqueModel): - pass - - -# Params -# ------------ - -# Structures as declared in triangle.h - - -class ParamsStruct: - fields = ( - ("image", "uchar4*"), - ("image_width", "unsigned int"), - ("image_height", "unsigned int"), - ("cam_eye", "float3"), - ("cam_u", "float3"), - ("cam_v", "float3"), - ("cam_w", "float3"), - ("handle", "OptixTraversableHandle"), - ) - - -# "Declare" a global called params -params = ParamsStruct() - - -class Params(types.Type): - def __init__(self): - super().__init__(name="ParamsType") - - -params_type = Params() - - -# ParamsStruct data model - - -@register_model(Params) -class ParamsModel(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("image", types.CPointer(uchar4)), - ("image_width", types.uint32), - ("image_height", types.uint32), - ("cam_eye", float3), - ("cam_u", float3), - ("cam_v", float3), - ("cam_w", float3), - ("handle", OptixTraversableHandle), - ] - super().__init__(dmm, fe_type, members) - - -make_attribute_wrapper(Params, "image", "image") -make_attribute_wrapper(Params, "image_width", "image_width") -make_attribute_wrapper(Params, "image_height", "image_height") -make_attribute_wrapper(Params, "cam_eye", "cam_eye") -make_attribute_wrapper(Params, "cam_u", "cam_u") -make_attribute_wrapper(Params, "cam_v", "cam_v") -make_attribute_wrapper(Params, "cam_w", "cam_w") -make_attribute_wrapper(Params, "handle", "handle") - - -@typeof_impl.register(ParamsStruct) -def typeof_params(val, c): - return params_type - - -# ParamsStruct lowering -# The below makes 'param' a global variable, accessible from any user defined -# kernels. - - -@lower_constant(Params) -def constant_params(context, builder, ty, pyval): - try: - gvar = builder.module.get_global("params") - except KeyError: - llty = context.get_value_type(ty) - gvar = cgutils.add_global_variable( - builder.module, llty, "params", addrspace=nvvm.ADDRSPACE_CONSTANT - ) - gvar.linkage = "external" - gvar.global_constant = True - - return builder.load(gvar) - - -# MissData -# ------------ - -# Structures as declared in triangle.h -class MissDataStruct: - fields = ("bg_color", "float3") - - -MissData = MissDataStruct() - - -class MissData(types.Type): - def __init__(self): - super().__init__(name="MissDataType") - - -miss_data_type = MissData() - - -@register_model(MissData) -class MissDataModel(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("bg_color", float3), - ] - super().__init__(dmm, fe_type, members) - - -make_attribute_wrapper(MissData, "bg_color", "bg_color") - - -@typeof_impl.register(MissDataStruct) -def typeof_miss_data(val, c): - return miss_data_type - - -# MissData Constructor -@type_callable(MissDataStruct) -def type_miss_data_struct(context): - def typer(sbt_data_pointer): - if isinstance(sbt_data_pointer, SbtDataPointer): - return miss_data_type - - return typer - - -@lower(MissDataStruct, sbt_data_pointer) -def lower_miss_data_ctor(context, builder, sig, args): - # Anyway to err if this ctor is not called inside __miss__* program? - # TODO: Optimize - ptr = args[0] - ptr = builder.bitcast(ptr, context.get_value_type(miss_data_type).as_pointer()) - - bg_color_ptr = cgutils.gep_inbounds(builder, ptr, 0, 0) - - xptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 0) - yptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 1) - zptr = cgutils.gep_inbounds(builder, bg_color_ptr, 0, 2) - - output_miss_data = cgutils.create_struct_proxy(miss_data_type)(context, builder) - output_bg_color_ptr = cgutils.gep_inbounds( - builder, output_miss_data._getpointer(), 0, 0 - ) - output_bg_color_x_ptr = cgutils.gep_inbounds(builder, output_bg_color_ptr, 0, 0) - output_bg_color_y_ptr = cgutils.gep_inbounds(builder, output_bg_color_ptr, 0, 1) - output_bg_color_z_ptr = cgutils.gep_inbounds(builder, output_bg_color_ptr, 0, 2) - - x = builder.load(xptr) - y = builder.load(yptr) - z = builder.load(zptr) - - builder.store(x, output_bg_color_x_ptr) - builder.store(y, output_bg_color_y_ptr) - builder.store(z, output_bg_color_z_ptr) - - # Doesn't seem to do what's expected? - # miss_data.bg_color.x = builder.load(xptr) - # miss_data.bg_color.y = builder.load(yptr) - # miss_data.bg_color.z = builder.load(zptr) - return output_miss_data._getvalue() - - -# OptiX functions -# --------------- - -# Here we "prototype" the OptiX functions that the user will call in their -# kernels, so that Numba has something to refer to when compiling the kernel. - - -def _optix_GetLaunchIndex(): - pass - - -def _optix_GetLaunchDimensions(): - pass - - -def _optix_GetSbtDataPointer(): - pass - - -def _optix_SetPayload_0(): - pass - - -def _optix_SetPayload_1(): - pass - - -def _optix_SetPayload_2(): - pass - - -def _optix_GetTriangleBarycentrics(): - pass - - -def _optix_Trace(): - pass - - -# Monkey-patch the functions into the optix module, so the user can write -# optix.GetLaunchIndex etc., for symmetry with the rest of the API implemented -# in PyOptiX. - -optix.GetLaunchIndex = _optix_GetLaunchIndex -optix.GetLaunchDimensions = _optix_GetLaunchDimensions -optix.GetSbtDataPointer = _optix_GetSbtDataPointer -optix.GetTriangleBarycentrics = _optix_GetTriangleBarycentrics -optix.SetPayload_0 = _optix_SetPayload_0 -optix.SetPayload_1 = _optix_SetPayload_1 -optix.SetPayload_2 = _optix_SetPayload_2 - -optix.Trace = _optix_Trace - - -# OptiX function typing - - -@register -class OptixGetLaunchIndex(ConcreteTemplate): - key = optix.GetLaunchIndex - cases = [signature(dim3)] - - -@register -class OptixGetLaunchDimensions(ConcreteTemplate): - key = optix.GetLaunchDimensions - cases = [signature(dim3)] - - -@register -class OptixGetSbtDataPointer(ConcreteTemplate): - key = optix.GetSbtDataPointer - cases = [signature(sbt_data_pointer)] - - -def registerSetPayload(reg): - class OptixSetPayloadReg(ConcreteTemplate): - key = getattr(optix, "SetPayload_" + str(reg)) - cases = [signature(types.void, uint32)] - - register(OptixSetPayloadReg) - return OptixSetPayloadReg - - -OptixSetPayload_0 = registerSetPayload(0) -OptixSetPayload_1 = registerSetPayload(1) -OptixSetPayload_2 = registerSetPayload(2) - - -@register -class OptixGetTriangleBarycentrics(ConcreteTemplate): - key = optix.GetTriangleBarycentrics - cases = [signature(float2)] - - -@register -class OptixTrace(ConcreteTemplate): - key = optix.Trace - cases = [ - signature( - payload_pack, - OptixTraversableHandle, - float3, - float3, - float32, - float32, - float32, - OptixVisibilityMask, - uint32, - uint32, - uint32, - uint32, - ) - ] - - -@register_attr -class OptixModuleTemplate(AttributeTemplate): - key = types.Module(optix) - - def resolve_GetLaunchIndex(self, mod): - return types.Function(OptixGetLaunchIndex) - - def resolve_GetLaunchDimensions(self, mod): - return types.Function(OptixGetLaunchDimensions) - - def resolve_GetSbtDataPointer(self, mod): - return types.Function(OptixGetSbtDataPointer) - - def resolve_SetPayload_0(self, mod): - return types.Function(OptixSetPayload_0) - - def resolve_SetPayload_1(self, mod): - return types.Function(OptixSetPayload_1) - - def resolve_SetPayload_2(self, mod): - return types.Function(OptixSetPayload_2) - - def resolve_GetTriangleBarycentrics(self, mod): - return types.Function(OptixGetTriangleBarycentrics) - - def resolve_Trace(self, mod): - return types.Function(OptixTrace) - - -# OptiX function lowering - - -@lower(optix.GetLaunchIndex) -def lower_optix_getLaunchIndex(context, builder, sig, args): - def get_launch_index(axis): - asm = ir.InlineAsm( - ir.FunctionType(ir.IntType(32), []), - f"call ($0), _optix_get_launch_index_{axis}, ();", - "=r", - ) - return builder.call(asm, []) - - index = cgutils.create_struct_proxy(dim3)(context, builder) - index.x = get_launch_index("x") - index.y = get_launch_index("y") - index.z = get_launch_index("z") - return index._getvalue() - - -@lower(optix.GetLaunchDimensions) -def lower_optix_getLaunchDimensions(context, builder, sig, args): - def get_launch_dimensions(axis): - asm = ir.InlineAsm( - ir.FunctionType(ir.IntType(32), []), - f"call ($0), _optix_get_launch_dimension_{axis}, ();", - "=r", - ) - return builder.call(asm, []) - - index = cgutils.create_struct_proxy(dim3)(context, builder) - index.x = get_launch_dimensions("x") - index.y = get_launch_dimensions("y") - index.z = get_launch_dimensions("z") - return index._getvalue() - - -@lower(optix.GetSbtDataPointer) -def lower_optix_getSbtDataPointer(context, builder, sig, args): - asm = ir.InlineAsm( - ir.FunctionType(ir.IntType(64), []), - "call ($0), _optix_get_sbt_data_ptr_64, ();", - "=l", - ) - ptr = builder.call(asm, []) - ptr = builder.inttoptr(ptr, ir.IntType(8).as_pointer()) - return ptr - - -def lower_optix_SetPayloadReg(reg): - def lower_optix_SetPayload_impl(context, builder, sig, args): - asm = ir.InlineAsm( - ir.FunctionType(ir.VoidType(), [ir.IntType(32), ir.IntType(32)]), - f"call _optix_set_payload, ($0, $1);", - "r,r", - ) - builder.call(asm, [context.get_constant(types.int32, reg), args[0]]) - - lower(getattr(optix, f"SetPayload_{reg}"), uint32)(lower_optix_SetPayload_impl) - - -lower_optix_SetPayloadReg(0) -lower_optix_SetPayloadReg(1) -lower_optix_SetPayloadReg(2) - - -@lower(optix.GetTriangleBarycentrics) -def lower_optix_getTriangleBarycentrics(context, builder, sig, args): - f2 = cgutils.create_struct_proxy(float2)(context, builder) - retty = ir.LiteralStructType([ir.FloatType(), ir.FloatType()]) - asm = ir.InlineAsm( - ir.FunctionType(retty, []), - "call ($0, $1), _optix_get_triangle_barycentrics, ();", - "=f,=f", - ) - ret = builder.call(asm, []) - f2.x = builder.extract_value(ret, 0) - f2.y = builder.extract_value(ret, 1) - return f2._getvalue() - - -@lower( - optix.Trace, - OptixTraversableHandle, - float3, - float3, - float32, - float32, - float32, +from numba_support import ( + OPTIX_RAY_FLAG_NONE, + Float3, + MissDataStruct, OptixVisibilityMask, - uint32, - uint32, - uint32, - uint32, + float2, + float3, + make_float2, + make_float3, + make_uchar4, + make_uint3, + params, + uchar4, + uint3, ) -def lower_optix_Trace(context, builder, sig, args): - # Only implements the version that accepts 3 payload registers - # TODO: Optimize returns - - ( - handle, - rayOrigin, - rayDirection, - tmin, - tmax, - rayTime, - visibilityMask, - rayFlags, - SBToffset, - SBTstride, - missSBTIndex, - ) = args - - rayOrigin = cgutils.create_struct_proxy(float3)(context, builder, rayOrigin) - rayDirection = cgutils.create_struct_proxy(float3)(context, builder, rayDirection) - output = cgutils.create_struct_proxy(payload_pack)(context, builder) - - ox, oy, oz = rayOrigin.x, rayOrigin.y, rayOrigin.z - dx, dy, dz = rayDirection.x, rayDirection.y, rayDirection.z - - n_payload_registers = 3 - n_stub_output_operands = 32 - n_payload_registers - outputs = [output.p0, output.p1, output.p2] + [ - builder.load(builder.alloca(ir.IntType(32))) - for _ in range(n_stub_output_operands) - ] - - retty = ir.LiteralStructType([ir.IntType(32)] * 32) - asm = ir.InlineAsm( - ir.FunctionType(retty, []), - "call " - "($0,$1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,$17,$18,$19,$20,$21,$22,$23,$24,$25,$26,$27,$28,$29," - "$30,$31)," - "_optix_trace_typed_32," - "($32,$33,$34,$35,$36,$37,$38,$39,$40,$41,$42,$43,$44,$45,$46,$47,$48,$49,$50,$51,$52,$53,$54,$55,$56,$57,$58,$59," - "$60,$61,$62,$63,$64,$65,$66,$67,$68,$69,$70,$71,$72,$73,$74,$75,$76,$77,$78,$79,$80);", - "=r," * 32 + "r,l,f,f,f,f,f,f,f,f,f,r,r,r,r,r,r," + "r," * 31 + "r", - side_effect=True, - ) - - zero = context.get_constant(types.int32, 0) - c_payload_registers = context.get_constant(types.int32, n_payload_registers) - args = [ - zero, - handle, - ox, - oy, - oz, - dx, - dy, - dz, - tmin, - tmax, - rayTime, - visibilityMask, - rayFlags, - SBToffset, - SBTstride, - missSBTIndex, - c_payload_registers, - ] + outputs - ret = builder.call(asm, args) - output.p0 = builder.extract_value(ret, 0) - output.p1 = builder.extract_value(ret, 1) - output.p2 = builder.extract_value(ret, 2) - return output._getvalue() +from PIL import Image, ImageOps # Image IO +import optix # ------------------------------------------------------------------------------- # @@ -1576,7 +622,7 @@ def main(): pix = pix.reshape((pix_height, pix_width, 4)) # PIL expects [ y, x ] resolution img = ImageOps.flip(Image.fromarray(pix, "RGBA")) # PIL expects y = 0 at bottom - img.save("my.png") + img.save("triangle.png") img.show() From d69175a6f129c4903c09018172490b408e01f555 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 20 Jan 2022 11:09:51 -0800 Subject: [PATCH 15/25] Removing pre-commit files before syncing with upstream --- .flake8 | 19 ------------------- .gitignore | 22 ---------------------- .pre-commit-config.yaml | 33 --------------------------------- setup.cfg | 19 ------------------- 4 files changed, 93 deletions(-) delete mode 100644 .flake8 delete mode 100644 .gitignore delete mode 100644 .pre-commit-config.yaml delete mode 100644 setup.cfg diff --git a/.flake8 b/.flake8 deleted file mode 100644 index f2d388a..0000000 --- a/.flake8 +++ /dev/null @@ -1,19 +0,0 @@ -[flake8] -exclude = __init__.py -ignore = - # line break before binary operator - W503, - # whitespace before : - E203 - -[pydocstyle] -match = .*\.py -match-dir = examples -# In addition to numpy style, we additionally ignore: -add-ignore = - # magic methods - D105, - # no docstring in __init__ - D107, - # newlines before docstrings - D204 diff --git a/.gitignore b/.gitignore deleted file mode 100644 index c958bba..0000000 --- a/.gitignore +++ /dev/null @@ -1,22 +0,0 @@ -## Common -__pycache__ -*.py[cod] -*$py.class -*.a -*.o -*.so -*.dylib -.cache -.vscode -*.swp -*.pytest_cache -DartConfiguration.tcl -.DS_Store -*.manifest -*.spec -.nfs* -.clangd - -## build files -build/ -*.ptx diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml deleted file mode 100644 index 15b1a1b..0000000 --- a/.pre-commit-config.yaml +++ /dev/null @@ -1,33 +0,0 @@ -repos: - - repo: https://github.com/PyCQA/isort - rev: 5.6.4 - hooks: - - id: isort - alias: isort - name: isort - args: ["--settings-path=setup.cfg"] - files: examples/.* - exclude: __init__.py$ - types: [text] - types_or: [python] - - repo: https://github.com/psf/black - rev: 19.10b0 - hooks: - - id: black - files: examples/.* - - repo: https://github.com/PyCQA/flake8 - rev: 3.8.3 - hooks: - - id: flake8 - alias: flake8 - name: flake8 - args: ["--config=.flake8"] - files: python/.*\.py$ - # - repo: https://github.com/PyCQA/pydocstyle - # rev: 6.1.1 - # hooks: - # - id: pydocstyle - # args: ["--config=.flake8"] - -default_language_version: - python: python3 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index d839db8..0000000 --- a/setup.cfg +++ /dev/null @@ -1,19 +0,0 @@ -[tool.black] -line-length = 79 -target-version = ["py36"] - -[isort] -line_length=79 -multi_line_output=3 -include_trailing_comma=True -force_grid_wrap=0 -combine_as_imports=True -order_by_type=True -skip= - .eggs - .git - .hg - .mypy_cache - .tox - .venv - __init__.py \ No newline at end of file From 74a30e224fd9d2a62eb9706e4fbab360271e1a1c Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Fri, 21 Jan 2022 12:29:34 -0800 Subject: [PATCH 16/25] Removing pyramid example --- examples/pyramid.py | 684 -------------------------------------------- 1 file changed, 684 deletions(-) delete mode 100755 examples/pyramid.py diff --git a/examples/pyramid.py b/examples/pyramid.py deleted file mode 100755 index 1c94ebe..0000000 --- a/examples/pyramid.py +++ /dev/null @@ -1,684 +0,0 @@ -#!/usr/bin/env python3 - - -import ctypes # C interop helpers -import math -from enum import Enum - -import cupy as cp # CUDA bindings -import numpy as np # Packing of structures in C-compatible format -from numba import cuda, float32, int32, types, uint8, uint32 -from numba.core.extending import overload -from numba.cuda import get_current_device -from numba.cuda.compiler import compile_cuda as numba_compile_cuda -from numba_support import ( - OPTIX_RAY_FLAG_NONE, - Float3, - MissDataStruct, - OptixVisibilityMask, - float2, - float3, - make_float2, - make_float3, - make_uchar4, - make_uint3, - params, - uchar4, - uint3, -) -from PIL import Image, ImageOps # Image IO - -import optix - -# ------------------------------------------------------------------------------- -# -# Util -# -# ------------------------------------------------------------------------------- -pix_width = 1024 -pix_height = 768 - - -class Logger: - def __init__(self): - self.num_mssgs = 0 - - def __call__(self, level, tag, mssg): - print("[{:>2}][{:>12}]: {}".format(level, tag, mssg)) - self.num_mssgs += 1 - - -def log_callback(level, tag, mssg): - print("[{:>2}][{:>12}]: {}".format(level, tag, mssg)) - - -def round_up(val, mult_of): - return val if val % mult_of == 0 else val + mult_of - val % mult_of - - -def get_aligned_itemsize(formats, alignment): - names = [] - for i in range(len(formats)): - names.append("x" + str(i)) - - temp_dtype = np.dtype({"names": names, "formats": formats, "align": True}) - return round_up(temp_dtype.itemsize, alignment) - - -def array_to_device_memory(numpy_array, stream=cp.cuda.Stream()): - - byte_size = numpy_array.size * numpy_array.dtype.itemsize - - h_ptr = ctypes.c_void_p(numpy_array.ctypes.data) - d_mem = cp.cuda.memory.alloc(byte_size) - d_mem.copy_from_async(h_ptr, byte_size, stream) - return d_mem - - -def compile_cuda(cuda_file): - with open(cuda_file, "rb") as f: - src = f.read() - from pynvrtc.compiler import Program - - prog = Program(src.decode(), cuda_file) - ptx = prog.compile( - [ - "-use_fast_math", - "-lineinfo", - "-default-device", - "-std=c++11", - "-rdc", - "true", - #'-IC:\\ProgramData\\NVIDIA Corporation\OptiX SDK 7.2.0\include', - #'-IC:\\Program Files\\NVIDIA GPU Computing Toolkit\CUDA\\v11.1\include' - "-I/usr/local/cuda/include", - f"-I{optix.include_path}", - ] - ) - return ptx - - -# ------------------------------------------------------------------------------- -# -# Optix setup -# -# ------------------------------------------------------------------------------- - - -def init_optix(): - print("Initializing cuda ...") - cp.cuda.runtime.free(0) - - print("Initializing optix ...") - optix.init() - - -def create_ctx(): - print("Creating optix device context ...") - - # Note that log callback data is no longer needed. We can - # instead send a callable class instance as the log-function - # which stores any data needed - global logger - logger = Logger() - - # OptiX param struct fields can be set with optional - # keyword constructor arguments. - ctx_options = optix.DeviceContextOptions( - logCallbackFunction=logger, logCallbackLevel=4 - ) - - # They can also be set and queried as properties on the struct - ctx_options.validationMode = optix.DEVICE_CONTEXT_VALIDATION_MODE_ALL - - cu_ctx = 0 - return optix.deviceContextCreate(cu_ctx, ctx_options) - - -def create_accel(ctx): - - accel_options = optix.AccelBuildOptions( - buildFlags=int(optix.BUILD_FLAG_ALLOW_RANDOM_VERTEX_ACCESS), - operation=optix.BUILD_OPERATION_BUILD, - ) - - global vertices - global indices - # fmt: off - vertices = cp.array( - [ - -1.0, -1.0, 0, - 1.0, -1.0, 0, - 1.0, 1.0, 0, - -1.0, 1.0, 0, - 0, 0, 2.0 - ], dtype="f4" - ) - indices = cp.array( - [ - 0, 1, 2, - 0, 2, 3, - 0, 1, 4, - 1, 2, 4, - 2, 3, 4, - 3, 0, 4 - ], dtype="uint32" - ) - # fmt: on - - pyramid_input_flags = [optix.GEOMETRY_FLAG_NONE] - - pyramid_input = optix.BuildInputTriangleArray() - pyramid_input.vertexFormat = optix.VERTEX_FORMAT_FLOAT3 - pyramid_input.numVertices = len(vertices) - pyramid_input.vertexBuffers = [vertices.data.ptr] - pyramid_input.indexFormat = optix.INDICES_FORMAT_UNSIGNED_INT3 - pyramid_input.numIndexTriplets = len(indices) // 3 - pyramid_input.indexBuffer = indices.data.ptr - pyramid_input.flags = pyramid_input_flags - pyramid_input.numSbtRecords = 1 - - gas_buffer_sizes = ctx.accelComputeMemoryUsage([accel_options], [pyramid_input]) - - d_temp_buffer_gas = cp.cuda.alloc(gas_buffer_sizes.tempSizeInBytes) - d_gas_output_buffer = cp.cuda.alloc(gas_buffer_sizes.outputSizeInBytes) - - gas_handle = ctx.accelBuild( - 0, # CUDA stream - [accel_options], - [pyramid_input], - d_temp_buffer_gas.ptr, - gas_buffer_sizes.tempSizeInBytes, - d_gas_output_buffer.ptr, - gas_buffer_sizes.outputSizeInBytes, - [], # emitted properties - ) - - return (gas_handle, d_gas_output_buffer) - - -def set_pipeline_options(): - return optix.PipelineCompileOptions( - usesMotionBlur=False, - traversableGraphFlags=int(optix.TRAVERSABLE_GRAPH_FLAG_ALLOW_SINGLE_GAS), - numPayloadValues=3, - numAttributeValues=3, - exceptionFlags=int(optix.EXCEPTION_FLAG_NONE), - pipelineLaunchParamsVariableName="params", - usesPrimitiveTypeFlags=optix.PRIMITIVE_TYPE_FLAGS_TRIANGLE, - ) - - -def create_module(ctx, pipeline_options, ptx): - print("Creating optix module ...") - - module_options = optix.ModuleCompileOptions( - maxRegisterCount=optix.COMPILE_DEFAULT_MAX_REGISTER_COUNT, - optLevel=optix.COMPILE_OPTIMIZATION_DEFAULT, - debugLevel=optix.COMPILE_DEBUG_LEVEL_LINEINFO, - ) - - module, log = ctx.moduleCreateFromPTX(module_options, pipeline_options, ptx) - print("\tModule create log: <<<{}>>>".format(log)) - return module - - -def create_program_groups(ctx, raygen_module, miss_prog_module, hitgroup_module): - print("Creating program groups ... ") - - program_group_options = optix.ProgramGroupOptions() - - raygen_prog_group_desc = optix.ProgramGroupDesc() - raygen_prog_group_desc.raygenModule = raygen_module - raygen_prog_group_desc.raygenEntryFunctionName = "__raygen__rg" - - miss_prog_group_desc = optix.ProgramGroupDesc() - miss_prog_group_desc.missModule = miss_prog_module - miss_prog_group_desc.missEntryFunctionName = "__miss__ms" - - hitgroup_prog_group_desc = optix.ProgramGroupDesc() - hitgroup_prog_group_desc.hitgroupModuleCH = hitgroup_module - hitgroup_prog_group_desc.hitgroupEntryFunctionNameCH = "__closesthit__ch" - - prog_group, log = ctx.programGroupCreate( - [raygen_prog_group_desc, miss_prog_group_desc, hitgroup_prog_group_desc], - program_group_options, - ) - print("\tProgramGroup create log: <<<{}>>>".format(log)) - - return prog_group - - -def create_pipeline(ctx, program_groups, pipeline_compile_options): - print("Creating pipeline ... ") - - max_trace_depth = 1 - pipeline_link_options = optix.PipelineLinkOptions() - pipeline_link_options.maxTraceDepth = max_trace_depth - pipeline_link_options.debugLevel = optix.COMPILE_DEBUG_LEVEL_FULL - - log = "" - pipeline = ctx.pipelineCreate( - pipeline_compile_options, pipeline_link_options, program_groups, log - ) - - stack_sizes = optix.StackSizes() - for prog_group in program_groups: - optix.util.accumulateStackSizes(prog_group, stack_sizes) - - ( - dc_stack_size_from_trav, - dc_stack_size_from_state, - cc_stack_size, - ) = optix.util.computeStackSizes( - stack_sizes, max_trace_depth, 0, 0 # maxCCDepth # maxDCDepth - ) - - pipeline.setStackSize( - dc_stack_size_from_trav, - dc_stack_size_from_state, - cc_stack_size, - 1, # maxTraversableDepth - ) - - return pipeline - - -def create_sbt(prog_groups): - print("Creating sbt ... ") - - (raygen_prog_group, miss_prog_group, hitgroup_prog_group) = prog_groups - - global d_raygen_sbt - global d_miss_sbt - - header_format = "{}B".format(optix.SBT_RECORD_HEADER_SIZE) - - # - # raygen record - # - formats = [header_format] - itemsize = get_aligned_itemsize(formats, optix.SBT_RECORD_ALIGNMENT) - dtype = np.dtype( - {"names": ["header"], "formats": formats, "itemsize": itemsize, "align": True} - ) - h_raygen_sbt = np.array([0], dtype=dtype) - optix.sbtRecordPackHeader(raygen_prog_group, h_raygen_sbt) - global d_raygen_sbt - d_raygen_sbt = array_to_device_memory(h_raygen_sbt) - - # - # miss record - # - formats = [header_format, "f4", "f4", "f4"] - itemsize = get_aligned_itemsize(formats, optix.SBT_RECORD_ALIGNMENT) - dtype = np.dtype( - { - "names": ["header", "r", "g", "b"], - "formats": formats, - "itemsize": itemsize, - "align": True, - } - ) - h_miss_sbt = np.array([(0, 0.3, 0.1, 0.2)], dtype=dtype) - optix.sbtRecordPackHeader(miss_prog_group, h_miss_sbt) - global d_miss_sbt - d_miss_sbt = array_to_device_memory(h_miss_sbt) - - # - # hitgroup record - # - formats = [header_format] - itemsize = get_aligned_itemsize(formats, optix.SBT_RECORD_ALIGNMENT) - dtype = np.dtype( - {"names": ["header"], "formats": formats, "itemsize": itemsize, "align": True} - ) - h_hitgroup_sbt = np.array([(0)], dtype=dtype) - optix.sbtRecordPackHeader(hitgroup_prog_group, h_hitgroup_sbt) - global d_hitgroup_sbt - d_hitgroup_sbt = array_to_device_memory(h_hitgroup_sbt) - - sbt = optix.ShaderBindingTable() - sbt.raygenRecord = d_raygen_sbt.ptr - sbt.missRecordBase = d_miss_sbt.ptr - sbt.missRecordStrideInBytes = d_miss_sbt.mem.size - sbt.missRecordCount = 1 - sbt.hitgroupRecordBase = d_hitgroup_sbt.ptr - sbt.hitgroupRecordStrideInBytes = d_hitgroup_sbt.mem.size - sbt.hitgroupRecordCount = 1 - return sbt - - -def launch(pipeline, sbt, trav_handle, cam): - print("Launching ... ") - - pix_bytes = pix_width * pix_height * 4 - - h_pix = np.zeros((pix_width, pix_height, 4), "B") - h_pix[0:pix_width, 0:pix_height] = [255, 128, 0, 255] - d_pix = cp.array(h_pix) - - cam_eye, cam_U, cam_V, cam_W = cam - - params = [ - ("u8", "image", d_pix.data.ptr), - ("u4", "image_width", pix_width), - ("u4", "image_height", pix_height), - ("f4", "cam_eye_x", cam_eye[0]), - ("f4", "cam_eye_y", cam_eye[1]), - ("f4", "cam_eye_z", cam_eye[2]), - ("f4", "cam_U_x", cam_U[0]), - ("f4", "cam_U_y", cam_U[1]), - ("f4", "cam_U_z", cam_U[2]), - ("f4", "cam_V_x", cam_V[0]), - ("f4", "cam_V_y", cam_V[1]), - ("f4", "cam_V_z", cam_V[2]), - ("f4", "cam_W_x", cam_W[0]), - ("f4", "cam_W_y", cam_W[1]), - ("f4", "cam_W_z", cam_W[2]), - ("u8", "trav_handle", trav_handle), - ] - - formats = [x[0] for x in params] - names = [x[1] for x in params] - values = [x[2] for x in params] - itemsize = get_aligned_itemsize(formats, 8) - params_dtype = np.dtype( - {"names": names, "formats": formats, "itemsize": itemsize, "align": True} - ) - h_params = np.array([tuple(values)], dtype=params_dtype) - d_params = array_to_device_memory(h_params) - - stream = cp.cuda.Stream() - optix.launch( - pipeline, - stream.ptr, - d_params.ptr, - h_params.dtype.itemsize, - sbt, - pix_width, - pix_height, - 1, # depth - ) - - stream.synchronize() - - h_pix = cp.asnumpy(d_pix) - return h_pix - - -# Numba compilation -# ----------------- - -# An equivalent to the compile_cuda function for Python kernels. The types of -# the arguments to the kernel must be provided, if there are any. - - -def compile_numba(f, sig=(), debug=False): - # Based on numba.cuda.compile_ptx. We don't just use - # compile_ptx_for_current_device because it generates a kernel with a - # mangled name. For proceeding beyond this prototype, an option should be - # added to compile_ptx in Numba to not mangle the function name. - - nvvm_options = { - "debug": debug, - "fastmath": False, - "opt": 0 if debug else 3, - } - - cres = numba_compile_cuda(f, None, sig, debug=debug, nvvm_options=nvvm_options) - fname = cres.fndesc.llvm_func_name - tgt = cres.target_context - filename = cres.type_annotation.filename - linenum = int(cres.type_annotation.linenum) - lib, kernel = tgt.prepare_cuda_kernel( - cres.library, cres.fndesc, debug, nvvm_options, filename, linenum - ) - cc = get_current_device().compute_capability - ptx = lib.get_asm_str(cc=cc) - - # Demangle name - mangled_name = kernel.name - original_name = cres.library.name - return ptx.replace(mangled_name, original_name) - - -# ------------------------------------------------------------------------------- -# -# User code / kernel - the following section is what we'd expect a user of -# PyOptiX to write. -# -# ------------------------------------------------------------------------------- - -# vec_math - -# Overload for Clamp -def clamp(x, a, b): - pass - - -@overload(clamp, target="cuda") -def jit_clamp(x, a, b): - if ( - isinstance(x, types.Float) - and isinstance(a, types.Float) - and isinstance(b, types.Float) - ): - - def clamp_float_impl(x, a, b): - return max(a, min(x, b)) - - return clamp_float_impl - elif ( - isinstance(x, Float3) - and isinstance(a, types.Float) - and isinstance(b, types.Float) - ): - - def clamp_float3_impl(x, a, b): - return make_float3(clamp(x.x, a, b), clamp(x.y, a, b), clamp(x.z, a, b)) - - return clamp_float3_impl - - -def dot(a, b): - pass - - -@overload(dot, target="cuda") -def jit_dot(a, b): - if isinstance(a, Float3) and isinstance(b, Float3): - - def dot_float3_impl(a, b): - return a.x * b.x + a.y * b.y + a.z * b.z - - return dot_float3_impl - - -@cuda.jit(device=True) -def normalize(v): - invLen = float32(1.0) / math.sqrt(dot(v, v)) - return v * invLen - - -# Helpers - - -@cuda.jit(device=True) -def toSRGB(c): - # Use float32 for constants - invGamma = float32(1.0) / float32(2.4) - powed = make_float3( - math.pow(c.x, invGamma), math.pow(c.y, invGamma), math.pow(c.z, invGamma) - ) - return make_float3( - float32(12.92) * c.x - if c.x < float32(0.0031308) - else float32(1.055) * powed.x - float32(0.055), - float32(12.92) * c.y - if c.y < float32(0.0031308) - else float32(1.055) * powed.y - float32(0.055), - float32(12.92) * c.z - if c.z < float32(0.0031308) - else float32(1.055) * powed.z - float32(0.055), - ) - - -@cuda.jit(device=True) -def quantizeUnsigned8Bits(x): - x = clamp(x, float32(0.0), float32(1.0)) - N, Np1 = (1 << 8) - 1, 1 << 8 - return uint8(min(uint32(x * float32(Np1)), uint32(N))) - - -@cuda.jit(device=True) -def make_color(c): - srgb = toSRGB(clamp(c, float32(0.0), float32(1.0))) - return make_uchar4( - quantizeUnsigned8Bits(srgb.x), - quantizeUnsigned8Bits(srgb.y), - quantizeUnsigned8Bits(srgb.z), - uint8(255), - ) - - -# ray functions - - -@cuda.jit(device=True) -def setPayload(p): - optix.SetPayload_0(cuda.libdevice.float_as_int(p.x)) - optix.SetPayload_1(cuda.libdevice.float_as_int(p.y)) - optix.SetPayload_2(cuda.libdevice.float_as_int(p.z)) - - -@cuda.jit(device=True) -def computeRay(idx, dim): - U = params.cam_u - V = params.cam_v - W = params.cam_w - # Normalizing coordinates to [-1.0, 1.0] - d = float32(2.0) * make_float2( - float32(idx.x) / float32(dim.x), float32(idx.y) / float32(dim.y) - ) - float32(1.0) - - origin = params.cam_eye - direction = normalize(d.x * U + d.y * V + W) - return origin, direction - - -def __raygen__rg(): - # Lookup our location within the launch grid - idx = optix.GetLaunchIndex() - dim = optix.GetLaunchDimensions() - - # Map our launch idx to a screen location and create a ray from the camera - # location through the screen - ray_origin, ray_direction = computeRay(make_uint3(idx.x, idx.y, 0), dim) - - # Trace the ray against our scene hierarchy - payload_pack = optix.Trace( - params.handle, - ray_origin, - ray_direction, - float32(0.0), # Min intersection distance - float32(1e16), # Max intersection distance - float32(0.0), # rayTime -- used for motion blur - OptixVisibilityMask(255), # Specify always visible - # OptixRayFlags.OPTIX_RAY_FLAG_NONE, - uint32(OPTIX_RAY_FLAG_NONE), - uint32(0), # SBT offset -- See SBT discussion - uint32(1), # SBT stride -- See SBT discussion - uint32(0), # missSBTIndex -- See SBT discussion - ) - result = make_float3( - cuda.libdevice.int_as_float(payload_pack.p0), - cuda.libdevice.int_as_float(payload_pack.p1), - cuda.libdevice.int_as_float(payload_pack.p2), - ) - - # Record results in our output raster - params.image[idx.y * params.image_width + idx.x] = make_color(result) - - -def __miss__ms(): - miss_data = MissDataStruct(optix.GetSbtDataPointer()) - setPayload(miss_data.bg_color) - - -def __closesthit__ch(): - # When built-in triangle intersection is used, a number of fundamental - # attributes are provided by the OptiX API, indlucing barycentric coordinates. - barycentrics = optix.GetTriangleBarycentrics() - - setPayload(make_float3(barycentrics, float32(1.0))) - - -# ------------------------------------------------------------------------------- -# -# render -# -# ------------------------------------------------------------------------------- - - -def render(cam, t): - raygen_ptx = compile_numba(__raygen__rg) - miss_ptx = compile_numba(__miss__ms) - hitgroup_ptx = compile_numba(__closesthit__ch) - - # triangle_ptx = compile_cuda( "examples/triangle.cu" ) - - init_optix() - - ctx = create_ctx() - gas_handle, d_gas_output_buffer = create_accel(ctx) - pipeline_options = set_pipeline_options() - - raygen_module = create_module(ctx, pipeline_options, raygen_ptx) - miss_module = create_module(ctx, pipeline_options, miss_ptx) - hitgroup_module = create_module(ctx, pipeline_options, hitgroup_ptx) - - prog_groups = create_program_groups( - ctx, raygen_module, miss_module, hitgroup_module - ) - pipeline = create_pipeline(ctx, prog_groups, pipeline_options) - sbt = create_sbt(prog_groups) - pix = launch(pipeline, sbt, gas_handle, cam) - - print("Total number of log messages: {}".format(logger.num_mssgs)) - - pix = pix.reshape((pix_height, pix_width, 4)) # PIL expects [ y, x ] resolution - img = ImageOps.flip(Image.fromarray(pix, "RGBA")) # PIL expects y = 0 at bottom - img.save(f"output/pyramid_{t}.png") - img.show() - - -def lookat(eye, at, up): - W = at - eye - Wnorm = np.linalg.norm(W) - if np.allclose(Wnorm, 0.0): - raise ValueError("Target too close to eye.") - W = W / Wnorm - U = np.cross(W, up) - U = U / np.linalg.norm(U) - V = np.cross(U, W) - V = V / np.linalg.norm(V) - return U, V, W - - -def polar2cart(r, theta): - return (r * math.cos(theta), r * math.sin(theta)) - - -if __name__ == "__main__": - - for t in range(0, 361, 6): - rad = math.radians(t) - cart = polar2cart(1.5, rad) - eye = np.array([*cart, 2.5]) - at = np.array([0.0, 0.0, 0.0]) - up = np.array([0.0, 0.0, 1.0]) - U, V, W = lookat(eye, at, up) - - # print(eye, U, V, W) - render((eye, U, V, W), t) From 049937d3a1ea27b970107fc4e8c3e0685ab17b48 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Mon, 24 Jan 2022 15:57:15 -0800 Subject: [PATCH 17/25] Use fast_powf and fast_math, 50% speedup. --- examples/triangle.py | 39 ++++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/examples/triangle.py b/examples/triangle.py index 8a6169c..164dc9e 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -11,6 +11,7 @@ from numba.core.extending import overload from numba.cuda import get_current_device from numba.cuda.compiler import compile_cuda as numba_compile_cuda +from numba.cuda.libdevice import fast_powf, float_as_int, int_as_float from numba_support import ( OPTIX_RAY_FLAG_NONE, Float3, @@ -387,7 +388,7 @@ def launch(pipeline, sbt, trav_handle): # the arguments to the kernel must be provided, if there are any. -def compile_numba(f, sig=(), debug=False): +def compile_numba(f, sig=(), debug=False, lineinfo=False): # Based on numba.cuda.compile_ptx. We don't just use # compile_ptx_for_current_device because it generates a kernel with a # mangled name. For proceeding beyond this prototype, an option should be @@ -395,7 +396,8 @@ def compile_numba(f, sig=(), debug=False): nvvm_options = { "debug": debug, - "fastmath": False, + "lineinfo": lineinfo, + "fastmath": True, "opt": 0 if debug else 3, } @@ -430,7 +432,7 @@ def clamp(x, a, b): pass -@overload(clamp, target="cuda") +@overload(clamp, target="cuda", fast_math=True) def jit_clamp(x, a, b): if ( isinstance(x, types.Float) @@ -458,7 +460,7 @@ def dot(a, b): pass -@overload(dot, target="cuda") +@overload(dot, target="cuda", fast_math=True) def jit_dot(a, b): if isinstance(a, Float3) and isinstance(b, Float3): @@ -468,7 +470,7 @@ def dot_float3_impl(a, b): return dot_float3_impl -@cuda.jit(device=True) +@cuda.jit(device=True, fast_math=True) def normalize(v): invLen = float32(1.0) / math.sqrt(dot(v, v)) return v * invLen @@ -477,12 +479,15 @@ def normalize(v): # Helpers -@cuda.jit(device=True) +@cuda.jit(device=True, fast_math=True) def toSRGB(c): # Use float32 for constants invGamma = float32(1.0) / float32(2.4) powed = make_float3( - math.pow(c.x, invGamma), math.pow(c.y, invGamma), math.pow(c.z, invGamma) + # math.pow(c.x, invGamma), math.pow(c.y, invGamma), math.pow(c.z, invGamma) + fast_powf(c.x, invGamma), + fast_powf(c.y, invGamma), + fast_powf(c.z, invGamma), ) return make_float3( float32(12.92) * c.x @@ -497,14 +502,14 @@ def toSRGB(c): ) -@cuda.jit(device=True) +@cuda.jit(device=True, fast_math=True) def quantizeUnsigned8Bits(x): x = clamp(x, float32(0.0), float32(1.0)) N, Np1 = (1 << 8) - 1, 1 << 8 return uint8(min(uint32(x * float32(Np1)), uint32(N))) -@cuda.jit(device=True) +@cuda.jit(device=True, fast_math=True) def make_color(c): srgb = toSRGB(clamp(c, float32(0.0), float32(1.0))) return make_uchar4( @@ -518,14 +523,14 @@ def make_color(c): # ray functions -@cuda.jit(device=True) +@cuda.jit(device=True, fast_math=True) def setPayload(p): - optix.SetPayload_0(cuda.libdevice.float_as_int(p.x)) - optix.SetPayload_1(cuda.libdevice.float_as_int(p.y)) - optix.SetPayload_2(cuda.libdevice.float_as_int(p.z)) + optix.SetPayload_0(float_as_int(p.x)) + optix.SetPayload_1(float_as_int(p.y)) + optix.SetPayload_2(float_as_int(p.z)) -@cuda.jit(device=True) +@cuda.jit(device=True, fast_math=True) def computeRay(idx, dim): U = params.cam_u V = params.cam_v @@ -565,9 +570,9 @@ def __raygen__rg(): uint32(0), # missSBTIndex -- See SBT discussion ) result = make_float3( - cuda.libdevice.int_as_float(payload_pack.p0), - cuda.libdevice.int_as_float(payload_pack.p1), - cuda.libdevice.int_as_float(payload_pack.p2), + int_as_float(payload_pack.p0), + int_as_float(payload_pack.p1), + int_as_float(payload_pack.p2), ) # Record results in our output raster From 4be3b25d5f2f963c4c33b646bdc32a4093ce91db Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 3 Feb 2022 21:57:30 -0800 Subject: [PATCH 18/25] Use automation to create vector types and factory functions. --- examples/numba_support.py | 390 +++++++++++++------------------------- examples/triangle.py | 6 +- 2 files changed, 132 insertions(+), 264 deletions(-) diff --git a/examples/numba_support.py b/examples/numba_support.py index 4d693b6..d61a62f 100644 --- a/examples/numba_support.py +++ b/examples/numba_support.py @@ -5,9 +5,10 @@ # ------------------------------------------------------------------------------- from operator import add, mul, sub +from typing import List, Tuple from llvmlite import ir -from numba import cuda, float32, int32, types, uint8, uint32 +from numba import cuda, float32, int32, types, uchar, uint8, uint32 from numba.core import cgutils from numba.core.extending import ( make_attribute_wrapper, @@ -30,116 +31,134 @@ import optix -# UChar4 -# ------ -# Numba presently doesn't implement the UChar4 type (which is fairly standard -# CUDA) so we provide some minimal support for it here. - - -# Prototype a function to construct a uchar4 - - -def make_uchar4(x, y, z, w): - pass - - -# UChar4 typing - - -class UChar4(types.Type): - def __init__(self): - super().__init__(name="UChar4") - - -uchar4 = UChar4() - - -@register -class MakeUChar4(ConcreteTemplate): - key = make_uchar4 - cases = [signature(uchar4, types.uchar, types.uchar, types.uchar, types.uchar)] - - -register_global(make_uchar4, types.Function(MakeUChar4)) - - -# UChar4 data model - - -@register_model(UChar4) -class UChar4Model(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("x", types.uchar), - ("y", types.uchar), - ("z", types.uchar), - ("w", types.uchar), - ] - super().__init__(dmm, fe_type, members) - - -make_attribute_wrapper(UChar4, "x", "x") -make_attribute_wrapper(UChar4, "y", "y") -make_attribute_wrapper(UChar4, "z", "z") -make_attribute_wrapper(UChar4, "w", "w") - - -# UChar4 lowering - - -@lower(make_uchar4, types.uchar, types.uchar, types.uchar, types.uchar) -def lower_make_uchar4(context, builder, sig, args): - uc4 = cgutils.create_struct_proxy(uchar4)(context, builder) - uc4.x = args[0] - uc4.y = args[1] - uc4.z = args[2] - uc4.w = args[3] - return uc4._getvalue() - - -# float3 -# ------ - -# Float3 typing - - -class Float3(types.Type): - def __init__(self): - super().__init__(name="Float3") - - -float3 = Float3() - - -# Float2 typing (forward declaration) - - -class Float2(types.Type): - def __init__(self): - super().__init__(name="Float2") - - -float2 = Float2() - - -# Float3 data model - - -@register_model(Float3) -class Float3Model(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("x", types.float32), - ("y", types.float32), - ("z", types.float32), - ] - super().__init__(dmm, fe_type, members) - - -make_attribute_wrapper(Float3, "x", "x") -make_attribute_wrapper(Float3, "y", "y") -make_attribute_wrapper(Float3, "z", "z") +class VectorType(types.Type): + def __init__(self, name, base_type, attr_names): + self._base_type = base_type + self._attr_names = attr_names + super().__init__(name=name) + + @property + def base_type(self): + return self._base_type + + @property + def attr_names(self): + return self._attr_names + + @property + def num_elements(self): + return len(self._attr_names) + + +def make_vector_type( + name: str, base_type: types.Type, attr_names: List[str] +) -> types.Type: + """Create a vector type. + + Parameters + ---------- + name: str + The name of the type. + base_type: numba.types.Type + The primitive type for each element in the vector. + attr_names: list of str + Name for each attribute. + """ + + class _VectorType(VectorType): + """Internal instantiation of VectorType.""" + + pass + + class VectorTypeModel(models.StructModel): + def __init__(self, dmm, fe_type): + members = [(attr_name, base_type) for attr_name in attr_names] + super().__init__(dmm, fe_type, members) + + vector_type = _VectorType(name, base_type, attr_names) + register_model(_VectorType)(VectorTypeModel) + for attr_name in attr_names: + make_attribute_wrapper(_VectorType, attr_name, attr_name) + + return vector_type + + +def make_vector_type_factory( + vector_type: types.Type, overloads: List[Tuple[types.Type]] +): + """Make a factory function for ``vector_type`` + + Parameters + ---------- + vector_type: VectorType + The type to create factory function for. + overloads: List of argument types tuples + A list containing different overloads of the factory function. Each + base type in the tuple should either be primitive type or VectorType. + """ + + def func(): + pass + + class FactoryTemplate(ConcreteTemplate): + key = func + cases = [signature(vector_type, *arglist) for arglist in overloads] + + def make_lower_factory(fml_arg_list): + """Meta function to create a lowering for the factory function. Flattens + the arguments by converting vector_type into load instructions for each + of its attributes. Such as float2 -> float2.x, float2.y. + """ + + def lower_factory(context, builder, sig, actual_args): + # A list of elements to assign from + source_list = [] + # Convert the list of argument types to a list of load IRs. + for argidx, fml_arg in enumerate(fml_arg_list): + if isinstance(fml_arg, VectorType): + pxy = cgutils.create_struct_proxy(fml_arg)( + context, builder, actual_args[argidx] + ) + source_list += [getattr(pxy, attr) for attr in fml_arg.attr_names] + else: + # assumed primitive type + source_list.append(actual_args[argidx]) + + if len(source_list) != vector_type.num_elements: + raise numba.core.TypingError( + f"Unmatched number of source elements ({len(source_list)}) " + "and target elements ({vector_type.num_elements})." + ) + + typ = cgutils.create_struct_proxy(vector_type)(context, builder) + + for attr_name, source in zip(vector_type.attr_names, source_list): + setattr(typ, attr_name, source) + return typ._getvalue() + + return lower_factory + + func.__name__ = f"make_{vector_type.name.lower()}" + register(FactoryTemplate) + register_global(func, types.Function(FactoryTemplate)) + for arglist in overloads: + lower_factory = make_lower_factory(arglist) + lower(func, *arglist)(lower_factory) + return func + + +# Register basic types +uchar4 = make_vector_type("UChar4", uchar, ["x", "y", "z", "w"]) +float3 = make_vector_type("Float3", float32, ["x", "y", "z"]) +float2 = make_vector_type("Float2", float32, ["x", "y"]) +uint3 = make_vector_type("UInt3", uint32, ["x", "y", "z"]) + +# Register factory functions +make_uchar4 = make_vector_type_factory(uchar4, [(uchar,) * 4]) +make_float3 = make_vector_type_factory(float3, [(float32,) * 3, (float2, float32)]) +make_float2 = make_vector_type_factory(float2, [(float32,) * 2]) +make_uint3 = make_vector_type_factory(uint3, [(uint32,) * 3]) def lower_float3_ops(op): @@ -220,68 +239,6 @@ def add_float3_float32_impl(context, builder, sig, args): return res._getvalue() -# Prototype a function to construct a float3 - - -def make_float3(x, y, z): - pass - - -@register -class MakeFloat3(ConcreteTemplate): - key = make_float3 - cases = [ - signature(float3, types.float32, types.float32, types.float32), - signature(float3, float2, types.float32), - ] - - -register_global(make_float3, types.Function(MakeFloat3)) - - -# make_float3 lowering - - -@lower(make_float3, types.float32, types.float32, types.float32) -def lower_make_float3(context, builder, sig, args): - f3 = cgutils.create_struct_proxy(float3)(context, builder) - f3.x = args[0] - f3.y = args[1] - f3.z = args[2] - return f3._getvalue() - - -@lower(make_float3, float2, types.float32) -def lower_make_float3(context, builder, sig, args): - f2 = cgutils.create_struct_proxy(float2)(context, builder, args[0]) - f3 = cgutils.create_struct_proxy(float3)(context, builder) - f3.x = f2.x - f3.y = f2.y - f3.z = args[1] - return f3._getvalue() - - -# float2 -# ------ - - -# Float2 data model - - -@register_model(Float2) -class Float2Model(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("x", types.float32), - ("y", types.float32), - ] - super().__init__(dmm, fe_type, members) - - -make_attribute_wrapper(Float2, "x", "x") -make_attribute_wrapper(Float2, "y", "y") - - def lower_float2_ops(op): class Float2_op_template(ConcreteTemplate): key = op @@ -335,93 +292,6 @@ def op_attr(lhs, rhs, res, attr): lower_float2_ops(sub) -# Prototype a function to construct a float2 - - -def make_float2(x, y): - pass - - -@register -class MakeFloat2(ConcreteTemplate): - key = make_float2 - cases = [signature(float2, types.float32, types.float32)] - - -register_global(make_float2, types.Function(MakeFloat2)) - - -# make_float2 lowering - - -@lower(make_float2, types.float32, types.float32) -def lower_make_float2(context, builder, sig, args): - f2 = cgutils.create_struct_proxy(float2)(context, builder) - f2.x = args[0] - f2.y = args[1] - return f2._getvalue() - - -# uint3 -# ------ - - -class UInt3(types.Type): - def __init__(self): - super().__init__(name="UInt3") - - -uint3 = UInt3() - - -# UInt3 data model - - -@register_model(UInt3) -class UInt3Model(models.StructModel): - def __init__(self, dmm, fe_type): - members = [ - ("x", types.uint32), - ("y", types.uint32), - ("z", types.uint32), - ] - super().__init__(dmm, fe_type, members) - - -make_attribute_wrapper(UInt3, "x", "x") -make_attribute_wrapper(UInt3, "y", "y") -make_attribute_wrapper(UInt3, "z", "z") - - -# Prototype a function to construct a uint3 - - -def make_uint3(x, y, z): - pass - - -@register -class MakeUInt3(ConcreteTemplate): - key = make_uint3 - cases = [signature(uint3, types.uint32, types.uint32, types.uint32)] - - -register_global(make_uint3, types.Function(MakeUInt3)) - - -# make_uint3 lowering - - -@lower(make_uint3, types.uint32, types.uint32, types.uint32) -def lower_make_uint3(context, builder, sig, args): - # u4 = uint32 - u4_3 = cgutils.create_struct_proxy(uint3)(context, builder) - u4_3.x = args[0] - u4_3.y = args[1] - u4_3.z = args[2] - return u4_3._getvalue() - - # Temporary Payload Parameter Pack class PayloadPack(types.Type): def __init__(self): @@ -902,7 +772,7 @@ def lower_optix_getTriangleBarycentrics(context, builder, sig, args): ) def lower_optix_Trace(context, builder, sig, args): # Only implements the version that accepts 3 payload registers - # TODO: Optimize returns, adapt to 0-8 payload registers. + # TODO: Optimize returns, adapt to 0-32 payload registers. ( handle, diff --git a/examples/triangle.py b/examples/triangle.py index 164dc9e..789b1cf 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -14,7 +14,6 @@ from numba.cuda.libdevice import fast_powf, float_as_int, int_as_float from numba_support import ( OPTIX_RAY_FLAG_NONE, - Float3, MissDataStruct, OptixVisibilityMask, float2, @@ -331,7 +330,6 @@ def launch(pipeline, sbt, trav_handle): pix_bytes = pix_width * pix_height * 4 h_pix = np.zeros((pix_width, pix_height, 4), "B") - h_pix[0:pix_width, 0:pix_height] = [255, 128, 0, 255] d_pix = cp.array(h_pix) params = [ @@ -445,7 +443,7 @@ def clamp_float_impl(x, a, b): return clamp_float_impl elif ( - isinstance(x, Float3) + isinstance(x, type(float3)) and isinstance(a, types.Float) and isinstance(b, types.Float) ): @@ -462,7 +460,7 @@ def dot(a, b): @overload(dot, target="cuda", fast_math=True) def jit_dot(a, b): - if isinstance(a, Float3) and isinstance(b, Float3): + if isinstance(a, type(float3)) and isinstance(b, type(float3)): def dot_float3_impl(a, b): return a.x * b.x + a.y * b.y + a.z * b.z From 4d16c3ba39d1431e56e2ea1b0b6734d41a22584f Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 3 Feb 2022 22:56:25 -0800 Subject: [PATCH 19/25] Automatically generate ops between primitve x vector_type, and vector_type x vector_type --- examples/numba_support.py | 207 ++++++++++++++------------------------ 1 file changed, 73 insertions(+), 134 deletions(-) diff --git a/examples/numba_support.py b/examples/numba_support.py index d61a62f..8716d64 100644 --- a/examples/numba_support.py +++ b/examples/numba_support.py @@ -131,11 +131,11 @@ def lower_factory(context, builder, sig, actual_args): "and target elements ({vector_type.num_elements})." ) - typ = cgutils.create_struct_proxy(vector_type)(context, builder) + out = cgutils.create_struct_proxy(vector_type)(context, builder) for attr_name, source in zip(vector_type.attr_names, source_list): - setattr(typ, attr_name, source) - return typ._getvalue() + setattr(out, attr_name, source) + return out._getvalue() return lower_factory @@ -148,6 +148,61 @@ def lower_factory(context, builder, sig, actual_args): return func +def lower_vector_type_ops( + op, vector_type: VectorType, overloads: List[Tuple[types.Type]] +): + class Vector_op_template(ConcreteTemplate): + key = op + cases = [signature(vector_type, *arglist) for arglist in overloads] + + def make_lower_op(fml_arg_list): + def op_impl(context, builder, sig, actual_args): + def _make_load_IR(typ, actual_arg): + if isinstance(typ, VectorType): + pxy = cgutils.create_struct_proxy(typ)(context, builder, actual_arg) + oprands = [getattr(pxy, attr) for attr in typ.attr_names] + else: + # Assumed primitive type, broadcast + oprands = [actual_arg for _ in range(vector_type.num_elements)] + return oprands + + def element_wise_op(lhs, rhs, res, attr): + setattr( + res, + attr, + context.compile_internal( + builder, + lambda x, y: op(x, y), + signature(types.float32, types.float32, types.float32), + (lhs, rhs), + ), + ) + + lhs_typ, rhs_typ = fml_arg_list + # Construct a list of load IRs + lhs = _make_load_IR(lhs_typ, actual_args[0]) + rhs = _make_load_IR(rhs_typ, actual_args[1]) + + if not len(lhs) == len(rhs) == vector_type.num_elements: + raise numba.core.TypingError( + f"Unmatched number of lhs elements ({len(lhs)}), rhs elements ({len(rhs)}) " + "and target elements ({vector_type.num_elements})." + ) + + out = cgutils.create_struct_proxy(vector_type)(context, builder) + for attr, l, r in zip(vector_type.attr_names, lhs, rhs): + element_wise_op(l, r, out, attr) + + return out._getvalue() + + return op_impl + + register_global(op, types.Function(Vector_op_template)) + for arglist in overloads: + impl = make_lower_op(arglist) + lower(op, *arglist)(impl) + + # Register basic types uchar4 = make_vector_type("UChar4", uchar, ["x", "y", "z", "w"]) float3 = make_vector_type("Float3", float32, ["x", "y", "z"]) @@ -160,137 +215,21 @@ def lower_factory(context, builder, sig, actual_args): make_float2 = make_vector_type_factory(float2, [(float32,) * 2]) make_uint3 = make_vector_type_factory(uint3, [(uint32,) * 3]) - -def lower_float3_ops(op): - class Float3_op_template(ConcreteTemplate): - key = op - cases = [ - signature(float3, float3, float3), - signature(float3, types.float32, float3), - signature(float3, float3, types.float32), - ] - - def float3_op_impl(context, builder, sig, args): - def op_attr(lhs, rhs, res, attr): - setattr( - res, - attr, - context.compile_internal( - builder, - lambda x, y: op(x, y), - signature(types.float32, types.float32, types.float32), - (getattr(lhs, attr), getattr(rhs, attr)), - ), - ) - - arg0, arg1 = args - - if isinstance(sig.args[0], types.Float): - lf3 = cgutils.create_struct_proxy(float3)(context, builder) - lf3.x = arg0 - lf3.y = arg0 - lf3.z = arg0 - else: - lf3 = cgutils.create_struct_proxy(float3)(context, builder, value=args[0]) - - if isinstance(sig.args[1], types.Float): - rf3 = cgutils.create_struct_proxy(float3)(context, builder) - rf3.x = arg1 - rf3.y = arg1 - rf3.z = arg1 - else: - rf3 = cgutils.create_struct_proxy(float3)(context, builder, value=args[1]) - - res = cgutils.create_struct_proxy(float3)(context, builder) - op_attr(lf3, rf3, res, "x") - op_attr(lf3, rf3, res, "y") - op_attr(lf3, rf3, res, "z") - return res._getvalue() - - register_global(op, types.Function(Float3_op_template)) - lower(op, float3, float3)(float3_op_impl) - lower(op, types.float32, float3)(float3_op_impl) - lower(op, float3, types.float32)(float3_op_impl) - - -lower_float3_ops(mul) -lower_float3_ops(add) - - -@lower(add, float32, float3) -def add_float32_float3_impl(context, builder, sig, args): - s = args[0] - rhs = cgutils.create_struct_proxy(float3)(context, builder, args[1]) - res = cgutils.create_struct_proxy(float3)(context, builder) - res.x = builder.fadd(s, rhs.x) - res.y = builder.fadd(s, rhs.y) - res.z = builder.fadd(s, rhs.z) - return res._getvalue() - - -@lower(add, float3, float32) -def add_float3_float32_impl(context, builder, sig, args): - lhs = cgutils.create_struct_proxy(float3)(context, builder, args[0]) - s = args[1] - res = cgutils.create_struct_proxy(float3)(context, builder) - res.x = builder.fadd(lhs.x, s) - res.y = builder.fadd(lhs.y, s) - res.z = builder.fadd(lhs.z, s) - return res._getvalue() - - -def lower_float2_ops(op): - class Float2_op_template(ConcreteTemplate): - key = op - cases = [ - signature(float2, float2, float2), - signature(float2, types.float32, float2), - signature(float2, float2, types.float32), - ] - - def float2_op_impl(context, builder, sig, args): - def op_attr(lhs, rhs, res, attr): - setattr( - res, - attr, - context.compile_internal( - builder, - lambda x, y: op(x, y), - signature(types.float32, types.float32, types.float32), - (getattr(lhs, attr), getattr(rhs, attr)), - ), - ) - - arg0, arg1 = args - - if isinstance(sig.args[0], types.Float): - lf2 = cgutils.create_struct_proxy(float2)(context, builder) - lf2.x = arg0 - lf2.y = arg0 - else: - lf2 = cgutils.create_struct_proxy(float2)(context, builder, value=args[0]) - - if isinstance(sig.args[1], types.Float): - rf2 = cgutils.create_struct_proxy(float2)(context, builder) - rf2.x = arg1 - rf2.y = arg1 - else: - rf2 = cgutils.create_struct_proxy(float2)(context, builder, value=args[1]) - - res = cgutils.create_struct_proxy(float2)(context, builder) - op_attr(lf2, rf2, res, "x") - op_attr(lf2, rf2, res, "y") - return res._getvalue() - - register_global(op, types.Function(Float2_op_template)) - lower(op, float2, float2)(float2_op_impl) - lower(op, types.Float, float2)(float2_op_impl) - lower(op, float2, types.Float)(float2_op_impl) - - -lower_float2_ops(mul) -lower_float2_ops(sub) - +# Lower Vector Type Ops +## float3 +lower_vector_type_ops( + add, float3, [(float3, float3), (float32, float3), (float3, float32)] +) +lower_vector_type_ops( + mul, float3, [(float3, float3), (float32, float3), (float3, float32)] +) +## float2 +lower_vector_type_ops( + mul, float2, [(float2, float2), (float32, float2), (float2, float32)] +) +lower_vector_type_ops( + sub, float2, [(float2, float2), (float32, float2), (float2, float32)] +) # Temporary Payload Parameter Pack class PayloadPack(types.Type): From 0d377add609b56075a2abfe39fde66a8b1b519c2 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 3 Feb 2022 23:01:57 -0800 Subject: [PATCH 20/25] docstrings --- examples/numba_support.py | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/examples/numba_support.py b/examples/numba_support.py index 8716d64..4705ca7 100644 --- a/examples/numba_support.py +++ b/examples/numba_support.py @@ -148,11 +148,27 @@ def lower_factory(context, builder, sig, actual_args): return func -def lower_vector_type_ops( - op, vector_type: VectorType, overloads: List[Tuple[types.Type]] +def lower_vector_type_binops( + binop, vector_type: VectorType, overloads: List[Tuple[types.Type]] ): + """Lower ops for ``vector_type`` + + Parameters + ---------- + binop: operation + The binop to lower + vector_type: VectorType + The type to lower op for. + overloads: List of argument types tuples + A list containing different overloads of the binop. Expected to be either + - vector_type x vector_type + - primitive_type x vector_type + - vector_type x primitive_type. + In case one of the oprand is primitive_type, the operation is broadcasted. + """ + # Should we assume the above are the only possible types? class Vector_op_template(ConcreteTemplate): - key = op + key = binop cases = [signature(vector_type, *arglist) for arglist in overloads] def make_lower_op(fml_arg_list): @@ -172,7 +188,7 @@ def element_wise_op(lhs, rhs, res, attr): attr, context.compile_internal( builder, - lambda x, y: op(x, y), + lambda x, y: binop(x, y), signature(types.float32, types.float32, types.float32), (lhs, rhs), ), @@ -197,10 +213,10 @@ def element_wise_op(lhs, rhs, res, attr): return op_impl - register_global(op, types.Function(Vector_op_template)) + register_global(binop, types.Function(Vector_op_template)) for arglist in overloads: impl = make_lower_op(arglist) - lower(op, *arglist)(impl) + lower(binop, *arglist)(impl) # Register basic types @@ -217,17 +233,17 @@ def element_wise_op(lhs, rhs, res, attr): # Lower Vector Type Ops ## float3 -lower_vector_type_ops( +lower_vector_type_binops( add, float3, [(float3, float3), (float32, float3), (float3, float32)] ) -lower_vector_type_ops( +lower_vector_type_binops( mul, float3, [(float3, float3), (float32, float3), (float3, float32)] ) ## float2 -lower_vector_type_ops( +lower_vector_type_binops( mul, float2, [(float2, float2), (float32, float2), (float2, float32)] ) -lower_vector_type_ops( +lower_vector_type_binops( sub, float2, [(float2, float2), (float32, float2), (float2, float32)] ) From a2a6c89ed559fa2403f1b5d2142d5c887f053753 Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 3 Feb 2022 23:04:17 -0800 Subject: [PATCH 21/25] A few minor notes --- examples/numba_support.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/numba_support.py b/examples/numba_support.py index 4705ca7..49a6643 100644 --- a/examples/numba_support.py +++ b/examples/numba_support.py @@ -151,7 +151,7 @@ def lower_factory(context, builder, sig, actual_args): def lower_vector_type_binops( binop, vector_type: VectorType, overloads: List[Tuple[types.Type]] ): - """Lower ops for ``vector_type`` + """Lower binops for ``vector_type`` Parameters ---------- @@ -166,7 +166,7 @@ def lower_vector_type_binops( - vector_type x primitive_type. In case one of the oprand is primitive_type, the operation is broadcasted. """ - # Should we assume the above are the only possible types? + # Should we assume the above are the only possible cases? class Vector_op_template(ConcreteTemplate): key = binop cases = [signature(vector_type, *arglist) for arglist in overloads] From 3f4a1452239e4d1fae42ad5c74569ccac0ae616c Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Sun, 13 Feb 2022 13:57:54 -0800 Subject: [PATCH 22/25] Get correct `fastmath` behavior based on dc73113981f919144b4e6efd87a73de87e857f00 --- examples/triangle.py | 49 +++++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 23 deletions(-) diff --git a/examples/triangle.py b/examples/triangle.py index 789b1cf..edfb022 100755 --- a/examples/triangle.py +++ b/examples/triangle.py @@ -13,8 +13,8 @@ from numba.cuda.compiler import compile_cuda as numba_compile_cuda from numba.cuda.libdevice import fast_powf, float_as_int, int_as_float from numba_support import ( - OPTIX_RAY_FLAG_NONE, MissDataStruct, + OptixRayFlags, OptixVisibilityMask, float2, float3, @@ -386,7 +386,7 @@ def launch(pipeline, sbt, trav_handle): # the arguments to the kernel must be provided, if there are any. -def compile_numba(f, sig=(), debug=False, lineinfo=False): +def compile_numba(f, sig=(), fastmath=True, debug=False, lineinfo=False): # Based on numba.cuda.compile_ptx. We don't just use # compile_ptx_for_current_device because it generates a kernel with a # mangled name. For proceeding beyond this prototype, an option should be @@ -395,11 +395,19 @@ def compile_numba(f, sig=(), debug=False, lineinfo=False): nvvm_options = { "debug": debug, "lineinfo": lineinfo, - "fastmath": True, + "fastmath": fastmath, "opt": 0 if debug else 3, } - cres = numba_compile_cuda(f, None, sig, debug=debug, nvvm_options=nvvm_options) + cres = numba_compile_cuda( + f, + None, + sig, + fastmath=fastmath, + debug=debug, + lineinfo=lineinfo, + nvvm_options=nvvm_options, + ) fname = cres.fndesc.llvm_func_name tgt = cres.target_context filename = cres.type_annotation.filename @@ -430,7 +438,7 @@ def clamp(x, a, b): pass -@overload(clamp, target="cuda", fast_math=True) +@overload(clamp, target="cuda", fastmath=True) def jit_clamp(x, a, b): if ( isinstance(x, types.Float) @@ -458,7 +466,7 @@ def dot(a, b): pass -@overload(dot, target="cuda", fast_math=True) +@overload(dot, target="cuda", fastmath=True) def jit_dot(a, b): if isinstance(a, type(float3)) and isinstance(b, type(float3)): @@ -468,7 +476,7 @@ def dot_float3_impl(a, b): return dot_float3_impl -@cuda.jit(device=True, fast_math=True) +@cuda.jit(device=True, fastmath=True) def normalize(v): invLen = float32(1.0) / math.sqrt(dot(v, v)) return v * invLen @@ -477,16 +485,14 @@ def normalize(v): # Helpers -@cuda.jit(device=True, fast_math=True) +@cuda.jit(device=True, fastmath=True) def toSRGB(c): # Use float32 for constants invGamma = float32(1.0) / float32(2.4) powed = make_float3( - # math.pow(c.x, invGamma), math.pow(c.y, invGamma), math.pow(c.z, invGamma) - fast_powf(c.x, invGamma), - fast_powf(c.y, invGamma), - fast_powf(c.z, invGamma), + math.pow(c.x, invGamma), math.pow(c.y, invGamma), math.pow(c.z, invGamma) ) + return make_float3( float32(12.92) * c.x if c.x < float32(0.0031308) @@ -500,14 +506,14 @@ def toSRGB(c): ) -@cuda.jit(device=True, fast_math=True) +@cuda.jit(device=True, fastmath=True) def quantizeUnsigned8Bits(x): x = clamp(x, float32(0.0), float32(1.0)) N, Np1 = (1 << 8) - 1, 1 << 8 return uint8(min(uint32(x * float32(Np1)), uint32(N))) -@cuda.jit(device=True, fast_math=True) +@cuda.jit(device=True, fastmath=True) def make_color(c): srgb = toSRGB(clamp(c, float32(0.0), float32(1.0))) return make_uchar4( @@ -521,14 +527,14 @@ def make_color(c): # ray functions -@cuda.jit(device=True, fast_math=True) +@cuda.jit(device=True, fastmath=True) def setPayload(p): optix.SetPayload_0(float_as_int(p.x)) optix.SetPayload_1(float_as_int(p.y)) optix.SetPayload_2(float_as_int(p.z)) -@cuda.jit(device=True, fast_math=True) +@cuda.jit(device=True, fastmath=True) def computeRay(idx, dim): U = params.cam_u V = params.cam_v @@ -561,8 +567,7 @@ def __raygen__rg(): float32(1e16), # Max intersection distance float32(0.0), # rayTime -- used for motion blur OptixVisibilityMask(255), # Specify always visible - # OptixRayFlags.OPTIX_RAY_FLAG_NONE, - uint32(OPTIX_RAY_FLAG_NONE), + OptixRayFlags.OPTIX_RAY_FLAG_NONE, uint32(0), # SBT offset -- See SBT discussion uint32(1), # SBT stride -- See SBT discussion uint32(0), # missSBTIndex -- See SBT discussion @@ -598,11 +603,9 @@ def __closesthit__ch(): def main(): - raygen_ptx = compile_numba(__raygen__rg) - miss_ptx = compile_numba(__miss__ms) - hitgroup_ptx = compile_numba(__closesthit__ch) - - # triangle_ptx = compile_cuda( "examples/triangle.cu" ) + raygen_ptx = compile_numba(__raygen__rg, fastmath=True) + miss_ptx = compile_numba(__miss__ms, fastmath=True) + hitgroup_ptx = compile_numba(__closesthit__ch, fastmath=True) init_optix() From 09295e4e7fa5b2e662be73f1cdeb6987c0f9f34c Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 24 Feb 2022 19:47:34 -0800 Subject: [PATCH 23/25] Use IntEnum --- examples/numba_support.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/numba_support.py b/examples/numba_support.py index 49a6643..9f1e9a3 100644 --- a/examples/numba_support.py +++ b/examples/numba_support.py @@ -6,6 +6,7 @@ from operator import add, mul, sub from typing import List, Tuple +from enum import IntEnum from llvmlite import ir from numba import cuda, float32, int32, types, uchar, uint8, uint32 @@ -283,17 +284,16 @@ def __init__(self, dmm, fe_type): ) -OPTIX_RAY_FLAG_NONE = 0 -# class OptixRayFlags(Enum): -# OPTIX_RAY_FLAG_NONE = 0 -# OPTIX_RAY_FLAG_DISABLE_ANYHIT = 1 << 0 -# OPTIX_RAY_FLAG_ENFORCE_ANYHIT = 1 << 1 -# OPTIX_RAY_FLAG_TERMINATE_ON_FIRST_HIT = 1 << 2 -# OPTIX_RAY_FLAG_DISABLE_CLOSESTHIT = 1 << 3, -# OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES = 1 << 4 -# OPTIX_RAY_FLAG_CULL_FRONT_FACING_TRIANGLES = 1 << 5 -# OPTIX_RAY_FLAG_CULL_DISABLED_ANYHIT = 1 << 6 -# OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT = 1 << 7 +class OptixRayFlags(IntEnum): + OPTIX_RAY_FLAG_NONE = 0 + OPTIX_RAY_FLAG_DISABLE_ANYHIT = 1 << 0 + OPTIX_RAY_FLAG_ENFORCE_ANYHIT = 1 << 1 + OPTIX_RAY_FLAG_TERMINATE_ON_FIRST_HIT = 1 << 2 + OPTIX_RAY_FLAG_DISABLE_CLOSESTHIT = 1 << 3, + OPTIX_RAY_FLAG_CULL_BACK_FACING_TRIANGLES = 1 << 4 + OPTIX_RAY_FLAG_CULL_FRONT_FACING_TRIANGLES = 1 << 5 + OPTIX_RAY_FLAG_CULL_DISABLED_ANYHIT = 1 << 6 + OPTIX_RAY_FLAG_CULL_ENFORCED_ANYHIT = 1 << 7 # OptiX types From 117529f0ac95494559f6b14bc8edfdabc09f2bad Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 24 Feb 2022 19:59:54 -0800 Subject: [PATCH 24/25] Update readme with triangle example --- README.md | 31 +++++++++++++++++-------------- triangle.png | Bin 0 -> 20688 bytes 2 files changed, 17 insertions(+), 14 deletions(-) create mode 100644 triangle.png diff --git a/README.md b/README.md index 2fbfff0..f92a00a 100644 --- a/README.md +++ b/README.md @@ -45,32 +45,35 @@ pip3 install --global-option build --global-option --debug . The example can be run from the examples directory with: ``` -python examples/hello.py +python examples/.py ``` -If the example runs successfully, a square will be rendered: +If the example runs successfully, the example output will be rendered: -![Example output](example_output.png) +![Hello output](example_output.png) +![Triangle output](triangle.png) +Currently supported examples: +- hello.py +- triangle.py ## Explanation The Python implementation of the OptiX kernel and Numba extensions consists of -three parts, all in [examples/hello.py](examples/hello.py): - -- Generic OptiX extensions for Numba - these implement things like - `GetSbtDataPointer`, etc., and are a sort of equivalent of the implementations - in the headers in the OptiX SDK. -- The user's code, which I tried to write exactly as I'd expect a PyOptiX Python - user to write it - it contains declarations of the data structures as in - hello.h, and the kernel as in hello.cu - you can, in this example modify the - Python `__raygen__hello` function and see the changes reflected in the output - image. +three parts: + +- Generic OptiX extention types for Numba. These include new types introduced in +the OptiX SDK. They can be vector math types such as `float3`, `uint4` etc. Or it +could be OptiX intrinsic methods such as `GetSbtDataPointer`. These are included in +examples/numba_support.py. We intend to build more examples by reusing these extensions. +- The second part are the user code. These are the ray tracing kernels that user +of PyOptiX will write. They are in each of the example files, such as `hello.py`, +`triangle.py`. - Code that should be generated from the user's code - these tell Numba how to support the data structures that the user declared, and how to create them from the `SbtDataPointer`, etc. I've handwritten these for this example, to understand what a code generator should generate, and because it would have taken too long and been too risky to write something to generate this off the bat. The correspondence between the user's code and the "hand-written - generated" code is mechanical - there is aclear path to write a generator for + generated" code is mechanical - there is a clear path to write a generator for these based on the example code. diff --git a/triangle.png b/triangle.png new file mode 100644 index 0000000000000000000000000000000000000000..938219690a7d18de7cb46f0dcc205ba44fb43cfe GIT binary patch literal 20688 zcmeHvdtA)f|NlFowJhsSY_sduT2h42WwI{0$J*U)m{tf;?S@Lt1-na|O0~M^N^Ev@ zQK^a1%=qZS7UMcaO*1YPrkZZk)ZE_R*O{3JH75W49>4vAYC7+8Uf1XAcHRZ8UuXB( zr-q+m81~u96-&Rvu>SC^KK98#_;TgS=+7{0ee}wui@*1aZ{YbqFuCRYdApJs@W0Vt zeff!AzfW#VF`vV^zCE$-$sA7P+(~C{crW%_thabuV%_49r4hE-j>Y4=r$#J|oNAbT z_^ZJw5hnz+Uy~>O(qb@i0f~R#izt3@F+e0$a1T28KD?hRd{FY_x4OSykrSoGb(Cw@ z2LKKEyX;BQVek#smWuwzMy^Ny(*}KwzOeg`LSJ0_AqFrrJ@f}7c{dU+ag#m{Bz7c@oW*TuH-O9r$ zisJ*I@Yfk&HqyVTlE}HzM(B@?TLqvnEn6wzwWJ*_ZE5qus}#xbnAu=}j+h(I7`VNu zlAwph1mTQfSfnGhsd0P|o|2Tg4{er=HuJmJh;L*m^sy=%^bA|7&k2-JWv%mw0-&X8 zk?KP-17D|yu_wYOg5q>y({o>(VA|iRU}mYs0$51d%eCE85%Ih(G&{w%sX;RgbvukG3Cfts&<$ zqTW@Jm?u(tWfJAqXrEl9qzQL?B>X|J^fRM_zNC$@GNX(MZzx? zh`K-0`8){^-&;>qctRKuID^>OqjO}#u%+m`lg_MwG6r5+`-rd~Hq3|s9ya~Gx?X7s z@vD3_$U5VWbh8^@LA_dFwJ8dJ%7}+^JW|j(FCsuyu*ZIgeY6=ll=AJ>3etQsu@TPQ zG#arp=K!up8-l)@B=p_V&chqr<+&nc5E;LrgN8G3%r**r7bLdiG&kbyb(xm1++i^~ zxNy8Lo-hQX36U@yTos`{puG@0MC1({+y;qY8SGeC0t5+@VTm!FQnHY*;7bP~Mjs&V zpJ)7}lKKK7?7ARJT**qgb|3KPM4s>v-&(KFm8u5Sf zs95q;^!#aYWrh%#$Y+W7B85DHNXj1c2pq^6sV35i)vyT8x54=fVR>qc;2srdh&8gh z1meC_2SyZsLz+&{lHohcsCYWs!U=5=RCNy|X@IdWqn}n*J#NG&Q!7C|3kL9kOXHtc ziM~ihHVa`L?sB@ zl^X&oC1*RT?YEU}*(2gK`N^n?fuLDMv$za=9-x2- z9|nvj#~MT{^)a@If76{t>e(+s7J@n(7Qo7oH0rfRYAG0ovYCWD#2LZPa?q(qrNV$# zihdecuarDeQ1{{X$cuL)r5H9fQq%T=;mj(5oz7pPgz)f01|F8dYMO_5{|v}F4|;KA z?2zz-;Cmnf8(psqQ|&`VRPRFjeB&n}ay@fpqA3{md?C=YQkX=S5r2{3;Usx0+Ua-D z>J)!il?UEXmkiqs5<2irI0y4*+Kc2?;tKH#_eYmKMTA^B+5QiaVrg-GQM+=8BsV3i zL9s!OD|aLQ|3te~fUX*p`fPus%+ulvNLO}lm9m|3{Ab|&5A^X@3_R-qhetK;hb05I z*eFNVzYIJY0#2o5HK_))i~!J950>SDhc6%iC!aW@HQ@*5<`r-XlboY!k`RAisbVC} zqgiko>AbCxHBt}z{291&|1|3DMjT*@&{;=CcCg?Ku%%bsMIL;@h6Cy#Ane0Uk<0^- zGcFKlY$-ktxl)*TMuU=cy%A3lno`e!nOw93`87U)MY16P4E-@Sf^ydq>1f~59HUvK z_brL`)(5PmP6sk7Wp)jF`H@cph~cvXvI>aT$qhUA1SAC!#3sX{YeW%JWpXEFS|ZaI9C_Ro<;=SAGYkDMzj(07!Cm z0_)yO5;me>EP(xH3X`&A)KawO0udp&5B$RQAks%BSrrI&0Q>9*TvLWiBIF8*)t&yY zvV?az**?5#;%+Zl;wZi>==h`H2F3VB4`fkmRnV)*GLX_R!Ba4T(2=7xLgLFJ2!5%C?qbetjIf8!L zrI^H*i{kLsa(8k#D7eZ9#OxWN@u@N}Sz25UCEORE+~o@mMQ3o{!A+1U2)H0%Zy&X| zMRoi1F#2+|KAUpd7cY>JsN}X*!CW}(GyBW&DCv@i=Te0O8I(i-e4f4CYb zenbJEu&YwR;=ky`Tt^whuLcQQ>y?!XFP03yG1!U&oMj*L7(+?*|hUyh1S|Pex^( z@=y326BxoS-al|T;#)YglcjW)-ogXiNF6>%`UjEpf6z$3e5=ZLGeSUO4sMD~BBkA9 z00Wzm#8;KU+eb9~Sbi)b3uhvx2jM`u!*vTBg;|ligg|qGV8s5x5r;GmkhoGo9TgBx zMDSmSpx_5}_`@j>lw)FH|5Q39R3rZ%`Er(|J~%KGHdbYasqPy3s{xeSh`>6FLm8N;W#NfP}xzzGXxBZ^XM;~=Q-tlw#je{J+~#>#l}po7bAU7afIM>F>U z)V=SFpM}_BX;@k9{?v{fkhSIv@0H<(Iqf@nERj6~A@)KfENA4=DOo8A@)W}K-d96`tgH`D;ZDTeB``4%yweph7=E^mp=_> zMiefLJn#OX%q4pkmOl`wZjmqv@}xIt*0o7gT7XkjO%eVpk_E!()MsE>mWB+|pF#>J zlcxsxOxwUYA4}i1`P^(Qe-v2Sy6uX3MvW6hp_d6sP@5SkW?GgvC{o(_um zWh=9yA^D+2He~`ghRjYPH?+8M^{UCuD@ZG!#;<%}zPr}M$zj(JY}TzoK;%z~XMjI% zaAHo-feRi>Z&pO?=Z_5tQk62%$WEps)9q=e(s9`FGq9j2J*g%RC$~}+;+bGKV*_q- zegQg4p*9jc)z~qZBi4v^TLz1JQg_}q-qb9d4XU1XgVhhE?zL(mPk%?qCT@^%j|6)! zla@eeeBT?U3TK6wlxAOjK!93amHY{|qEyh%%`T{JgVjBUhz0H=#K%whujW3S#FeoI zU{#QG(1IO__bqT)EiM&4f=H1mI!C+@)D(j%1e3aPfx^c)PsN%ZJrSL4{Wrq6<*lFv zvv@D1%4!R9=FTR6*If<4j+cgE2Y&$vIL#&QehbVZchavc|MR$slkcv2CNsJ-1;x`) zu!bclg$ykYjVL?Cd$Hz?O4WM-p{_Av^T1Y?mok&JMQDMG#EPRyPk>%G6k3nJ!9!cs zZfQx7XT9%upow!?5zJ^PifMZC8BmWrQG4*$r}u=%L`}IJyVN>mM|H&h++z>h^lXQL z@7&>CdHg1yqL1wYg4Nv3hNW!FK7W;B3S2P&p0=g(*M$7hr z*}Om+%IJznrK&U()h@Oc!qm5;Y&YShE#sF$No{Ma91U&h=39^y&pq&`Ro;qkSC6P$ zb0%?vx!27r&V|krJ=@iYi6W_8uLhM{FVm+M0t=5p%*lgMl6JRZLrcY^K_+hyZTD~w zo&lDQ7_E}-Vs`Q^@N$wn7;*vAp<^zvu@kxqwAwB7SOu0;m)4{ zFlf63v2m(J!^YQnD;B)P#&XcbV8lk5^zUB$I=XmyTfx&*wAlo-5?O?$q2;PkE>ohr z!I;U`$x&xFt-e6i1U|I zRuxC3G3L(yN7O7McH9J?SdWZiw>s&Px^UJ&L*=ba7rl(!-Mf zP}98TJIBXw;llhgnNx;h8$e1Y@+gpB_2MQQ$dV#bXCzb)Um{}Ws_epX?RVoEx&^x3 z!zbKz7-2vg0koY>OM0Qk%7+#>A18_mIxYmB0h{_0{(V`bDo9a8E>#Y})}wT8tI+sh zZ>h6V)hkui`xO6zp~t z?FQ)cxjG9-n?L<{+sW<_67s@hya7#ymSj_^-fnxeEB8to6zzN)zAP<;xv(+D^zEoI|BTnRiC7dWKZd3eJhbR1tb-tRg#9YuAG<&uhH_Vu zx@RyKmhD+e%JdWA3;TKVfWV4e%5e<@-UtII4J-Wuro_WBJh=0YwZC1*sbu4An0}cV%2j_gjf#3Y=4D#!f{vZTFx%6w0NCzsWtd6LHLfZV7!S9-ea5QJkCs+(BgjjBh6{^xlm>Evs%Si;HR3v_( z%q89f=%JvrjD^bly%LT0Jj#d8#wz&T#hInqX$0tnLgJsh?dao1stu)$?-776BB0O+ z1jy&VPXA41=G!1E(z^lmp@4n>AIq0hAG|5yx}h@oURXzB3dtLgM5rZKGW!wd@Vj`{WRdrfWzeqZ;;OpXlk1+RD(*7La#n!i?Oc)8`JVA zF&d+Z@JfXmsl0)#{_5L9-bE4?(ES*Pv8;!&ei$vQWc{U4o8Ij>zx&Qm4SjwOi1pph_PK^@3mcJ<@*TP$Kt z#^S?Z;YcT9p%&0GYPL=vDJ9Ck0=<@ftB1Kx>OxiKd#JjlIdQ<@F~8|!88H7~PvH(@ zzTH}>W;HMi_)5m}fTaCxqIg2c>+Xs~A#1(c)fFhoea`LI zIIlUFV$84aIei>Xw!3bDBZw&79(jlR!vv}*9$`APk-x>{Hm;!I>~*I>#BNrOA9*#O!`-oaqt zo`3cVuQh7NYFF_Ej-H_EYCp`o zh%33`lBl@Th;MkWyP+zwHjq_5R+VleyTuoJcK%HuZ)qSH&;}zt|AF9ItDcRL2C)|+ zvpdI?psw3{?GJUnjKcW&DGaeYt0 zMz6)T%BkJI-)cQ+7xmsmFGU7~f%pV)ItM^ZXntvNx;AAXcAEnp6zRf;Iz$^XrHl{` zUnwJd_{aG|S}G_5gk9Ho4O&%D!1fv%PWd0C+F<3^Kx4*v%vN4YJ+!~WdT!sSki-0& zETiyCcHEZvqCmR*7OVa0I%+>`HwPi?(T3 zNL0}N9f`wztadT!UJ@bg}SZSW3Ga-3Z!@F;Ht;d>voduwA7sk3niD< zq?xxKk9o%ItZw{56qVG0MkVnmeK7_SwZqmyDrKJycKbS_%83v2PHh@887VBYnp}IV zOmaZvO_vu^q3g?bLZX9v2_F=lU7fyQ6PVh9aPdOz&ax}k(Y<#3Y`cBT)Tr(Lvbri0 zT+D|iq+keoysLuj9baUvDXn_mHjqdfXW7)jtr^Hdjp>@XTufVz@&5VOZ0b@SdUHekec^6cLU|^jqIY1{>OqXR0NkXPV0X~Xhq2jA7?EpB7&ZLd zoqUs8LPyvz0wLDtAB;1|4swn3Z3!r5u%%qxtd27UPpF1b8SO4Q1=INyBAhZLk9krz zUd(S&+o~)JPHx1o$q;SLVpLtT580V>m|il9oCBTt_jJDLXYbSv5EF<>HeIo+hyikL%A!J)iEUntuBxtn;izt9{F%mnM6p) zEQ%SvY$%p^6Kqb8YeCnVbJ&*_E;8pFTj*>IHkbH(b9EIcsu~nEPB#}aDL(%7FEI3g zCtmC!o9pDQDf=T>7PU4a=??Iuw&(rfx{FS5=P`T7(r|s6G0daj{>$1AWVGFdc*92| z{Rw?a@ z$js9HtVsVxUUda%jEm&2tH!n^>L*)+(Hpw(0aHIuRoyS=^ESt{Y+qr%=7tv|lc96Dfa3ry;E^(NSyDS~sm6LYo7k zoSV%&)CWvxf@b@#n#vQ1Ok6EmS= zt+0pg;|ZwiZKvvn%Nl-BoEgn6&($4UvN;$sCO4+w{! zL%C;rLT7BXm^6@W11?GvKNLl|6EDfMwu?1|#-(*FlHzSi!tnuo2bPO6GKpMmZQoA- zx}g~atAIrpSVu=I!0;AaWWrS!U9vJYGI%bnS8QqU>fQFkah;OH(jwJoWZO{g?y? z7~WJ$KLu0hdHAg!v%Kv$mTvX>q-(5CV6+27%ZKV5UmDA1W6-OJf)@|OHND)T=@}go z7Pc=orPAdmsr3wucG-Ub<_h--?)En8>FPN_Q3GKO9ZLxwiLu+b(a_`8(x-a|li<|z zip8c7O;1vx|7&In$Et30YHbTvt26a7A9wIo1~ekq>?J#CAM?F;Zr)yMQ>k>U9?c&t zc+D-H(caQ0t=5U=0yi_(c#y~UdPPRv8GgrxLmRAR*GzrrayvX))DNqDHX3tv)*T}f z7XH{elbsHAP|Q8)y*^=+c}ZxZ;zTslMcHBsGm|nfH^&HH3q2B;3MspED+gRa^T|KlKY^HCF(RQVo{-jfDO%w4feP+t)k1`8+ zEz!!qH)8Ati0^-)%ig*bGQPDH=iT)(;I5+CZ{ngpRB3>ap$~0jRbwIAJ|)5{3Hf{B zMqte0zUVk5aX|h7z#9uM^6b8Wq+5C-xiOe$LTx5Nf>AYV04Ata-{jKTEY-xmgoWIe zN7XDlY!2djM{vw*V?W^1{Dc@oiZS`JJsJ?}~X*GI>8 zwq{tgUng~Vc3vaFg|>Wn1&zH{_e}lvM32rrBK~B|{po7#P+Si4&WA@7?|~`z)uWUw zLi4x215)%*hh0#sd%oo!+QmOy^6>mudbp3C?jhc@uiow|yt$~4|LY6#hl|$v)BBQ< zm?SkQ0{RP%k3u)VZ7Tax)&oXea9>?Qd7dMR65$r0{&IM+Cx84$9G?ak5!q`UN$ftt zYwp3kef25rc1#*|IF?UA{;PX!2+^!2r+^xU#RKGci@64U7yiqNy|NUeDigkB0&_5y zzXi^MTPQYX<-MYo}ElSG#Am0clHk- zNdp*R2KP#z3T1F1@t&6zDIDpG8qdyZxGgA)xOMIE&(N63k3Akqa=T)Gp97M~7jX*9 zmf0){J&-dP5n%vC{L^DUFCJzZ-lMssJ|@_J&G)j>m7#(4nLHD|A^>)|jpFR5eZ#2qPV4Vl zWn%?~ppR6#@}G}z(!HI&&*23Yn??r!%Y#3GakfG!xDGCfe1tSM!JFxV2V9c6-KTk% z>&>eIy%jksX_(-lR!S#=WQK34Z^1Fp3VQz{Rkw*0>NBYlmQ(!6^H@~9j3Cg8v(U7% zFRe8VVnk8x65AdT@KT^dX@lee=(A3E%}bpC>6|@)tK1TiR4;DF-;;7a#$QWWlO^Q# z5Rd`j_)>Fbre!TFERp*zw>8;`;F4Q(PUoYY~#`zaJE)r$~FO6&hk|Fe*~mgwun~j zZ5ALDd08il@(P#=x+w*(nRUDGC^FKI=7K`p1F6qPn#kj)*Pb=lOD{8ODz~7pKL<1& ze6T4_wVn3M8|vu}RMsX*2wLOV`c@t=4KN;lD z1C#X@+J7Xt?hZBxZ%>uC@mfqW;l)|>hLKs-5h}ds_n~V%-78Z7`rJW_*d`FGes{*9 zf1!ILT2&lIg+kyvo9DOwp7KwY4?VDS0!+M)pN0WFye8IHc%AkFqr=~78Mi|LSKD+S ztbm=RCz^brcAqcrD<33Xabb4e0yB|NWOke6F$OsPUxeHz{w18Ngq2Vs#Fz1UvANnF-1r1WzCG5xax<~v1GWe zm&`E0)pt&GYcQ)5HRmaU=zV#_Vd#x(S_YSp*SBtHH)nPpL$GN$3K%D~1&p&n+cB&0 zkxLF8jf2}j4OV1yUm4y>LT?)X_b062eihsn|L+f-!8 b9l{AS*C(6Z!INx}09Gzrw>0Bh*Teq@BRf(| literal 0 HcmV?d00001 From 4bbf2840fdb652440a1ae1f1ef6f94dc1221b9fb Mon Sep 17 00:00:00 2001 From: Michael Wang Date: Thu, 24 Feb 2022 20:23:03 -0800 Subject: [PATCH 25/25] Move vector maths to vec_math.py --- examples/vec_math.py | 55 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 examples/vec_math.py diff --git a/examples/vec_math.py b/examples/vec_math.py new file mode 100644 index 0000000..f533c90 --- /dev/null +++ b/examples/vec_math.py @@ -0,0 +1,55 @@ +import math + +from numba import cuda, float32, types +from numba.core.extending import overload +from numba_support import float3, make_float3 + +__all__ = ["clamp", "dot", "normalize"] + + +def clamp(x, a, b): + pass + + +@overload(clamp, target="cuda", fastmath=True) +def jit_clamp(x, a, b): + if ( + isinstance(x, types.Float) + and isinstance(a, types.Float) + and isinstance(b, types.Float) + ): + + def clamp_float_impl(x, a, b): + return max(a, min(x, b)) + + return clamp_float_impl + elif ( + isinstance(x, type(float3)) + and isinstance(a, types.Float) + and isinstance(b, types.Float) + ): + + def clamp_float3_impl(x, a, b): + return make_float3(clamp(x.x, a, b), clamp(x.y, a, b), clamp(x.z, a, b)) + + return clamp_float3_impl + + +def dot(a, b): + pass + + +@overload(dot, target="cuda", fastmath=True) +def jit_dot(a, b): + if isinstance(a, type(float3)) and isinstance(b, type(float3)): + + def dot_float3_impl(a, b): + return a.x * b.x + a.y * b.y + a.z * b.z + + return dot_float3_impl + + +@cuda.jit(device=True, fastmath=True) +def normalize(v): + invLen = float32(1.0) / math.sqrt(dot(v, v)) + return v * invLen