Skip to content

Commit

Permalink
updating sample 13 to use newer pointer compatible atomics
Browse files Browse the repository at this point in the history
  • Loading branch information
natevm committed Dec 9, 2024
1 parent 1ae443d commit 3ed2137
Show file tree
Hide file tree
Showing 3 changed files with 169 additions and 64 deletions.
187 changes: 146 additions & 41 deletions samples/s13-differentiable/deviceCode.slang
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,19 @@ CompositeGui(uint3 DispatchThreadID: SV_DispatchThreadID, uniform CompositeGuiCo
float2 uv = (fragCoord) / float2(pc.fbSize);

// Load color of the rendered image
float4 imageColor = gprt::load<float4>(pc.imageBuffer, fbOfs);

// Gamma correction
imageColor = pow(imageColor, 1.f / 2.2f);
float4 imageColor = pc.imageBuffer[fbOfs];

// Sample the color from the GUI texture
SamplerState sampler = gprt::getDefaultSampler();
Texture2D guiTexture = gprt::getTexture2DHandle(pc.guiTexture);
float4 guiColor = guiTexture.SampleGrad(sampler, uv, float2(0.f, 0.f), float2(0.f, 0.f));

// Gamma correction
guiColor = pow(guiColor, 2.2f);

// Composite the GUI on top of the scene
float4 pixelColor = over(guiColor, imageColor);
gprt::store(pc.frameBuffer, fbOfs, gprt::make_bgra(pixelColor));
pc.frameBuffer[fbOfs] = gprt::make_bgra(pixelColor);
}

[[vk::push_constant]]
Expand Down Expand Up @@ -66,7 +66,7 @@ raygen(uniform RayGenData record) {
TraceRay(obbAccel, RAY_FLAG_FORCE_OPAQUE, 0xff, 0, 1, /*miss type*/ 1, rayDesc, payload);

const int fbOfs = pixelID.x + fbSize.x * pixelID.y;
gprt::store(record.imageBuffer, fbOfs, float4(payload.color, 1.f));
record.imageBuffer[fbOfs] = float4(payload.color, 1.f);
}

[shader("closesthit")]
Expand Down Expand Up @@ -120,9 +120,116 @@ eul_to_mat3(float3 eul) {
return mat;
}

void
atomicAccumulate(float *derivBufferPtr, uint idx, float val) {
// No need to accumulate zeros.
if (val == 0.f)
return;

Atomic<uint> *derivBuffer = (Atomic<uint> *) (derivBufferPtr);

// Loop for as long as the compareExchange() fails, which means another thread
// is trying to write to the same location.
//
for (;;) {
uint oldInt = derivBuffer[idx].load();
float oldFloat = asfloat(oldInt);

float newFloat = oldFloat + val;

uint newInt = asuint(newFloat);

// compareExchange() returns the value at the location before the operation.
// If it's changed, we have contention between threads & need to try again.
//
if (derivBuffer[idx].compareExchange(oldInt, newInt) == oldInt)
break;
}
}

// no_diff float
// atomicMin32f(no_diff in Buffer buffer, no_diff uint32_t index, inout float value) {
// uint ret_i = asuint(buffers[buffer.index].Load<float>(index * sizeof(float)));
// while (value < asfloat(ret_i)) {
// uint old = ret_i;
// buffers[buffer.index].InterlockedCompareExchange(index * sizeof(float), old, asuint(value), ret_i);
// if (ret_i == old)
// break;
// }
// return asfloat(ret_i);
// }

no_diff float
atomicMin(no_diff uint64_t bufferAddr, no_diff uint idx, inout float val) {
Atomic<uint> *buffer = (Atomic<uint> *) (((float *) bufferAddr));

// Loop for as long as the compareExchange() fails, which means another thread
// is trying to write to the same location.
//
for (;;) {
uint oldInt = buffer[idx].load();
float oldFloat = asfloat(oldInt);

val = min(oldFloat, val);

uint newInt = asuint(val);

// compareExchange() returns the value at the location before the operation.
// If it's changed, we have contention between threads & need to try again.
//
if (buffer[idx].compareExchange(oldInt, newInt) == oldInt)
break;
}
return val;
}

no_diff float
atomicMax(uint64_t bufferAddr, uint idx, inout float val) {
Atomic<uint> *buffer = (Atomic<uint> *) (((float *) bufferAddr));

// Loop for as long as the compareExchange() fails, which means another thread
// is trying to write to the same location.
//
for (;;) {
uint oldInt = buffer[idx].load();
float oldFloat = asfloat(oldInt);

val = max(oldFloat, val);

uint newInt = asuint(val);

// compareExchange() returns the value at the location before the operation.
// If it's changed, we have contention between threads & need to try again.
//
if (buffer[idx].compareExchange(oldInt, newInt) == oldInt)
break;
}
return val;
}

