Skip to content

Commit e0183ea

Browse files
mvaligurskyMartin Valigursky
andauthored
Half-precision (f16) type aliases for WGSL shaders (playcanvas#8439)
Co-authored-by: Martin Valigursky <mvaligursky@snapchat.com>
1 parent e6855f8 commit e0183ea

6 files changed

Lines changed: 104 additions & 29 deletions

File tree

examples/assets/scripts/misc/hatch-material.mjs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -218,27 +218,27 @@ const createHatchMaterial = (device, textures) => {
218218
fn fragmentMain(input: FragmentInput) -> FragmentOutput
219219
{
220220
var output: FragmentOutput;
221-
var colorLinear: vec3f;
221+
var colorLinear: half3;
222222
223223
#ifdef TOON
224224
225225
// just a simple toon shader - no texture sampling
226-
let level: f32 = f32(i32(input.brightness * uniform.uNumTextures)) / uniform.uNumTextures;
227-
colorLinear = level * uniform.uColor;
226+
let level: half = half(i32(input.brightness * uniform.uNumTextures)) / half(uniform.uNumTextures);
227+
colorLinear = level * half3(uniform.uColor);
228228
229229
#else
230230
// brightness dictates the hatch texture level
231-
let level: f32 = (1.0 - input.brightness) * uniform.uNumTextures;
231+
let level: half = (half(1.0) - half(input.brightness)) * half(uniform.uNumTextures);
232232
233233
// sample the two nearest levels and interpolate between them
234-
let hatchUnder: vec3f = textureSample(uDiffuseMap, uDiffuseMapSampler, input.uv0 * uniform.uDensity, i32(floor(level))).xyz;
235-
let hatchAbove: vec3f = textureSample(uDiffuseMap, uDiffuseMapSampler, input.uv0 * uniform.uDensity, i32(min(ceil(level), uniform.uNumTextures - 1.0))).xyz;
236-
colorLinear = mix(hatchUnder, hatchAbove, fract(level)) * uniform.uColor;
234+
let hatchUnder: half3 = half3(textureSample(uDiffuseMap, uDiffuseMapSampler, input.uv0 * uniform.uDensity, i32(floor(level))).xyz);
235+
let hatchAbove: half3 = half3(textureSample(uDiffuseMap, uDiffuseMapSampler, input.uv0 * uniform.uDensity, i32(min(ceil(level), half(uniform.uNumTextures - 1.0)))).xyz);
236+
colorLinear = mix(hatchUnder, hatchAbove, fract(level)) * half3(uniform.uColor);
237237
#endif
238238
239239
// handle standard color processing - the called functions are automatically attached to the
240240
// shader based on the current fog / tone-mapping / gamma settings
241-
let fogged: vec3f = addFog(colorLinear);
241+
let fogged: vec3f = addFog(vec3f(colorLinear));
242242
let toneMapped: vec3f = toneMap(fogged);
243243
output.color = vec4f(gammaCorrectOutput(toneMapped), 1.0);
244244
Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
// Include half-precision type aliases (resolves to f16 when supported, f32 otherwise)
2+
#include "halfTypesCS"
3+
14
@group(0) @binding(0) var inputTexture: texture_2d<f32>;
25
@group(0) @binding(1) var inputTexture_sampler: sampler;
36
@group(0) @binding(2) var outputTexture: texture_storage_2d<rgba8unorm, write>;
@@ -14,43 +17,43 @@ fn main(@builtin(global_invocation_id) global_id : vec3u) {
1417

1518
// Sample the center pixel
1619
let uvFloat = (vec2f(uv) + vec2f(0.5)) / vec2f(texSize);
17-
var color = textureSampleLevel(inputTexture, inputTexture_sampler, uvFloat, 0.0);
20+
var color = half4(textureSampleLevel(inputTexture, inputTexture_sampler, uvFloat, 0.0));
1821

1922
// Sobel edge detection using 3x3 kernel
2023
let texelSize = 1.0 / vec2f(texSize);
2124

22-
// Sample 3x3 neighborhood and convert to grayscale
23-
var samples: array<f32, 9>;
25+
// Sample 3x3 neighborhood and convert to grayscale (using half precision)
26+
var samples: array<half, 9>;
2427
var idx = 0;
2528
for (var y = -1; y <= 1; y++) {
2629
for (var x = -1; x <= 1; x++) {
2730
let offset = vec2f(f32(x), f32(y)) * texelSize;
2831
let sampleUV = uvFloat + offset;
29-
let sampleColor = textureSampleLevel(inputTexture, inputTexture_sampler, sampleUV, 0.0);
32+
let sampleColor = half3(textureSampleLevel(inputTexture, inputTexture_sampler, sampleUV, 0.0).rgb);
3033
// Convert to grayscale using standard luminance weights
31-
samples[idx] = dot(sampleColor.rgb, vec3f(0.299, 0.587, 0.114));
34+
samples[idx] = dot(sampleColor, half3(0.299, 0.587, 0.114));
3235
idx++;
3336
}
3437
}
3538

3639
// Sobel horizontal and vertical kernels
3740
// Horizontal: [-1, 0, 1; -2, 0, 2; -1, 0, 1]
38-
let gx = -samples[0] + samples[2] - 2.0 * samples[3] + 2.0 * samples[5] - samples[6] + samples[8];
41+
let gx: half = -samples[0] + samples[2] - half(2.0) * samples[3] + half(2.0) * samples[5] - samples[6] + samples[8];
3942

4043
// Vertical: [-1, -2, -1; 0, 0, 0; 1, 2, 1]
41-
let gy = -samples[0] - 2.0 * samples[1] - samples[2] + samples[6] + 2.0 * samples[7] + samples[8];
44+
let gy: half = -samples[0] - half(2.0) * samples[1] - samples[2] + samples[6] + half(2.0) * samples[7] + samples[8];
4245

4346
// Calculate edge magnitude
44-
let edgeStrength = sqrt(gx * gx + gy * gy);
47+
let edgeStrength: half = sqrt(gx * gx + gy * gy);
4548

4649
// Make edges red: stronger edges = more red
47-
let edgeAmount = clamp(edgeStrength * 3.0, 0.0, 1.0);
48-
let edgeColor = vec3f(1.0, 0.0, 0.0); // Red
50+
let edgeAmount: half = clamp(edgeStrength * half(3.0), half(0.0), half(1.0));
51+
let edgeColor = half3(1.0, 0.0, 0.0); // Red
4952

5053
// Blend original color with red edges
51-
var finalColor = mix(color.rgb, edgeColor, edgeAmount);
54+
var finalColor: half3 = mix(color.rgb, edgeColor, edgeAmount);
5255

53-
// Write to output storage texture (no channel swap - keep edges red)
54-
textureStore(outputTexture, vec2i(uv), vec4f(finalColor, color.a));
56+
// Write to output storage texture (convert half back to f32 for storage)
57+
textureStore(outputTexture, vec2i(uv), vec4f(vec3f(finalColor), f32(color.a)));
5558
}
5659

src/platform/graphics/graphics-device.js

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,18 @@ class GraphicsDevice extends EventHandler {
367367
*/
368368
supportsPrimitiveIndex = false;
369369

370+
/**
371+
* True if the device supports 16-bit floating-point types in shaders (WebGPU only). When
372+
* supported, shaders can use native WGSL types: `f16`, `vec2h`, `vec3h`, `vec4h`, `mat2x2h`,
373+
* `mat3x3h`, `mat4x4h`. For convenience, PlayCanvas also provides type aliases (`half`,
374+
* `half2`, `half3`, `half4`, `half2x2`, `half3x3`, `half4x4`) that resolve to f16 types when
375+
* supported, or fall back to f32 types when not supported.
376+
*
377+
* @type {boolean}
378+
* @readonly
379+
*/
380+
supportsShaderF16 = false;
381+
370382
/**
371383
* True if 32-bit floating-point textures can be used as a frame buffer.
372384
*
@@ -598,6 +610,7 @@ class GraphicsDevice extends EventHandler {
598610
if (this.textureFloatRenderable) capsDefines.set('CAPS_TEXTURE_FLOAT_RENDERABLE', '');
599611
if (this.supportsMultiDraw) capsDefines.set('CAPS_MULTI_DRAW', '');
600612
if (this.supportsPrimitiveIndex) capsDefines.set('CAPS_PRIMITIVE_INDEX', '');
613+
if (this.supportsShaderF16) capsDefines.set('CAPS_SHADER_F16', '');
601614

602615
// Platform defines
603616
if (platform.desktop) capsDefines.set('PLATFORM_DESKTOP', '');
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/**
2+
* WGSL shader chunk providing half-precision type aliases. When the device supports f16
3+
* (CAPS_SHADER_F16), these resolve to native f16 types. Otherwise, they fall back to f32.
4+
*
5+
* Available types: half, half2, half3, half4, half2x2, half3x3, half4x4
6+
*
7+
* Usage in WGSL shaders:
8+
* - Vertex/Fragment: automatically included
9+
* - Compute: #include "halfTypesCS"
10+
*
11+
* @ignore
12+
*/
13+
export default /* wgsl */`
14+
#ifdef CAPS_SHADER_F16
15+
alias half = f16;
16+
alias half2 = vec2<f16>;
17+
alias half3 = vec3<f16>;
18+
alias half4 = vec4<f16>;
19+
alias half2x2 = mat2x2<f16>;
20+
alias half3x3 = mat3x3<f16>;
21+
alias half4x4 = mat4x4<f16>;
22+
#else
23+
alias half = f32;
24+
alias half2 = vec2f;
25+
alias half3 = vec3f;
26+
alias half4 = vec4f;
27+
alias half2x2 = mat2x2f;
28+
alias half3x3 = mat3x3f;
29+
alias half4x4 = mat4x4f;
30+
#endif
31+
`;

src/platform/graphics/shader-definition-utils.js

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import wgslFS from './shader-chunks/frag/webgpu-wgsl.js';
1515
import wgslVS from './shader-chunks/vert/webgpu-wgsl.js';
1616
import sharedGLSL from './shader-chunks/frag/shared.js';
1717
import sharedWGSL from './shader-chunks/frag/shared-wgsl.js';
18+
import halfTypes from './shader-chunks/frag/half-types.js';
1819

1920
/**
2021
* @import { GraphicsDevice } from './graphics-device.js'
@@ -115,12 +116,8 @@ class ShaderDefinitionUtils {
115116

116117
const getDefinesWgsl = (isVertex, options) => {
117118

118-
let code = '';
119-
120119
// Enable directives must come before all global declarations
121-
if (!isVertex && device.supportsPrimitiveIndex) {
122-
code += 'enable primitive_index;\n';
123-
}
120+
let code = ShaderDefinitionUtils.getWGSLEnables(device, isVertex ? 'vertex' : 'fragment');
124121

125122
// Define the fragment shader output type, vec4 by default
126123
if (!isVertex) {
@@ -151,6 +148,7 @@ class ShaderDefinitionUtils {
151148
vertCode = `
152149
${getDefinesWgsl(true, options)}
153150
${vertexDefinesCode}
151+
${halfTypes}
154152
${wgslVS}
155153
${sharedWGSL}
156154
${options.vertexCode}
@@ -159,6 +157,7 @@ class ShaderDefinitionUtils {
159157
fragCode = `
160158
${getDefinesWgsl(false, options)}
161159
${fragmentDefinesCode}
160+
${halfTypes}
162161
${wgslFS}
163162
${sharedWGSL}
164163
${options.fragmentCode}
@@ -205,6 +204,26 @@ class ShaderDefinitionUtils {
205204
};
206205
}
207206

207+
/**
208+
* Generates WGSL enable directives based on device capabilities. Enable directives must come
209+
* before all global declarations in WGSL shaders.
210+
*
211+
* @param {GraphicsDevice} device - The graphics device.
212+
* @param {'vertex'|'fragment'|'compute'} shaderType - The type of shader.
213+
* @returns {string} The WGSL enable directives code.
214+
* @ignore
215+
*/
216+
static getWGSLEnables(device, shaderType) {
217+
let code = '';
218+
if (device.supportsShaderF16) {
219+
code += 'enable f16;\n';
220+
}
221+
if (shaderType === 'fragment' && device.supportsPrimitiveIndex) {
222+
code += 'enable primitive_index;\n';
223+
}
224+
return code;
225+
}
226+
208227
/**
209228
* @param {GraphicsDevice} device - The graphics device.
210229
* @param {Map<string, string>} [defines] - A map containing key-value pairs.

src/platform/graphics/shader.js

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import { Preprocessor } from '../../core/preprocessor.js';
55
import { SHADERLANGUAGE_GLSL, SHADERLANGUAGE_WGSL } from './constants.js';
66
import { DebugGraphics } from './debug-graphics.js';
77
import { ShaderDefinitionUtils } from './shader-definition-utils.js';
8+
import halfTypes from './shader-chunks/frag/half-types.js';
89

910
/**
1011
* @import { BindGroupFormat } from './bind-group-format.js'
@@ -139,12 +140,20 @@ class Shader {
139140
this.cUnmodified = definition.cshader;
140141
});
141142

142-
// Prepend defines to compute shader source
143+
// Prepend enables and defines to compute shader source
144+
const enablesCode = ShaderDefinitionUtils.getWGSLEnables(graphicsDevice, 'compute');
143145
const definesCode = ShaderDefinitionUtils.getDefinesCode(graphicsDevice, definition.cdefines);
144-
const cshader = definesCode + definition.cshader;
146+
147+
const cshader = enablesCode + definesCode + definition.cshader;
148+
149+
// Add built-in halfTypesCS include for compute shaders (if not already provided by user)
150+
const cincludes = definition.cincludes ?? new Map();
151+
if (!cincludes.has('halfTypesCS')) {
152+
cincludes.set('halfTypesCS', halfTypes);
153+
}
145154

146155
// pre-process compute shader source
147-
definition.cshader = Preprocessor.run(cshader, definition.cincludes, {
156+
definition.cshader = Preprocessor.run(cshader, cincludes, {
148157
sourceName: `compute shader for ${this.label}`,
149158
stripDefines: true
150159
});

0 commit comments

Comments
 (0)