From daaa350e10390982a32194038609512e7c5b7940 Mon Sep 17 00:00:00 2001 From: Raph Levien Date: Tue, 19 Mar 2024 10:15:55 -0700 Subject: [PATCH] Allow large numbers of draw objects Previously there was a limit of workgroup size squared for the number of draw objects, which is 64k in practice. This PR makes each workgroup iterate multiple blocks if that limit is exceeded, borrowing a technique from FidelityFX sort. WIP, this causes hangs on mac. Uploading to test on other hardware. Also contains some changes for testing that may not want to be committed as is. Fixes #334 --- crates/encoding/src/config.rs | 12 +- examples/scenes/src/test_scenes.rs | 20 +- shader/draw_leaf.wgsl | 346 +++++++++++++++-------------- shader/draw_reduce.wgsl | 21 +- src/cpu_shader/draw_leaf.rs | 13 +- src/cpu_shader/draw_reduce.rs | 13 +- 6 files changed, 238 insertions(+), 187 deletions(-) diff --git a/crates/encoding/src/config.rs b/crates/encoding/src/config.rs index b6ef8857a..a6a40db86 100644 --- a/crates/encoding/src/config.rs +++ b/crates/encoding/src/config.rs @@ -234,6 +234,7 @@ impl WorkgroupCounts { path_tag_wgs }; let draw_object_wgs = (n_draw_objects + PATH_BBOX_WG - 1) / PATH_BBOX_WG; + let draw_monoid_wgs = draw_object_wgs.min(PATH_BBOX_WG); let flatten_wgs = (n_path_tags + FLATTEN_WG - 1) / FLATTEN_WG; let clip_reduce_wgs = n_clips.saturating_sub(1) / CLIP_REDUCE_WG; let clip_wgs = (n_clips + CLIP_REDUCE_WG - 1) / CLIP_REDUCE_WG; @@ -248,8 +249,8 @@ impl WorkgroupCounts { path_scan: (path_tag_wgs, 1, 1), bbox_clear: (draw_object_wgs, 1, 1), flatten: (flatten_wgs, 1, 1), - draw_reduce: (draw_object_wgs, 1, 1), - draw_leaf: (draw_object_wgs, 1, 1), + draw_reduce: (draw_monoid_wgs, 1, 1), + draw_leaf: (draw_monoid_wgs, 1, 1), clip_reduce: (clip_reduce_wgs, 1, 1), clip_leaf: (clip_wgs, 1, 1), binning: (draw_object_wgs, 1, 1), @@ -364,8 +365,9 @@ impl BufferSizes { let path_reduced_scan = BufferSize::new(path_tag_wgs); let path_monoids = BufferSize::new(path_tag_wgs * PATH_REDUCE_WG); let path_bboxes = BufferSize::new(n_paths); - let draw_object_wgs = workgroups.draw_reduce.0; - let draw_reduced = BufferSize::new(draw_object_wgs); + let binning_wgs = workgroups.binning.0; + let draw_monoid_wgs = workgroups.draw_reduce.0; + let draw_reduced = BufferSize::new(draw_monoid_wgs); let draw_monoids = BufferSize::new(n_draw_objects); let info = BufferSize::new(layout.bin_data_start); let clip_inps = BufferSize::new(n_clips); @@ -375,7 +377,7 @@ impl BufferSizes { let draw_bboxes = BufferSize::new(n_paths); let bump_alloc = BufferSize::new(1); let indirect_count = BufferSize::new(1); - let bin_headers = BufferSize::new(draw_object_wgs * 256); + let bin_headers = BufferSize::new(binning_wgs * 256); let n_paths_aligned = align_up(n_paths, 256); let paths = BufferSize::new(n_paths_aligned); diff --git a/examples/scenes/src/test_scenes.rs b/examples/scenes/src/test_scenes.rs index 02c31ac35..fa6f71712 100644 --- a/examples/scenes/src/test_scenes.rs +++ b/examples/scenes/src/test_scenes.rs @@ -2,7 +2,9 @@ // SPDX-License-Identifier: Apache-2.0 OR MIT use crate::{ExampleScene, SceneConfig, SceneParams, SceneSet}; -use vello::kurbo::{Affine, BezPath, Cap, Ellipse, Join, PathEl, Point, Rect, Shape, Stroke, Vec2}; +use vello::kurbo::{ + Affine, BezPath, Cap, Circle, Ellipse, Join, PathEl, Point, Rect, Shape, Stroke, Vec2, +}; use vello::peniko::*; use vello::*; @@ -31,6 +33,7 @@ macro_rules! scene { pub fn test_scenes() -> SceneSet { let scenes = vec![ + scene!(many_draw), scene!(splash_with_tiger(), "splash_with_tiger", false), scene!(funky_paths), scene!(stroke_styles(Affine::IDENTITY), "stroke_styles", false), @@ -67,6 +70,21 @@ pub fn test_scenes() -> SceneSet { // Scenes +fn many_draw(scene: &mut Scene, _: &mut SceneParams) { + const N_WIDE: usize = 300; + const N_HIGH: usize = 300; + const SCENE_WIDTH: f64 = 2000.0; + const SCENE_HEIGHT: f64 = 1500.0; + for j in 0..N_HIGH { + let y = (j as f64 + 0.5) * (SCENE_HEIGHT / N_HIGH as f64); + for i in 0..N_WIDE { + let x = (i as f64 + 0.5) * (SCENE_WIDTH / N_WIDE as f64); + let c = Circle::new((x, y), 3.0); + scene.fill(Fill::NonZero, Affine::IDENTITY, Color::YELLOW, None, &c); + } + } +} + fn funky_paths(scene: &mut Scene, _: &mut SceneParams) { use PathEl::*; let missing_movetos = [ diff --git a/shader/draw_leaf.wgsl b/shader/draw_leaf.wgsl index 6f1a2e6c2..e5126615d 100644 --- a/shader/draw_leaf.wgsl +++ b/shader/draw_leaf.wgsl @@ -51,11 +51,9 @@ var sh_scratch: array; @compute @workgroup_size(256) fn main( - @builtin(global_invocation_id) global_id: vec3, @builtin(local_invocation_id) local_id: vec3, @builtin(workgroup_id) wg_id: vec3, ) { - let ix = global_id.x; // Reduce prefix of workgroups up to this one var agg = draw_monoid_identity(); if local_id.x < wg_id.x { @@ -74,184 +72,198 @@ fn main( // Two barriers can be eliminated if we use separate shared arrays // for prefix and intra-workgroup prefix sum. workgroupBarrier(); - var m = sh_scratch[0]; - workgroupBarrier(); - let tag_word = read_draw_tag_from_scene(ix); - agg = map_draw_tag(tag_word); - sh_scratch[local_id.x] = agg; - for (var i = 0u; i < firstTrailingBit(WG_SIZE); i += 1u) { + var prefix = sh_scratch[0]; + + let num_blocks_total = (config.n_drawobj + (WG_SIZE - 1u)) / WG_SIZE; + let n_blocks_base = num_blocks_total / WG_SIZE; + let remainder = num_blocks_total % WG_SIZE; + let first_block = n_blocks_base * wg_id.x + min(wg_id.x, remainder); + let n_blocks = n_blocks_base + u32(wg_id.x < remainder); + var ix = first_block * WG_SIZE + local_id.x; + let ix_end = ix + n_blocks * WG_SIZE; + while ix != ix_end { + let tag_word = read_draw_tag_from_scene(ix); + agg = map_draw_tag(tag_word); workgroupBarrier(); - if local_id.x >= 1u << i { - let other = sh_scratch[local_id.x - (1u << i)]; - agg = combine_draw_monoid(agg, other); + sh_scratch[local_id.x] = agg; + for (var i = 0u; i < firstTrailingBit(WG_SIZE); i += 1u) { + workgroupBarrier(); + if local_id.x >= 1u << i { + let other = sh_scratch[local_id.x - (1u << i)]; + agg = combine_draw_monoid(agg, other); + } + workgroupBarrier(); + sh_scratch[local_id.x] = agg; } + var m = prefix; workgroupBarrier(); - sh_scratch[local_id.x] = agg; - } - workgroupBarrier(); - if local_id.x > 0u { - m = combine_draw_monoid(m, sh_scratch[local_id.x - 1u]); - } - // m now contains exclusive prefix sum of draw monoid - if ix < config.n_drawobj { - draw_monoid[ix] = m; - } - let dd = config.drawdata_base + m.scene_offset; - let di = m.info_offset; - if tag_word == DRAWTAG_FILL_COLOR || tag_word == DRAWTAG_FILL_LIN_GRADIENT || - tag_word == DRAWTAG_FILL_RAD_GRADIENT || tag_word == DRAWTAG_FILL_SWEEP_GRADIENT || - tag_word == DRAWTAG_FILL_IMAGE || tag_word == DRAWTAG_BEGIN_CLIP - { - let bbox = path_bbox[m.path_ix]; - // TODO: bbox is mostly yagni here, sort that out. Maybe clips? - // let x0 = f32(bbox.x0); - // let y0 = f32(bbox.y0); - // let x1 = f32(bbox.x1); - // let y1 = f32(bbox.y1); - // let bbox_f = vec4(x0, y0, x1, y1); - var transform = Transform(); - let draw_flags = bbox.draw_flags; - if tag_word == DRAWTAG_FILL_LIN_GRADIENT || tag_word == DRAWTAG_FILL_RAD_GRADIENT || - tag_word == DRAWTAG_FILL_SWEEP_GRADIENT || tag_word == DRAWTAG_FILL_IMAGE - { - transform = read_transform(config.transform_base, bbox.trans_ix); + if local_id.x > 0u { + m = combine_draw_monoid(m, sh_scratch[local_id.x - 1u]); } - switch tag_word { - case DRAWTAG_FILL_COLOR: { - info[di] = draw_flags; - } - case DRAWTAG_FILL_LIN_GRADIENT: { - info[di] = draw_flags; - var p0 = bitcast>(vec2(scene[dd + 1u], scene[dd + 2u])); - var p1 = bitcast>(vec2(scene[dd + 3u], scene[dd + 4u])); - p0 = transform_apply(transform, p0); - p1 = transform_apply(transform, p1); - let dxy = p1 - p0; - let scale = 1.0 / dot(dxy, dxy); - let line_xy = dxy * scale; - let line_c = -dot(p0, line_xy); - info[di + 1u] = bitcast(line_xy.x); - info[di + 2u] = bitcast(line_xy.y); - info[di + 3u] = bitcast(line_c); + // m now contains exclusive prefix sum of draw monoid + if ix < config.n_drawobj { + draw_monoid[ix] = m; + } + let dd = config.drawdata_base + m.scene_offset; + let di = m.info_offset; + if tag_word == DRAWTAG_FILL_COLOR || tag_word == DRAWTAG_FILL_LIN_GRADIENT || + tag_word == DRAWTAG_FILL_RAD_GRADIENT || tag_word == DRAWTAG_FILL_SWEEP_GRADIENT || + tag_word == DRAWTAG_FILL_IMAGE || tag_word == DRAWTAG_BEGIN_CLIP + { + let bbox = path_bbox[m.path_ix]; + // TODO: bbox is mostly yagni here, sort that out. Maybe clips? + // let x0 = f32(bbox.x0); + // let y0 = f32(bbox.y0); + // let x1 = f32(bbox.x1); + // let y1 = f32(bbox.y1); + // let bbox_f = vec4(x0, y0, x1, y1); + var transform = Transform(); + let draw_flags = bbox.draw_flags; + if tag_word == DRAWTAG_FILL_LIN_GRADIENT || tag_word == DRAWTAG_FILL_RAD_GRADIENT || + tag_word == DRAWTAG_FILL_SWEEP_GRADIENT || tag_word == DRAWTAG_FILL_IMAGE + { + transform = read_transform(config.transform_base, bbox.trans_ix); } - case DRAWTAG_FILL_RAD_GRADIENT: { - // Two-point conical gradient implementation based - // on the algorithm at - // This epsilon matches what Skia uses - let GRADIENT_EPSILON = 1.0 / f32(1u << 12u); - info[di] = draw_flags; - var p0 = bitcast>(vec2(scene[dd + 1u], scene[dd + 2u])); - var p1 = bitcast>(vec2(scene[dd + 3u], scene[dd + 4u])); - var r0 = bitcast(scene[dd + 5u]); - var r1 = bitcast(scene[dd + 6u]); - let user_to_gradient = transform_inverse(transform); - // Output variables - var xform = Transform(); - var focal_x = 0.0; - var radius = 0.0; - var kind = 0u; - var flags = 0u; - if abs(r0 - r1) <= GRADIENT_EPSILON { - // When the radii are the same, emit a strip gradient - kind = RAD_GRAD_KIND_STRIP; - let scaled = r0 / distance(p0, p1); - xform = transform_mul( - two_point_to_unit_line(p0, p1), - user_to_gradient - ); - radius = scaled * scaled; - } else { - // Assume a two point conical gradient unless the centers - // are equal. - kind = RAD_GRAD_KIND_CONE; - if all(p0 == p1) { - kind = RAD_GRAD_KIND_CIRCULAR; - // Nudge p0 a bit to avoid denormals. - p0 += GRADIENT_EPSILON; - } - if r1 == 0.0 { - // If r1 == 0.0, swap the points and radii - flags |= RAD_GRAD_SWAPPED; - let tmp_p = p0; - p0 = p1; - p1 = tmp_p; - let tmp_r = r0; - r0 = r1; - r1 = tmp_r; - } - focal_x = r0 / (r0 - r1); - let cf = (1.0 - focal_x) * p0 + focal_x * p1; - radius = r1 / (distance(cf, p1)); - let user_to_unit_line = transform_mul( - two_point_to_unit_line(cf, p1), - user_to_gradient - ); - var user_to_scaled = user_to_unit_line; - // When r == 1.0, focal point is on circle - if abs(radius - 1.0) <= GRADIENT_EPSILON { - kind = RAD_GRAD_KIND_FOCAL_ON_CIRCLE; - let scale = 0.5 * abs(1.0 - focal_x); - user_to_scaled = transform_mul( - Transform(vec4(scale, 0.0, 0.0, scale), vec2(0.0)), - user_to_unit_line + switch tag_word { + case DRAWTAG_FILL_COLOR: { + info[di] = draw_flags; + } + case DRAWTAG_FILL_LIN_GRADIENT: { + info[di] = draw_flags; + var p0 = bitcast>(vec2(scene[dd + 1u], scene[dd + 2u])); + var p1 = bitcast>(vec2(scene[dd + 3u], scene[dd + 4u])); + p0 = transform_apply(transform, p0); + p1 = transform_apply(transform, p1); + let dxy = p1 - p0; + let scale = 1.0 / dot(dxy, dxy); + let line_xy = dxy * scale; + let line_c = -dot(p0, line_xy); + info[di + 1u] = bitcast(line_xy.x); + info[di + 2u] = bitcast(line_xy.y); + info[di + 3u] = bitcast(line_c); + } + case DRAWTAG_FILL_RAD_GRADIENT: { + // Two-point conical gradient implementation based + // on the algorithm at + // This epsilon matches what Skia uses + let GRADIENT_EPSILON = 1.0 / f32(1u << 12u); + info[di] = draw_flags; + var p0 = bitcast>(vec2(scene[dd + 1u], scene[dd + 2u])); + var p1 = bitcast>(vec2(scene[dd + 3u], scene[dd + 4u])); + var r0 = bitcast(scene[dd + 5u]); + var r1 = bitcast(scene[dd + 6u]); + let user_to_gradient = transform_inverse(transform); + // Output variables + var xform = Transform(); + var focal_x = 0.0; + var radius = 0.0; + var kind = 0u; + var flags = 0u; + if abs(r0 - r1) <= GRADIENT_EPSILON { + // When the radii are the same, emit a strip gradient + kind = RAD_GRAD_KIND_STRIP; + let scaled = r0 / distance(p0, p1); + xform = transform_mul( + two_point_to_unit_line(p0, p1), + user_to_gradient ); + radius = scaled * scaled; } else { - let a = radius * radius - 1.0; - let scale_ratio = abs(1.0 - focal_x) / a; - let scale_x = radius * scale_ratio; - let scale_y = sqrt(abs(a)) * scale_ratio; - user_to_scaled = transform_mul( - Transform(vec4(scale_x, 0.0, 0.0, scale_y), vec2(0.0)), - user_to_unit_line + // Assume a two point conical gradient unless the centers + // are equal. + kind = RAD_GRAD_KIND_CONE; + if all(p0 == p1) { + kind = RAD_GRAD_KIND_CIRCULAR; + // Nudge p0 a bit to avoid denormals. + p0 += GRADIENT_EPSILON; + } + if r1 == 0.0 { + // If r1 == 0.0, swap the points and radii + flags |= RAD_GRAD_SWAPPED; + let tmp_p = p0; + p0 = p1; + p1 = tmp_p; + let tmp_r = r0; + r0 = r1; + r1 = tmp_r; + } + focal_x = r0 / (r0 - r1); + let cf = (1.0 - focal_x) * p0 + focal_x * p1; + radius = r1 / (distance(cf, p1)); + let user_to_unit_line = transform_mul( + two_point_to_unit_line(cf, p1), + user_to_gradient ); + var user_to_scaled = user_to_unit_line; + // When r == 1.0, focal point is on circle + if abs(radius - 1.0) <= GRADIENT_EPSILON { + kind = RAD_GRAD_KIND_FOCAL_ON_CIRCLE; + let scale = 0.5 * abs(1.0 - focal_x); + user_to_scaled = transform_mul( + Transform(vec4(scale, 0.0, 0.0, scale), vec2(0.0)), + user_to_unit_line + ); + } else { + let a = radius * radius - 1.0; + let scale_ratio = abs(1.0 - focal_x) / a; + let scale_x = radius * scale_ratio; + let scale_y = sqrt(abs(a)) * scale_ratio; + user_to_scaled = transform_mul( + Transform(vec4(scale_x, 0.0, 0.0, scale_y), vec2(0.0)), + user_to_unit_line + ); + } + xform = user_to_scaled; } - xform = user_to_scaled; + info[di + 1u] = bitcast(xform.matrx.x); + info[di + 2u] = bitcast(xform.matrx.y); + info[di + 3u] = bitcast(xform.matrx.z); + info[di + 4u] = bitcast(xform.matrx.w); + info[di + 5u] = bitcast(xform.translate.x); + info[di + 6u] = bitcast(xform.translate.y); + info[di + 7u] = bitcast(focal_x); + info[di + 8u] = bitcast(radius); + info[di + 9u] = bitcast((flags << 3u) | kind); } - info[di + 1u] = bitcast(xform.matrx.x); - info[di + 2u] = bitcast(xform.matrx.y); - info[di + 3u] = bitcast(xform.matrx.z); - info[di + 4u] = bitcast(xform.matrx.w); - info[di + 5u] = bitcast(xform.translate.x); - info[di + 6u] = bitcast(xform.translate.y); - info[di + 7u] = bitcast(focal_x); - info[di + 8u] = bitcast(radius); - info[di + 9u] = bitcast((flags << 3u) | kind); - } - case DRAWTAG_FILL_SWEEP_GRADIENT: { - info[di] = draw_flags; - let p0 = bitcast>(vec2(scene[dd + 1u], scene[dd + 2u])); - let xform = transform_mul(transform, Transform(vec4(1.0, 0.0, 0.0, 1.0), p0)); - let inv = transform_inverse(xform); - info[di + 1u] = bitcast(inv.matrx.x); - info[di + 2u] = bitcast(inv.matrx.y); - info[di + 3u] = bitcast(inv.matrx.z); - info[di + 4u] = bitcast(inv.matrx.w); - info[di + 5u] = bitcast(inv.translate.x); - info[di + 6u] = bitcast(inv.translate.y); - info[di + 7u] = scene[dd + 3u]; - info[di + 8u] = scene[dd + 4u]; - } - case DRAWTAG_FILL_IMAGE: { - info[di] = draw_flags; - let inv = transform_inverse(transform); - info[di + 1u] = bitcast(inv.matrx.x); - info[di + 2u] = bitcast(inv.matrx.y); - info[di + 3u] = bitcast(inv.matrx.z); - info[di + 4u] = bitcast(inv.matrx.w); - info[di + 5u] = bitcast(inv.translate.x); - info[di + 6u] = bitcast(inv.translate.y); - info[di + 7u] = scene[dd]; - info[di + 8u] = scene[dd + 1u]; + case DRAWTAG_FILL_SWEEP_GRADIENT: { + info[di] = draw_flags; + let p0 = bitcast>(vec2(scene[dd + 1u], scene[dd + 2u])); + let xform = transform_mul(transform, Transform(vec4(1.0, 0.0, 0.0, 1.0), p0)); + let inv = transform_inverse(xform); + info[di + 1u] = bitcast(inv.matrx.x); + info[di + 2u] = bitcast(inv.matrx.y); + info[di + 3u] = bitcast(inv.matrx.z); + info[di + 4u] = bitcast(inv.matrx.w); + info[di + 5u] = bitcast(inv.translate.x); + info[di + 6u] = bitcast(inv.translate.y); + info[di + 7u] = scene[dd + 3u]; + info[di + 8u] = scene[dd + 4u]; + } + case DRAWTAG_FILL_IMAGE: { + info[di] = draw_flags; + let inv = transform_inverse(transform); + info[di + 1u] = bitcast(inv.matrx.x); + info[di + 2u] = bitcast(inv.matrx.y); + info[di + 3u] = bitcast(inv.matrx.z); + info[di + 4u] = bitcast(inv.matrx.w); + info[di + 5u] = bitcast(inv.translate.x); + info[di + 6u] = bitcast(inv.translate.y); + info[di + 7u] = scene[dd]; + info[di + 8u] = scene[dd + 1u]; + } + default: {} } - default: {} } - } - if tag_word == DRAWTAG_BEGIN_CLIP || tag_word == DRAWTAG_END_CLIP { - var path_ix = ~ix; - if tag_word == DRAWTAG_BEGIN_CLIP { - path_ix = m.path_ix; + if tag_word == DRAWTAG_BEGIN_CLIP || tag_word == DRAWTAG_END_CLIP { + var path_ix = ~ix; + if tag_word == DRAWTAG_BEGIN_CLIP { + path_ix = m.path_ix; + } + clip_inp[m.clip_ix] = ClipInp(ix, i32(path_ix)); } - clip_inp[m.clip_ix] = ClipInp(ix, i32(path_ix)); + ix += WG_SIZE; + // break here on end to save monoid aggregation? + prefix = combine_draw_monoid(prefix, sh_scratch[WG_SIZE - 1u]); } } diff --git a/shader/draw_reduce.wgsl b/shader/draw_reduce.wgsl index a39758bd3..4a7fb6dcb 100644 --- a/shader/draw_reduce.wgsl +++ b/shader/draw_reduce.wgsl @@ -13,7 +13,7 @@ var scene: array; @group(0) @binding(2) var reduced: array; -let WG_SIZE = 256u; +const WG_SIZE = 256u; var sh_scratch: array; @@ -21,12 +21,21 @@ var sh_scratch: array; @compute @workgroup_size(256) fn main( - @builtin(global_invocation_id) global_id: vec3, @builtin(local_invocation_id) local_id: vec3, + @builtin(workgroup_id) wg_id: vec3, ) { - let ix = global_id.x; - let tag_word = read_draw_tag_from_scene(ix); - var agg = map_draw_tag(tag_word); + let num_blocks_total = (config.n_drawobj + (WG_SIZE - 1u)) / WG_SIZE; + let n_blocks_base = num_blocks_total / WG_SIZE; + let remainder = num_blocks_total % WG_SIZE; + let first_block = n_blocks_base * wg_id.x + min(wg_id.x, remainder); + let n_blocks = n_blocks_base + u32(wg_id.x < remainder); + var block_index = first_block * WG_SIZE + local_id.x; + var agg = draw_monoid_identity(); + for (var i = 0u; i < n_blocks; i++) { + let tag_word = read_draw_tag_from_scene(block_index); + agg = combine_draw_monoid(agg, map_draw_tag(tag_word)); + block_index += WG_SIZE; + } sh_scratch[local_id.x] = agg; for (var i = 0u; i < firstTrailingBit(WG_SIZE); i += 1u) { workgroupBarrier(); @@ -38,6 +47,6 @@ fn main( sh_scratch[local_id.x] = agg; } if local_id.x == 0u { - reduced[ix >> firstTrailingBit(WG_SIZE)] = agg; + reduced[wg_id.x] = agg; } } diff --git a/src/cpu_shader/draw_leaf.rs b/src/cpu_shader/draw_leaf.rs index df86f6c3b..c1a958c1e 100644 --- a/src/cpu_shader/draw_leaf.rs +++ b/src/cpu_shader/draw_leaf.rs @@ -23,11 +23,16 @@ fn draw_leaf_main( info: &mut [u32], clip_inp: &mut [Clip], ) { + let num_blocks_total = (config.layout.n_draw_objects as usize + (WG_SIZE - 1)) / WG_SIZE; + let n_blocks_base = num_blocks_total / WG_SIZE; + let remainder = num_blocks_total % WG_SIZE; let mut prefix = DrawMonoid::default(); - for i in 0..n_wg { + for i in 0..n_wg as usize { + let first_block = n_blocks_base * i + i.min(remainder); + let n_blocks = n_blocks_base + (i < remainder) as usize; let mut m = prefix; - for j in 0..WG_SIZE { - let ix = i * WG_SIZE as u32 + j as u32; + for j in 0..WG_SIZE * n_blocks { + let ix = (first_block * WG_SIZE) as u32 + j as u32; let tag_raw = read_draw_tag_from_scene(config, scene, ix); let tag_word = DrawTag(tag_raw); // store exclusive prefix sum @@ -185,7 +190,7 @@ fn draw_leaf_main( } m = m_next; } - prefix = prefix.combine(&reduced[i as usize]); + prefix = prefix.combine(&reduced[i]); } } diff --git a/src/cpu_shader/draw_reduce.rs b/src/cpu_shader/draw_reduce.rs index 24e15a134..bc2104efc 100644 --- a/src/cpu_shader/draw_reduce.rs +++ b/src/cpu_shader/draw_reduce.rs @@ -10,14 +10,19 @@ use super::util::read_draw_tag_from_scene; const WG_SIZE: usize = 256; fn draw_reduce_main(n_wg: u32, config: &ConfigUniform, scene: &[u32], reduced: &mut [DrawMonoid]) { - for i in 0..n_wg { + let num_blocks_total = (config.layout.n_draw_objects as usize + (WG_SIZE - 1)) / WG_SIZE; + let n_blocks_base = num_blocks_total / WG_SIZE; + let remainder = num_blocks_total % WG_SIZE; + for i in 0..n_wg as usize { + let first_block = n_blocks_base * i + i.min(remainder); + let n_blocks = n_blocks_base + (i < remainder) as usize; let mut m = DrawMonoid::default(); - for j in 0..WG_SIZE { - let ix = i * WG_SIZE as u32 + j as u32; + for j in 0..WG_SIZE * n_blocks { + let ix = (first_block * WG_SIZE) as u32 + j as u32; let tag = read_draw_tag_from_scene(config, scene, ix); m = m.combine(&DrawMonoid::new(DrawTag(tag))); } - reduced[i as usize] = m; + reduced[i] = m; } }