// Derivative of an atomic max is discontinuous, where only an exact match between value
// and the currently stored max is non-zero
[BackwardDerivativeOf(atomicMin)]
void
atomicMin(in uint64_t bufferAddr, uint32_t index, inout DifferentialPair<float> value) {
float *buffer = (float *) bufferAddr;
float ret = buffer[index];
value = diffPair(value.p, abs(value.p - ret) < .000001f ? ret : 0.f);
}

// Derivative of an atomic max is discontinuous, where only an exact match between value
// and the currently stored max is non-zero
[BackwardDerivativeOf(atomicMax)]
void
atomicMax(in uint64_t bufferAddr, uint32_t index, inout DifferentialPair<float> value) {
float *buffer = (float *) bufferAddr;
float ret = buffer[index];
value = diffPair(value.p, abs(value.p - ret) < .000001f ? ret : 0.f);
}

[Differentiable]
float
computeOBB(float3 eul, no_diff float3 a, no_diff float3 b, no_diff float3 c, no_diff gprt::Buffer aabbBuffer) {
computeOBB(float3 eul, no_diff float3 a, no_diff float3 b, no_diff float3 c, no_diff uint64_t aabbsAddr) {
// Compute triangle OBB
float3x3 rot = eul_to_mat3(eul);
float3 aabbMin = float3(1e38f);
Expand All @@ -138,12 +245,12 @@ computeOBB(float3 eul, no_diff float3 a, no_diff float3 b, no_diff float3 c, no_
aabbMax = max(aabbMax, c);

// In forward pass, we atomically min/max the OBB
gprt::atomicMin32f(aabbBuffer, 0, aabbMin.x);
gprt::atomicMin32f(aabbBuffer, 1, aabbMin.y);
gprt::atomicMin32f(aabbBuffer, 2, aabbMin.z);
gprt::atomicMax32f(aabbBuffer, 3, aabbMax.x);
gprt::atomicMax32f(aabbBuffer, 4, aabbMax.y);
gprt::atomicMax32f(aabbBuffer, 5, aabbMax.z);
atomicMin(aabbsAddr, 0, aabbMin.x);
atomicMin(aabbsAddr, 1, aabbMin.y);
atomicMin(aabbsAddr, 2, aabbMin.z);
atomicMax(aabbsAddr, 3, aabbMax.x);
atomicMax(aabbsAddr, 4, aabbMax.y);
atomicMax(aabbsAddr, 5, aabbMax.z);

// Note, in the backward pass, combined aabb min and max will take on the
// atomically combined OBB
Expand All @@ -156,11 +263,11 @@ void
ClearOBB(uint3 DispatchThreadID: SV_DispatchThreadID, uniform ComputeOBBConstants params) {
if (DispatchThreadID.x >= 1)
return;
gprt::store<float3>(params.aabbs, 0, float3(+1e38f));
gprt::store<float3>(params.aabbs, 1, float3(-1e38f));
params.aabbs[0] = float3(+1e38f);
params.aabbs[1] = float3(-1e38f);

// Clear the gradient
gprt::store<float3>(params.eulRots, 1, float3(0.f));
params.eulRots[1] = float3(0.f);
}

[shader("compute")]
Expand All @@ -171,22 +278,20 @@ ComputeOBB(uint3 DispatchThreadID: SV_DispatchThreadID, uniform ComputeOBBConsta
return; // temp
int triID = DispatchThreadID.x;

uint3 tri = gprt::load<uint3>(params.indices, triID);
float3 a = gprt::load<float3>(params.vertices, tri.x);
float3 b = gprt::load<float3>(params.vertices, tri.y);
float3 c = gprt::load<float3>(params.vertices, tri.z);
uint3 tri = params.indices[triID];
float3 a = params.vertices[tri.x];
float3 b = params.vertices[tri.y];
float3 c = params.vertices[tri.z];

// Current euler rotation
float3 eul = gprt::load<float3>(params.eulRots, 0);
float3 eul = params.eulRots[0];

// Compute the current OBB
computeOBB(eul, a, b, c, params.aabbs);
computeOBB(eul, a, b, c, (uint64_t) params.aabbs);

// Also update the visualization of the OBB
float3x3 rot = eul_to_mat3(eul);
gprt::Instance instance = gprt::load<gprt::Instance>(params.instance, 0);
instance.transform = float3x4(float4(rot[0], 0), float4(rot[1], 0), float4(rot[2], 0));
gprt::store<gprt::Instance>(params.instance, 0, instance);
params.instance->transform = float3x4(float4(rot[0], 0), float4(rot[1], 0), float4(rot[2], 0));
}

[shader("compute")]
Expand All @@ -197,30 +302,30 @@ BackPropOBB(uint3 DispatchThreadID: SV_DispatchThreadID, uniform ComputeOBBConst
return; // temp
int triID = DispatchThreadID.x;

uint3 tri = gprt::load<uint3>(params.indices, triID);
float3 a = gprt::load<float3>(params.vertices, tri.x);
float3 b = gprt::load<float3>(params.vertices, tri.y);
float3 c = gprt::load<float3>(params.vertices, tri.z);
uint3 tri = params.indices[triID];
float3 a = params.vertices[tri.x];
float3 b = params.vertices[tri.y];
float3 c = params.vertices[tri.z];

// Current euler rotation
float3 eul = gprt::load<float3>(params.eulRots, 0);
float3 eul = params.eulRots[0];
DifferentialPair<float3> diffEul = diffPair(eul, float3(0));

bwd_diff(computeOBB)(diffEul, a, b, c, params.aabbs, /*dSurfaceArea*/ 1.0f);
bwd_diff(computeOBB)(diffEul, a, b, c, (uint64_t) params.aabbs, /*dSurfaceArea*/ 1.0f);

// In our application, the gradient is differentiated rotation
float3 grad = diffEul.d;

gprt::atomicAdd32f(params.eulRots, 3 + 0, grad.x);
gprt::atomicAdd32f(params.eulRots, 3 + 1, grad.y);
gprt::atomicAdd32f(params.eulRots, 3 + 2, grad.z);
atomicAccumulate((float *) params.eulRots, 3 + 0, grad.x);
atomicAccumulate((float *) params.eulRots, 3 + 1, grad.y);
atomicAccumulate((float *) params.eulRots, 3 + 2, grad.z);

// Have thread 0 report the current surface area
if (DispatchThreadID.x == 0) {
float3 aabbMin = gprt::load<float3>(params.aabbs, 0);
float3 aabbMax = gprt::load<float3>(params.aabbs, 1);
float3 aabbMin = params.aabbs[0];
float3 aabbMax = params.aabbs[1];
float SA = getSurfaceArea(aabbMin, aabbMax);
gprt::store<float3>(params.eulRots, 2, float3(SA, 0.f, 0.f));
params.eulRots[2] = float3(SA, 0.f, 0.f);
}
}

Expand All @@ -242,8 +347,8 @@ intersectBoundingBox(uniform BoundingBoxData record) {
BBoxAttributes attr;

// raytrace bounding box
float3 bbmin = gprt::load<float3>(record.aabbs, 0);
float3 bbmax = gprt::load<float3>(record.aabbs, 1);
float3 bbmin = record.aabbs[0];
float3 bbmax = record.aabbs[1];

attr.cen = 0.5 * (bbmin + bbmax);
attr.rad = 0.5 * (bbmax - bbmin);
Expand Down Expand Up @@ -285,8 +390,8 @@ hitBoundingBox(uniform BoundingBoxData record, inout Payload payload, in Boundin
float tcur = payload.tHit;

// raytrace bounding box
float3 bbmin = gprt::load<float3>(record.aabbs, 0);
float3 bbmax = gprt::load<float3>(record.aabbs, 1);
float3 bbmin = record.aabbs[0];
float3 bbmax = record.aabbs[1];

float3 cen = 0.5 * (bbmin + bbmax);
float3 rad = 0.5 * (bbmax - bbmin);
Expand Down
24 changes: 12 additions & 12 deletions samples/s13-differentiable/hostCode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ template <typename T> struct Mesh {
GPRTGeomOf<TrianglesGeomData> geometry;
GPRTAccel accel;

Mesh(){};
Mesh() {};
Mesh(GPRTContext context, GPRTGeomTypeOf<TrianglesGeomData> geomType, T generator) {
// Use the generator to generate vertices and indices
auto vertGenerator = generator.vertices();
Expand All @@ -103,8 +103,8 @@ template <typename T> struct Mesh {
gprtTrianglesSetVertices(geometry, vertexBuffer, vertices.size());
gprtTrianglesSetIndices(geometry, indexBuffer, indices.size());
TrianglesGeomData *geomData = gprtGeomGetParameters(geometry);
geomData->vertex = gprtBufferGetHandle(vertexBuffer);
geomData->index = gprtBufferGetHandle(indexBuffer);
geomData->vertex = gprtBufferGetDevicePointer(vertexBuffer);
geomData->index = gprtBufferGetDevicePointer(indexBuffer);

// Build the bottom level acceleration structure
accel = gprtTriangleAccelCreate(context, 1, &geometry);
Expand Down Expand Up @@ -188,7 +188,7 @@ main(int ac, char **av) {

// Raygen program frame buffer
RayGenData *rayGenData = gprtRayGenGetParameters(rayGen);
rayGenData->imageBuffer = gprtBufferGetHandle(imageBuffer);
rayGenData->imageBuffer = gprtBufferGetDevicePointer(imageBuffer);

// Miss program checkerboard background colors
MissProgData *missData = gprtMissGetParameters(triMiss);
Expand Down Expand Up @@ -238,7 +238,7 @@ main(int ac, char **av) {
// Create initial aabb geometry
GPRTGeomOf<BoundingBoxData> aabbGeom = gprtGeomCreate<BoundingBoxData>(context, aabbType);
BoundingBoxData *aabbGeomData = gprtGeomGetParameters(aabbGeom);
aabbGeomData->aabbs = gprtBufferGetHandle(aabbPositions);
aabbGeomData->aabbs = gprtBufferGetDevicePointer(aabbPositions);
gprtAABBsSetPositions(aabbGeom, aabbPositions, 1);

// Place that geometry into an AABB BLAS.
Expand All @@ -259,11 +259,11 @@ main(int ac, char **av) {
gprtBuildShaderBindingTable(context, GPRT_SBT_ALL);

ComputeOBBConstants obbPC;
obbPC.aabbs = gprtBufferGetHandle(aabbPositions);
obbPC.eulRots = gprtBufferGetHandle(eulRots);
obbPC.vertices = gprtBufferGetHandle(mesh.vertexBuffer);
obbPC.indices = gprtBufferGetHandle(mesh.indexBuffer);
obbPC.instance = gprtBufferGetHandle(aabbInstanceBuffer);
obbPC.aabbs = gprtBufferGetDevicePointer(aabbPositions);
obbPC.eulRots = gprtBufferGetDevicePointer(eulRots);
obbPC.vertices = gprtBufferGetDevicePointer(mesh.vertexBuffer);
obbPC.indices = gprtBufferGetDevicePointer(mesh.indexBuffer);
obbPC.instance = gprtBufferGetDevicePointer(aabbInstanceBuffer);
obbPC.numIndices = mesh.indices.size();
obbPC.numTrisToInclude = mesh.indices.size();

Expand All @@ -280,8 +280,8 @@ main(int ac, char **av) {

CompositeGuiConstants guiPC;
guiPC.fbSize = fbSize;
guiPC.frameBuffer = gprtBufferGetHandle(frameBuffer);
guiPC.imageBuffer = gprtBufferGetHandle(imageBuffer);
guiPC.frameBuffer = gprtBufferGetDevicePointer(frameBuffer);
guiPC.imageBuffer = gprtBufferGetDevicePointer(imageBuffer);
guiPC.guiTexture = gprtTextureGetHandle(guiColorAttachment);

RTPushConstants rtPC;
Expand Down
22 changes: 11 additions & 11 deletions samples/s13-differentiable/sharedCode.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@
#include "gprt.h"

struct TrianglesGeomData {
gprt::Buffer vertex;
gprt::Buffer index;
float3 *vertex;
uint3 *index;
};

struct BoundingBoxData {
gprt::Buffer aabbs;
float3 *aabbs;
};

struct RayGenData {
gprt::Buffer imageBuffer;
float4 *imageBuffer;
gprt::Accel triangleTLAS;
gprt::Accel obbAccel;
};
Expand All @@ -56,17 +56,17 @@ struct RTPushConstants {

struct ComputeOBBConstants {
int numIndices;
gprt::Buffer vertices;
gprt::Buffer indices;
gprt::Buffer eulRots;
gprt::Buffer aabbs;
gprt::Buffer instance;
float3 *vertices;
uint3 *indices;
float3 *eulRots;
float3 *aabbs;
gprt::Instance *instance;
int numTrisToInclude;
};

struct CompositeGuiConstants {
uint2 fbSize;
gprt::Buffer imageBuffer;
gprt::Buffer frameBuffer;
float4 *imageBuffer;
uint *frameBuffer;
gprt::Texture guiTexture;
};

0 comments on commit 3ed2137

Please sign in to comment.