From aad292dd787d2cd7e0be8a936a5603e2acad8b15 Mon Sep 17 00:00:00 2001
From: Raph Levien <raph.levien@gmail.com>
Date: Mon, 1 Apr 2024 09:29:37 -0700
Subject: [PATCH] Add robustness to GPU shaders (#537)

* Add robustness to GPU shaders

Make each stage quit early if a previous stage has failed.

The CPU shaders are minimally changed to be layout compatible. For the most part, they'll panic on a bounds check if sizes are exceeded. That's arguably useful for debugging, but a case can be made they should have the same behavior as the GPU shaders.

Work towards #366

* Address review feedback

Clarify some nits, and also make a distinction between reporting failure in path_count and coarse.
---
 crates/encoding/src/config.rs       |  6 ++++++
 shader/binning.wgsl                 | 14 +++++++++++++-
 shader/coarse.wgsl                  | 19 ++++++++++++++++---
 shader/fine.wgsl                    |  5 +++++
 shader/flatten.wgsl                 |  4 +++-
 shader/path_count.wgsl              | 16 +++++++++++-----
 shader/path_count_setup.wgsl        |  8 ++++++--
 shader/path_tiling_setup.wgsl       | 13 +++++++++++--
 shader/shared/bump.wgsl             |  5 +++--
 shader/shared/config.wgsl           |  2 ++
 shader/tile_alloc.wgsl              |  9 +++++----
 src/cpu_shader/path_count.rs        | 11 ++++++-----
 src/cpu_shader/path_tiling_setup.rs |  1 +
 src/render.rs                       | 11 +++++++++--
 src/shaders.rs                      |  4 ++--
 15 files changed, 99 insertions(+), 29 deletions(-)

diff --git a/crates/encoding/src/config.rs b/crates/encoding/src/config.rs
index a6a40db86..78c9d7f0f 100644
--- a/crates/encoding/src/config.rs
+++ b/crates/encoding/src/config.rs
@@ -137,10 +137,14 @@ pub struct ConfigUniform {
     pub base_color: u32,
     /// Layout of packed scene data.
     pub layout: Layout,
+    /// Size of line soup buffer allocation (in [`LineSoup`]s)
+    pub lines_size: u32,
     /// Size of binning buffer allocation (in `u32`s).
     pub binning_size: u32,
     /// Size of tile buffer allocation (in [`Tile`]s).
     pub tiles_size: u32,
+    /// Size of segment count buffer allocation (in [`SegmentCount`]s).
+    pub seg_counts_size: u32,
     /// Size of segment buffer allocation (in [`PathSegment`]s).
     pub segments_size: u32,
     /// Size of per-tile command list buffer allocation (in `u32`s).
@@ -175,8 +179,10 @@ impl RenderConfig {
                 target_width: width,
                 target_height: height,
                 base_color: base_color.to_premul_u32(),
+                lines_size: buffer_sizes.lines.len(),
                 binning_size: buffer_sizes.bin_data.len() - layout.bin_data_start,
                 tiles_size: buffer_sizes.tiles.len(),
+                seg_counts_size: buffer_sizes.seg_counts.len(),
                 segments_size: buffer_sizes.segments.len(),
                 ptcl_size: buffer_sizes.ptcl.len(),
                 layout: *layout,
diff --git a/shader/binning.wgsl b/shader/binning.wgsl
index daeb6f095..55d80c350 100644
--- a/shader/binning.wgsl
+++ b/shader/binning.wgsl
@@ -53,6 +53,7 @@ var<workgroup> sh_bitmaps: array<array<atomic<u32>, N_TILE>, N_SLICE>;
 // store count values packed two u16's to a u32
 var<workgroup> sh_count: array<array<u32, N_TILE>, N_SUBSLICE>;
 var<workgroup> sh_chunk_offset: array<u32, N_TILE>;
+var<workgroup> sh_previous_failed: u32;
 
 @compute @workgroup_size(256)
 fn main(
@@ -63,7 +64,18 @@ fn main(
     for (var i = 0u; i < N_SLICE; i += 1u) {
         atomicStore(&sh_bitmaps[i][local_id.x], 0u);
     }
-    workgroupBarrier();
+    if local_id.x == 0u {
+        let failed = bump.lines > config.lines_size;
+        sh_previous_failed = u32(failed);
+    }
+    // also functions as barrier to protect zeroing of bitmaps
+    let failed = workgroupUniformLoad(&sh_previous_failed);
+    if failed != 0u {
+        if global_id.x == 0u {
+            bump.failed |= STAGE_FLATTEN;
+        }
+        return;
+    }
 
     // Read inputs and determine coverage of bins
     let element_ix = global_id.x;
diff --git a/shader/coarse.wgsl b/shader/coarse.wgsl
index 758244c84..c28f8d28d 100644
--- a/shader/coarse.wgsl
+++ b/shader/coarse.wgsl
@@ -73,6 +73,9 @@ fn alloc_cmd(size: u32) {
         let ptcl_dyn_start = config.width_in_tiles * config.height_in_tiles * PTCL_INITIAL_ALLOC;
         var new_cmd = ptcl_dyn_start + atomicAdd(&bump.ptcl, PTCL_INCREMENT);
         if new_cmd + PTCL_INCREMENT > config.ptcl_size {
+            // This sets us up for technical UB, as lots of threads will be writing
+            // to the same locations. But I think it's fine, and predicating the
+            // writes would probably slow things down.
             new_cmd = 0u;
             atomicOr(&bump.failed, STAGE_COARSE);
         }
@@ -152,11 +155,19 @@ fn main(
     // We need to check only prior stages, as if this stage has failed in another workgroup, 
     // we still want to know this workgroup's memory requirement.   
     if local_id.x == 0u {
+        var failed = atomicLoad(&bump.failed) & (STAGE_BINNING | STAGE_TILE_ALLOC | STAGE_FLATTEN);
+        if atomicLoad(&bump.seg_counts) > config.seg_counts_size {
+            failed |= STAGE_PATH_COUNT;
+        }
         // Reuse sh_part_count to hold failed flag, shmem is tight
-        sh_part_count[0] = atomicLoad(&bump.failed);
+        sh_part_count[0] = u32(failed);
     }
     let failed = workgroupUniformLoad(&sh_part_count[0]);
-    if (failed & (STAGE_BINNING | STAGE_TILE_ALLOC | STAGE_PATH_COARSE)) != 0u {
+    if failed != 0u {
+        if wg_id.x == 0u && local_id.x == 0u {
+            // propagate PATH_COUNT failure to path_tiling_setup so it doesn't need to bind config
+            atomicOr(&bump.failed, failed);
+        }
         return;
     }
     let width_in_bins = (config.width_in_tiles + N_TILE_X - 1u) / N_TILE_X;
@@ -431,9 +442,11 @@ fn main(
     }
     if bin_tile_x + tile_x < config.width_in_tiles && bin_tile_y + tile_y < config.height_in_tiles {
         ptcl[cmd_offset] = CMD_END;
+        var blend_ix = 0u;
         if max_blend_depth > BLEND_STACK_SPLIT {
             let scratch_size = max_blend_depth * TILE_WIDTH * TILE_HEIGHT;
-            ptcl[blend_offset] = atomicAdd(&bump.blend, scratch_size);
+            blend_ix = atomicAdd(&bump.blend, scratch_size);
         }
+        ptcl[blend_offset] = blend_ix;
     }
 }
diff --git a/shader/fine.wgsl b/shader/fine.wgsl
index 539ccbe75..64621cb8b 100644
--- a/shader/fine.wgsl
+++ b/shader/fine.wgsl
@@ -867,6 +867,11 @@ fn main(
     @builtin(local_invocation_id) local_id: vec3<u32>,
     @builtin(workgroup_id) wg_id: vec3<u32>,
 ) {
+    if ptcl[0] == ~0u {
+        // An earlier stage has failed, don't try to render.
+        // We use ptcl[0] for this so we don't use up a binding for bump.
+        return;
+    }
     let tile_ix = wg_id.y * config.width_in_tiles + wg_id.x;
     let xy = vec2(f32(global_id.x * PIXELS_PER_THREAD), f32(global_id.y));
     let local_xy = vec2(f32(local_id.x * PIXELS_PER_THREAD), f32(local_id.y));
diff --git a/shader/flatten.wgsl b/shader/flatten.wgsl
index 3f1a855e6..80da1880c 100644
--- a/shader/flatten.wgsl
+++ b/shader/flatten.wgsl
@@ -746,7 +746,9 @@ fn read_path_segment(tag: PathTagData, is_stroke: bool) -> CubicPoints {
 // Writes a line into a the `lines` buffer at a pre-allocated location designated by `line_ix`.
 fn write_line(line_ix: u32, path_ix: u32, p0: vec2f, p1: vec2f) {
     bbox = vec4(min(bbox.xy, min(p0, p1)), max(bbox.zw, max(p0, p1)));
-    lines[line_ix] = LineSoup(path_ix, p0, p1);
+    if line_ix < config.lines_size {
+        lines[line_ix] = LineSoup(path_ix, p0, p1);
+    }
 }
 
 fn write_line_with_transform(line_ix: u32, path_ix: u32, p0: vec2f, p1: vec2f, t: Transform) {
diff --git a/shader/path_count.wgsl b/shader/path_count.wgsl
index 34913eb5a..7de89278d 100644
--- a/shader/path_count.wgsl
+++ b/shader/path_count.wgsl
@@ -15,18 +15,21 @@ struct AtomicTile {
 }
 
 @group(0) @binding(0)
-var<storage, read_write> bump: BumpAllocators;
+var<uniform> config: Config;
 
 @group(0) @binding(1)
-var<storage> lines: array<LineSoup>;
+var<storage, read_write> bump: BumpAllocators;
 
 @group(0) @binding(2)
-var<storage> paths: array<Path>;
+var<storage> lines: array<LineSoup>;
 
 @group(0) @binding(3)
-var<storage, read_write> tile: array<AtomicTile>;
+var<storage> paths: array<Path>;
 
 @group(0) @binding(4)
+var<storage, read_write> tile: array<AtomicTile>;
+
+@group(0) @binding(5)
 var<storage, read_write> seg_counts: array<SegmentCount>;
 
 // number of integer cells spanned by interval defined by a, b
@@ -187,7 +190,10 @@ fn main(
             // Pack two count values into a single u32
             let counts = (seg_within_slice << 16u) | subix;
             let seg_count = SegmentCount(line_ix, counts);
-            seg_counts[seg_base + i - imin] = seg_count;
+            let seg_ix = seg_base + i - imin;
+            if seg_ix < config.seg_counts_size {
+                seg_counts[seg_ix] = seg_count;
+            }
             // Note: since we're iterating, we have a reliable value for
             // last_z.
             last_z = z;
diff --git a/shader/path_count_setup.wgsl b/shader/path_count_setup.wgsl
index 92793590e..a9c4a916a 100644
--- a/shader/path_count_setup.wgsl
+++ b/shader/path_count_setup.wgsl
@@ -16,8 +16,12 @@ let WG_SIZE = 256u;
 
 @compute @workgroup_size(1)
 fn main() {
-    let lines = atomicLoad(&bump.lines);
-    indirect.count_x = (lines + (WG_SIZE - 1u)) / WG_SIZE;
+    if atomicLoad(&bump.failed) != 0u {
+        indirect.count_x = 0u;
+    } else {
+        let lines = atomicLoad(&bump.lines);
+        indirect.count_x = (lines + (WG_SIZE - 1u)) / WG_SIZE;
+    }
     indirect.count_y = 1u;
     indirect.count_z = 1u;
 }
diff --git a/shader/path_tiling_setup.wgsl b/shader/path_tiling_setup.wgsl
index 6fc70c39d..4d5bf2e30 100644
--- a/shader/path_tiling_setup.wgsl
+++ b/shader/path_tiling_setup.wgsl
@@ -11,13 +11,22 @@ var<storage, read_write> bump: BumpAllocators;
 @group(0) @binding(1)
 var<storage, read_write> indirect: IndirectCount;
 
+@group(0) @binding(2)
+var<storage, read_write> ptcl: array<u32>;
+
 // Partition size for path tiling stage
 let WG_SIZE = 256u;
 
 @compute @workgroup_size(1)
 fn main() {
-    let segments = atomicLoad(&bump.seg_counts);
-    indirect.count_x = (segments + (WG_SIZE - 1u)) / WG_SIZE;
+    if atomicLoad(&bump.failed) != 0u {
+        indirect.count_x = 0u;
+        // signal fine rasterizer that failure happened (it doesn't bind bump)
+        ptcl[0] = ~0u;
+    } else {
+        let segments = atomicLoad(&bump.seg_counts);
+        indirect.count_x = (segments + (WG_SIZE - 1u)) / WG_SIZE;
+    }
     indirect.count_y = 1u;
     indirect.count_z = 1u;
 }
diff --git a/shader/shared/bump.wgsl b/shader/shared/bump.wgsl
index 48338f59f..9270fc2f8 100644
--- a/shader/shared/bump.wgsl
+++ b/shader/shared/bump.wgsl
@@ -4,8 +4,9 @@
 // Bitflags for each stage that can fail allocation.
 let STAGE_BINNING: u32 = 0x1u;
 let STAGE_TILE_ALLOC: u32 = 0x2u;
-let STAGE_PATH_COARSE: u32 = 0x4u;
-let STAGE_COARSE: u32 = 0x8u;
+let STAGE_FLATTEN: u32 = 0x4u;
+let STAGE_PATH_COUNT: u32 = 0x8u;
+let STAGE_COARSE: u32 = 0x10u;
 
 // This must be kept in sync with the struct in config.rs in the encoding crate.
 struct BumpAllocators {
diff --git a/shader/shared/config.wgsl b/shader/shared/config.wgsl
index fe8580fd2..ef7b928c4 100644
--- a/shader/shared/config.wgsl
+++ b/shader/shared/config.wgsl
@@ -33,8 +33,10 @@ struct Config {
     style_base: u32,
 
     // Sizes of bump allocated buffers (in element size units)
+    lines_size: u32,
     binning_size: u32,
     tiles_size: u32,
+    seg_counts_size: u32,
     segments_size: u32,
     ptcl_size: u32,
 }
diff --git a/shader/tile_alloc.wgsl b/shader/tile_alloc.wgsl
index 7b05bdceb..c6073d128 100644
--- a/shader/tile_alloc.wgsl
+++ b/shader/tile_alloc.wgsl
@@ -30,7 +30,7 @@ let WG_SIZE = 256u;
 
 var<workgroup> sh_tile_count: array<u32, WG_SIZE>;
 var<workgroup> sh_tile_offset: u32;
-var<workgroup> sh_atomic_failed: u32;
+var<workgroup> sh_previous_failed: u32;
 
 @compute @workgroup_size(256)
 fn main(
@@ -41,10 +41,11 @@ fn main(
     // We need to check only prior stages, as if this stage has failed in another workgroup, 
     // we still want to know this workgroup's memory requirement.
     if local_id.x == 0u {
-        sh_atomic_failed = atomicLoad(&bump.failed);
+        let failed = (atomicLoad(&bump.failed) & (STAGE_BINNING | STAGE_FLATTEN)) != 0u;
+        sh_previous_failed = u32(failed);
     }
-    let failed = workgroupUniformLoad(&sh_atomic_failed);
-    if (failed & STAGE_BINNING) != 0u {
+    let failed = workgroupUniformLoad(&sh_previous_failed);
+    if failed != 0u {
         return;
     }    
     // scale factors useful for converting coordinates to tiles
diff --git a/src/cpu_shader/path_count.rs b/src/cpu_shader/path_count.rs
index ed04d7a1f..80eafc9ab 100644
--- a/src/cpu_shader/path_count.rs
+++ b/src/cpu_shader/path_count.rs
@@ -153,10 +153,11 @@ fn path_count_main(
 }
 
 pub fn path_count(_n_wg: u32, resources: &[CpuBinding]) {
-    let mut bump = resources[0].as_typed_mut();
-    let lines = resources[1].as_slice();
-    let paths = resources[2].as_slice();
-    let mut tile = resources[3].as_slice_mut();
-    let mut seg_counts = resources[4].as_slice_mut();
+    // config is binding 0
+    let mut bump = resources[1].as_typed_mut();
+    let lines = resources[2].as_slice();
+    let paths = resources[3].as_slice();
+    let mut tile = resources[4].as_slice_mut();
+    let mut seg_counts = resources[5].as_slice_mut();
     path_count_main(&mut bump, &lines, &paths, &mut tile, &mut seg_counts);
 }
diff --git a/src/cpu_shader/path_tiling_setup.rs b/src/cpu_shader/path_tiling_setup.rs
index 33ca6bb8d..2b9014505 100644
--- a/src/cpu_shader/path_tiling_setup.rs
+++ b/src/cpu_shader/path_tiling_setup.rs
@@ -17,5 +17,6 @@ fn path_tiling_setup_main(bump: &BumpAllocators, indirect: &mut IndirectCount) {
 pub fn path_tiling_setup(_n_wg: u32, resources: &[CpuBinding]) {
     let bump = resources[0].as_typed();
     let mut indirect = resources[1].as_typed_mut();
+    // binding 2 is ptcl, which we would need if we propagate failure
     path_tiling_setup_main(&bump, &mut indirect);
 }
diff --git a/src/render.rs b/src/render.rs
index 264b241d0..279f1451a 100644
--- a/src/render.rs
+++ b/src/render.rs
@@ -345,7 +345,14 @@ impl Render {
             shaders.path_count,
             indirect_count_buf,
             0,
-            [bump_buf, lines_buf, path_buf, tile_buf, seg_counts_buf],
+            [
+                config_buf,
+                bump_buf,
+                lines_buf,
+                path_buf,
+                tile_buf,
+                seg_counts_buf,
+            ],
         );
         recording.dispatch(
             shaders.backdrop,
@@ -370,7 +377,7 @@ impl Render {
         recording.dispatch(
             shaders.path_tiling_setup,
             wg_counts.path_tiling_setup,
-            [bump_buf, indirect_count_buf.into()],
+            [bump_buf, indirect_count_buf.into(), ptcl_buf],
         );
         recording.dispatch_indirect(
             shaders.path_tiling,
diff --git a/src/shaders.rs b/src/shaders.rs
index fab6b7f1e..e00c1a96a 100644
--- a/src/shaders.rs
+++ b/src/shaders.rs
@@ -222,7 +222,7 @@ pub fn full_shaders(
     let path_count_setup = add_shader!(path_count_setup, [Buffer, Buffer], &empty);
     let path_count = add_shader!(
         path_count,
-        [Buffer, BufReadOnly, BufReadOnly, Buffer, Buffer]
+        [Uniform, Buffer, BufReadOnly, BufReadOnly, Buffer, Buffer]
     );
     let backdrop = add_shader!(
         backdrop_dyn,
@@ -245,7 +245,7 @@ pub fn full_shaders(
         ],
         &empty
     );
-    let path_tiling_setup = add_shader!(path_tiling_setup, [Buffer, Buffer], &empty);
+    let path_tiling_setup = add_shader!(path_tiling_setup, [Buffer, Buffer, Buffer], &empty);
     let path_tiling = add_shader!(
         path_tiling,
         [