From 3ed213791651ea697af384987e60f5ad3a194005 Mon Sep 17 00:00:00 2001 From: Nate Morrical Date: Sun, 8 Dec 2024 22:36:49 -0800 Subject: [PATCH] updating sample 13 to use newer pointer compatible atomics --- samples/s13-differentiable/deviceCode.slang | 187 +++++++++++++++----- samples/s13-differentiable/hostCode.cpp | 24 +-- samples/s13-differentiable/sharedCode.h | 22 +-- 3 files changed, 169 insertions(+), 64 deletions(-) diff --git a/samples/s13-differentiable/deviceCode.slang b/samples/s13-differentiable/deviceCode.slang index 3039492..630f870 100644 --- a/samples/s13-differentiable/deviceCode.slang +++ b/samples/s13-differentiable/deviceCode.slang @@ -25,19 +25,19 @@ CompositeGui(uint3 DispatchThreadID: SV_DispatchThreadID, uniform CompositeGuiCo float2 uv = (fragCoord) / float2(pc.fbSize); // Load color of the rendered image - float4 imageColor = gprt::load(pc.imageBuffer, fbOfs); - - // Gamma correction - imageColor = pow(imageColor, 1.f / 2.2f); + float4 imageColor = pc.imageBuffer[fbOfs]; // Sample the color from the GUI texture SamplerState sampler = gprt::getDefaultSampler(); Texture2D guiTexture = gprt::getTexture2DHandle(pc.guiTexture); float4 guiColor = guiTexture.SampleGrad(sampler, uv, float2(0.f, 0.f), float2(0.f, 0.f)); + // Gamma correction + guiColor = pow(guiColor, 2.2f); + // Composite the GUI on top of the scene float4 pixelColor = over(guiColor, imageColor); - gprt::store(pc.frameBuffer, fbOfs, gprt::make_bgra(pixelColor)); + pc.frameBuffer[fbOfs] = gprt::make_bgra(pixelColor); } [[vk::push_constant]] @@ -66,7 +66,7 @@ raygen(uniform RayGenData record) { TraceRay(obbAccel, RAY_FLAG_FORCE_OPAQUE, 0xff, 0, 1, /*miss type*/ 1, rayDesc, payload); const int fbOfs = pixelID.x + fbSize.x * pixelID.y; - gprt::store(record.imageBuffer, fbOfs, float4(payload.color, 1.f)); + record.imageBuffer[fbOfs] = float4(payload.color, 1.f); } [shader("closesthit")] @@ -120,9 +120,116 @@ eul_to_mat3(float3 eul) { return mat; } +void +atomicAccumulate(float *derivBufferPtr, uint idx, float val) { + // No need to accumulate zeros. + if (val == 0.f) + return; + + Atomic *derivBuffer = (Atomic *) (derivBufferPtr); + + // Loop for as long as the compareExchange() fails, which means another thread + // is trying to write to the same location. + // + for (;;) { + uint oldInt = derivBuffer[idx].load(); + float oldFloat = asfloat(oldInt); + + float newFloat = oldFloat + val; + + uint newInt = asuint(newFloat); + + // compareExchange() returns the value at the location before the operation. + // If it's changed, we have contention between threads & need to try again. + // + if (derivBuffer[idx].compareExchange(oldInt, newInt) == oldInt) + break; + } +} + +// no_diff float +// atomicMin32f(no_diff in Buffer buffer, no_diff uint32_t index, inout float value) { +// uint ret_i = asuint(buffers[buffer.index].Load(index * sizeof(float))); +// while (value < asfloat(ret_i)) { +// uint old = ret_i; +// buffers[buffer.index].InterlockedCompareExchange(index * sizeof(float), old, asuint(value), ret_i); +// if (ret_i == old) +// break; +// } +// return asfloat(ret_i); +// } + +no_diff float +atomicMin(no_diff uint64_t bufferAddr, no_diff uint idx, inout float val) { + Atomic *buffer = (Atomic *) (((float *) bufferAddr)); + + // Loop for as long as the compareExchange() fails, which means another thread + // is trying to write to the same location. + // + for (;;) { + uint oldInt = buffer[idx].load(); + float oldFloat = asfloat(oldInt); + + val = min(oldFloat, val); + + uint newInt = asuint(val); + + // compareExchange() returns the value at the location before the operation. + // If it's changed, we have contention between threads & need to try again. + // + if (buffer[idx].compareExchange(oldInt, newInt) == oldInt) + break; + } + return val; +} + +no_diff float +atomicMax(uint64_t bufferAddr, uint idx, inout float val) { + Atomic *buffer = (Atomic *) (((float *) bufferAddr)); + + // Loop for as long as the compareExchange() fails, which means another thread + // is trying to write to the same location. + // + for (;;) { + uint oldInt = buffer[idx].load(); + float oldFloat = asfloat(oldInt); + + val = max(oldFloat, val); + + uint newInt = asuint(val); + + // compareExchange() returns the value at the location before the operation. + // If it's changed, we have contention between threads & need to try again. + // + if (buffer[idx].compareExchange(oldInt, newInt) == oldInt) + break; + } + return val; +} + +// Derivative of an atomic max is discontinuous, where only an exact match between value +// and the currently stored max is non-zero +[BackwardDerivativeOf(atomicMin)] +void +atomicMin(in uint64_t bufferAddr, uint32_t index, inout DifferentialPair value) { + float *buffer = (float *) bufferAddr; + float ret = buffer[index]; + value = diffPair(value.p, abs(value.p - ret) < .000001f ? ret : 0.f); +} + +// Derivative of an atomic max is discontinuous, where only an exact match between value +// and the currently stored max is non-zero +[BackwardDerivativeOf(atomicMax)] +void +atomicMax(in uint64_t bufferAddr, uint32_t index, inout DifferentialPair value) { + float *buffer = (float *) bufferAddr; + float ret = buffer[index]; + value = diffPair(value.p, abs(value.p - ret) < .000001f ? ret : 0.f); +} + [Differentiable] float -computeOBB(float3 eul, no_diff float3 a, no_diff float3 b, no_diff float3 c, no_diff gprt::Buffer aabbBuffer) { +computeOBB(float3 eul, no_diff float3 a, no_diff float3 b, no_diff float3 c, no_diff uint64_t aabbsAddr) { // Compute triangle OBB float3x3 rot = eul_to_mat3(eul); float3 aabbMin = float3(1e38f); @@ -138,12 +245,12 @@ computeOBB(float3 eul, no_diff float3 a, no_diff float3 b, no_diff float3 c, no_ aabbMax = max(aabbMax, c); // In forward pass, we atomically min/max the OBB - gprt::atomicMin32f(aabbBuffer, 0, aabbMin.x); - gprt::atomicMin32f(aabbBuffer, 1, aabbMin.y); - gprt::atomicMin32f(aabbBuffer, 2, aabbMin.z); - gprt::atomicMax32f(aabbBuffer, 3, aabbMax.x); - gprt::atomicMax32f(aabbBuffer, 4, aabbMax.y); - gprt::atomicMax32f(aabbBuffer, 5, aabbMax.z); + atomicMin(aabbsAddr, 0, aabbMin.x); + atomicMin(aabbsAddr, 1, aabbMin.y); + atomicMin(aabbsAddr, 2, aabbMin.z); + atomicMax(aabbsAddr, 3, aabbMax.x); + atomicMax(aabbsAddr, 4, aabbMax.y); + atomicMax(aabbsAddr, 5, aabbMax.z); // Note, in the backward pass, combined aabb min and max will take on the // atomically combined OBB @@ -156,11 +263,11 @@ void ClearOBB(uint3 DispatchThreadID: SV_DispatchThreadID, uniform ComputeOBBConstants params) { if (DispatchThreadID.x >= 1) return; - gprt::store(params.aabbs, 0, float3(+1e38f)); - gprt::store(params.aabbs, 1, float3(-1e38f)); + params.aabbs[0] = float3(+1e38f); + params.aabbs[1] = float3(-1e38f); // Clear the gradient - gprt::store(params.eulRots, 1, float3(0.f)); + params.eulRots[1] = float3(0.f); } [shader("compute")] @@ -171,22 +278,20 @@ ComputeOBB(uint3 DispatchThreadID: SV_DispatchThreadID, uniform ComputeOBBConsta return; // temp int triID = DispatchThreadID.x; - uint3 tri = gprt::load(params.indices, triID); - float3 a = gprt::load(params.vertices, tri.x); - float3 b = gprt::load(params.vertices, tri.y); - float3 c = gprt::load(params.vertices, tri.z); + uint3 tri = params.indices[triID]; + float3 a = params.vertices[tri.x]; + float3 b = params.vertices[tri.y]; + float3 c = params.vertices[tri.z]; // Current euler rotation - float3 eul = gprt::load(params.eulRots, 0); + float3 eul = params.eulRots[0]; // Compute the current OBB - computeOBB(eul, a, b, c, params.aabbs); + computeOBB(eul, a, b, c, (uint64_t) params.aabbs); // Also update the visualization of the OBB float3x3 rot = eul_to_mat3(eul); - gprt::Instance instance = gprt::load(params.instance, 0); - instance.transform = float3x4(float4(rot[0], 0), float4(rot[1], 0), float4(rot[2], 0)); - gprt::store(params.instance, 0, instance); + params.instance->transform = float3x4(float4(rot[0], 0), float4(rot[1], 0), float4(rot[2], 0)); } [shader("compute")] @@ -197,30 +302,30 @@ BackPropOBB(uint3 DispatchThreadID: SV_DispatchThreadID, uniform ComputeOBBConst return; // temp int triID = DispatchThreadID.x; - uint3 tri = gprt::load(params.indices, triID); - float3 a = gprt::load(params.vertices, tri.x); - float3 b = gprt::load(params.vertices, tri.y); - float3 c = gprt::load(params.vertices, tri.z); + uint3 tri = params.indices[triID]; + float3 a = params.vertices[tri.x]; + float3 b = params.vertices[tri.y]; + float3 c = params.vertices[tri.z]; // Current euler rotation - float3 eul = gprt::load(params.eulRots, 0); + float3 eul = params.eulRots[0]; DifferentialPair diffEul = diffPair(eul, float3(0)); - bwd_diff(computeOBB)(diffEul, a, b, c, params.aabbs, /*dSurfaceArea*/ 1.0f); + bwd_diff(computeOBB)(diffEul, a, b, c, (uint64_t) params.aabbs, /*dSurfaceArea*/ 1.0f); // In our application, the gradient is differentiated rotation float3 grad = diffEul.d; - gprt::atomicAdd32f(params.eulRots, 3 + 0, grad.x); - gprt::atomicAdd32f(params.eulRots, 3 + 1, grad.y); - gprt::atomicAdd32f(params.eulRots, 3 + 2, grad.z); + atomicAccumulate((float *) params.eulRots, 3 + 0, grad.x); + atomicAccumulate((float *) params.eulRots, 3 + 1, grad.y); + atomicAccumulate((float *) params.eulRots, 3 + 2, grad.z); // Have thread 0 report the current surface area if (DispatchThreadID.x == 0) { - float3 aabbMin = gprt::load(params.aabbs, 0); - float3 aabbMax = gprt::load(params.aabbs, 1); + float3 aabbMin = params.aabbs[0]; + float3 aabbMax = params.aabbs[1]; float SA = getSurfaceArea(aabbMin, aabbMax); - gprt::store(params.eulRots, 2, float3(SA, 0.f, 0.f)); + params.eulRots[2] = float3(SA, 0.f, 0.f); } } @@ -242,8 +347,8 @@ intersectBoundingBox(uniform BoundingBoxData record) { BBoxAttributes attr; // raytrace bounding box - float3 bbmin = gprt::load(record.aabbs, 0); - float3 bbmax = gprt::load(record.aabbs, 1); + float3 bbmin = record.aabbs[0]; + float3 bbmax = record.aabbs[1]; attr.cen = 0.5 * (bbmin + bbmax); attr.rad = 0.5 * (bbmax - bbmin); @@ -285,8 +390,8 @@ hitBoundingBox(uniform BoundingBoxData record, inout Payload payload, in Boundin float tcur = payload.tHit; // raytrace bounding box - float3 bbmin = gprt::load(record.aabbs, 0); - float3 bbmax = gprt::load(record.aabbs, 1); + float3 bbmin = record.aabbs[0]; + float3 bbmax = record.aabbs[1]; float3 cen = 0.5 * (bbmin + bbmax); float3 rad = 0.5 * (bbmax - bbmin); diff --git a/samples/s13-differentiable/hostCode.cpp b/samples/s13-differentiable/hostCode.cpp index 507c810..b854353 100644 --- a/samples/s13-differentiable/hostCode.cpp +++ b/samples/s13-differentiable/hostCode.cpp @@ -77,7 +77,7 @@ template struct Mesh { GPRTGeomOf geometry; GPRTAccel accel; - Mesh(){}; + Mesh() {}; Mesh(GPRTContext context, GPRTGeomTypeOf geomType, T generator) { // Use the generator to generate vertices and indices auto vertGenerator = generator.vertices(); @@ -103,8 +103,8 @@ template struct Mesh { gprtTrianglesSetVertices(geometry, vertexBuffer, vertices.size()); gprtTrianglesSetIndices(geometry, indexBuffer, indices.size()); TrianglesGeomData *geomData = gprtGeomGetParameters(geometry); - geomData->vertex = gprtBufferGetHandle(vertexBuffer); - geomData->index = gprtBufferGetHandle(indexBuffer); + geomData->vertex = gprtBufferGetDevicePointer(vertexBuffer); + geomData->index = gprtBufferGetDevicePointer(indexBuffer); // Build the bottom level acceleration structure accel = gprtTriangleAccelCreate(context, 1, &geometry); @@ -188,7 +188,7 @@ main(int ac, char **av) { // Raygen program frame buffer RayGenData *rayGenData = gprtRayGenGetParameters(rayGen); - rayGenData->imageBuffer = gprtBufferGetHandle(imageBuffer); + rayGenData->imageBuffer = gprtBufferGetDevicePointer(imageBuffer); // Miss program checkerboard background colors MissProgData *missData = gprtMissGetParameters(triMiss); @@ -238,7 +238,7 @@ main(int ac, char **av) { // Create initial aabb geometry GPRTGeomOf aabbGeom = gprtGeomCreate(context, aabbType); BoundingBoxData *aabbGeomData = gprtGeomGetParameters(aabbGeom); - aabbGeomData->aabbs = gprtBufferGetHandle(aabbPositions); + aabbGeomData->aabbs = gprtBufferGetDevicePointer(aabbPositions); gprtAABBsSetPositions(aabbGeom, aabbPositions, 1); // Place that geometry into an AABB BLAS. @@ -259,11 +259,11 @@ main(int ac, char **av) { gprtBuildShaderBindingTable(context, GPRT_SBT_ALL); ComputeOBBConstants obbPC; - obbPC.aabbs = gprtBufferGetHandle(aabbPositions); - obbPC.eulRots = gprtBufferGetHandle(eulRots); - obbPC.vertices = gprtBufferGetHandle(mesh.vertexBuffer); - obbPC.indices = gprtBufferGetHandle(mesh.indexBuffer); - obbPC.instance = gprtBufferGetHandle(aabbInstanceBuffer); + obbPC.aabbs = gprtBufferGetDevicePointer(aabbPositions); + obbPC.eulRots = gprtBufferGetDevicePointer(eulRots); + obbPC.vertices = gprtBufferGetDevicePointer(mesh.vertexBuffer); + obbPC.indices = gprtBufferGetDevicePointer(mesh.indexBuffer); + obbPC.instance = gprtBufferGetDevicePointer(aabbInstanceBuffer); obbPC.numIndices = mesh.indices.size(); obbPC.numTrisToInclude = mesh.indices.size(); @@ -280,8 +280,8 @@ main(int ac, char **av) { CompositeGuiConstants guiPC; guiPC.fbSize = fbSize; - guiPC.frameBuffer = gprtBufferGetHandle(frameBuffer); - guiPC.imageBuffer = gprtBufferGetHandle(imageBuffer); + guiPC.frameBuffer = gprtBufferGetDevicePointer(frameBuffer); + guiPC.imageBuffer = gprtBufferGetDevicePointer(imageBuffer); guiPC.guiTexture = gprtTextureGetHandle(guiColorAttachment); RTPushConstants rtPC; diff --git a/samples/s13-differentiable/sharedCode.h b/samples/s13-differentiable/sharedCode.h index bfba8cd..35b25c9 100644 --- a/samples/s13-differentiable/sharedCode.h +++ b/samples/s13-differentiable/sharedCode.h @@ -23,16 +23,16 @@ #include "gprt.h" struct TrianglesGeomData { - gprt::Buffer vertex; - gprt::Buffer index; + float3 *vertex; + uint3 *index; }; struct BoundingBoxData { - gprt::Buffer aabbs; + float3 *aabbs; }; struct RayGenData { - gprt::Buffer imageBuffer; + float4 *imageBuffer; gprt::Accel triangleTLAS; gprt::Accel obbAccel; }; @@ -56,17 +56,17 @@ struct RTPushConstants { struct ComputeOBBConstants { int numIndices; - gprt::Buffer vertices; - gprt::Buffer indices; - gprt::Buffer eulRots; - gprt::Buffer aabbs; - gprt::Buffer instance; + float3 *vertices; + uint3 *indices; + float3 *eulRots; + float3 *aabbs; + gprt::Instance *instance; int numTrisToInclude; }; struct CompositeGuiConstants { uint2 fbSize; - gprt::Buffer imageBuffer; - gprt::Buffer frameBuffer; + float4 *imageBuffer; + uint *frameBuffer; gprt::Texture guiTexture; };