Skip to content

Commit

Permalink
updating multiple aabb geoms test
Browse files Browse the repository at this point in the history
  • Loading branch information
natevm committed Dec 22, 2024
1 parent 4716723 commit 3a868a5
Show file tree
Hide file tree
Showing 10 changed files with 227 additions and 225 deletions.
16 changes: 8 additions & 8 deletions gprt/gprt_host.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,8 +305,8 @@ GPRT_API void gprtGeomSetParameters(GPRTGeom geometry, void *parameters, int dev
*/
template <typename T>
void
gprtGeomSetParameters(GPRTGeomOf<T> geometry, T *parameters, int deviceID GPRT_IF_CPP(= 0)) {
gprtGeomSetParameters((GPRTGeom) geometry, (void *) parameters, deviceID);
gprtGeomSetParameters(GPRTGeomOf<T> geometry, T &parameters, int deviceID GPRT_IF_CPP(= 0)) {
gprtGeomSetParameters((GPRTGeom) geometry, (void *) &parameters, deviceID);
}

// ==================================================================
Expand Down Expand Up @@ -791,8 +791,8 @@ GPRT_API void gprtRayGenSetParameters(GPRTRayGen rayGen, void *parameters, int d
*/
template <typename T>
void
gprtRayGenSetParameters(GPRTRayGenOf<T> rayGen, T *parameters, int deviceID GPRT_IF_CPP(= 0)) {
gprtRayGenSetParameters((GPRTRayGen) rayGen, (void *) parameters, deviceID);
gprtRayGenSetParameters(GPRTRayGenOf<T> rayGen, T &parameters, int deviceID GPRT_IF_CPP(= 0)) {
gprtRayGenSetParameters((GPRTRayGen) rayGen, (void *) &parameters, deviceID);
}

/**
Expand Down Expand Up @@ -889,8 +889,8 @@ GPRT_API void gprtMissSetParameters(GPRTMiss miss, void *parameters, int deviceI
*/
template <typename T>
void
gprtMissSetParameters(GPRTMissOf<T> miss, T *parameters, int deviceID GPRT_IF_CPP(= 0)) {
gprtMissSetParameters((GPRTMiss) miss, (void *) parameters, deviceID);
gprtMissSetParameters(GPRTMissOf<T> miss, T &parameters, int deviceID GPRT_IF_CPP(= 0)) {
gprtMissSetParameters((GPRTMiss) miss, (void *) &parameters, deviceID);
}

GPRT_API uint32_t gprtMissGetIndex(GPRTMiss missProg, int deviceID GPRT_IF_CPP(= 0));
Expand Down Expand Up @@ -934,8 +934,8 @@ GPRT_API void gprtCallableSetParameters(GPRTCallable callable, void *parameters,

template <typename T>
void
gprtCallableSetParameters(GPRTCallableOf<T> callable, T *parameters, int deviceID GPRT_IF_CPP(= 0)) {
gprtCallableSetParameters((GPRTCallable) callable, (void *) parameters, deviceID);
gprtCallableSetParameters(GPRTCallableOf<T> callable, T &parameters, int deviceID GPRT_IF_CPP(= 0)) {
gprtCallableSetParameters((GPRTCallable) callable, (void *) &parameters, deviceID);
}

// ------------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions samples/s0-rayGenPrograms/s0-1-multipleRayGens/hostCode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ int main() {
data.color0 = float3(0.1f, 0.1f, 0.1f); // Background color
data.color1 = float3(0.0f, 0.0f, 0.0f); // Secondary color
data.frameBuffer = gprtBufferGetDevicePointer(frameBuffer);
gprtRayGenSetParameters(firstRayGen, &data);
gprtRayGenSetParameters(secondRayGen, &data);
gprtRayGenSetParameters(firstRayGen, data);
gprtRayGenSetParameters(secondRayGen, data);

// Build the Shader Binding Table (SBT)
gprtBuildShaderBindingTable(context, GPRT_SBT_RAYGEN);
Expand Down
2 changes: 1 addition & 1 deletion samples/s2-hitPrograms/s2-1-spheres/hostCode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ int main(int ac, char **av) {
// when rays hit spheres.
SphereGeomData sphereParams;
sphereParams.posAndRadius = gprtBufferGetDevicePointer(vertexBuffer);
gprtGeomSetParameters(sphereGeom, &sphereParams);
gprtGeomSetParameters(sphereGeom, sphereParams);

// Create and build BLAS
GPRTAccel sphereAccel = gprtSphereAccelCreate(context, 1, &sphereGeom);
Expand Down
2 changes: 1 addition & 1 deletion samples/s2-hitPrograms/s2-2-lss/hostCode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ int main(int ac, char **av) {
LSSGeomData lssParams;
lssParams.indices = gprtBufferGetDevicePointer(indexBuffer);
lssParams.vertices = gprtBufferGetDevicePointer(vertexBuffer);
gprtGeomSetParameters(lssGeom, &lssParams);
gprtGeomSetParameters(lssGeom, lssParams);

// Create and build BLAS
GPRTAccel lssAccel = gprtLSSAccelCreate(context, 1, &lssGeom);
Expand Down
4 changes: 2 additions & 2 deletions samples/s3-instancing/s3-2-triGeoInBLAS/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ embed_devicecode(
${CMAKE_CURRENT_SOURCE_DIR}/deviceCode.slang
)

add_executable(s3_2_multipleGeometry hostCode.cpp)
target_link_libraries(s3_2_multipleGeometry
add_executable(s3_2_triGeoInBLAS hostCode.cpp)
target_link_libraries(s3_2_triGeoInBLAS
PRIVATE
s3_2_deviceCode
gprt::gprt
Expand Down
4 changes: 2 additions & 2 deletions samples/s3-instancing/s3-3-AABBGeoInBLAS/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ embed_devicecode(
${CMAKE_CURRENT_SOURCE_DIR}/deviceCode.slang
)

add_executable(s3_3_multipleGeometry hostCode.cpp)
target_link_libraries(s3_3_multipleGeometry
add_executable(s3_3_AABBGeoInBLAS hostCode.cpp)
target_link_libraries(s3_3_AABBGeoInBLAS
PRIVATE
s3_3_deviceCode
gprt::gprt
Expand Down
178 changes: 143 additions & 35 deletions samples/s3-instancing/s3-3-AABBGeoInBLAS/deviceCode.slang
Original file line number Diff line number Diff line change
Expand Up @@ -7,50 +7,158 @@ struct Payload {
float3 color;
};

// This intersection program will be called when rays hit our axis
// aligned bounding boxes. Here, we can fetch per-geometry data and
// process that data, but we do not have access to the ray payload
// structure here.
//
// Instead, we pass data through a customizable Attributes structure
// for further processing by closest hit / any hit programs.
struct BBoxAttributes {
float3 cen;
float tN;
float3 rad;
float tF;
};

[shader("intersection")]
void AABBIntersection(uniform AABBGeomData record) {
uint primID = PrimitiveIndex();
float3 ro = ObjectRayOrigin();
float3 rd = ObjectRayDirection();
float tcur = RayTCurrent();

BBoxAttributes attr;

// raytrace bounding box
float3 bbmin = record.aabbs[primID * 2 + 0];
float3 bbmax = record.aabbs[primID * 2 + 1];

attr.cen = 0.5 * (bbmin + bbmax);
attr.rad = 0.5 * (bbmax - bbmin);

float3 m = 1.0 / rd;
float3 n = m * (ro - attr.cen);
float3 k = abs(m) * attr.rad;

float3 t1 = -n - k;
float3 t2 = -n + k;

attr.tN = max(max(t1.x, t1.y), t1.z);
attr.tF = min(min(t2.x, t2.y), t2.z);

if (attr.tN > attr.tF || attr.tF < 0.0)
return;

int hitKind = 0;
if (attr.tN > 0.0) {
// front face
if (attr.tN <= tcur)
hitKind |= 1;
// back face
if (attr.tF <= tcur)
hitKind |= 2;
}

if (hitKind != 0)
ReportHit(attr.tN, hitKind, attr);
}

// This closest hit program will be called when our intersection program
// reports a hit between our ray and our custom primitives.
// Here, we can fetch per-geometry data, process that data, and send
// it back to our ray generation program.
//
// Note, since this is a custom AABB primitive, our intersection program
// above defines what attributes are passed to our closest hit program.
//
// Also note, this program is also called after all ReportHit's have been
// called and we can conclude which reported hit is closest.
[shader("closesthit")]
void AABBClosestHit(uniform AABBGeomData record, inout Payload payload, in BBoxAttributes attr) {
float3 ro = ObjectRayOrigin();
float3 rd = ObjectRayDirection();
int hitKind = HitKind();
int geoIndex = GeometryIndex();

// front
if (bool(hitKind & 1)) {
float3 pos = ro + rd * attr.tN;
float3 e = smoothstep(attr.rad - 0.03, attr.rad - 0.02, abs(pos - attr.cen));
float al = 1.0 - (1.0 - e.x * e.y) * (1.0 - e.y * e.z) * (1.0 - e.z * e.x);
payload.color = lerp(float3(0.0), float3(1.0), 0.15 + 0.85 * al);
}

// back
if (bool(hitKind & 2)) {
float3 pos = ro + rd * attr.tF;
float3 e = smoothstep(attr.rad - 0.03, attr.rad - 0.02, abs(pos - attr.cen));
float al = 1.0 - (1.0 - e.x * e.y) * (1.0 - e.y * e.z) * (1.0 - e.z * e.x);
payload.color = lerp(payload.color, float3(1.0), 0.25 + 0.75 * al);
}

if (geoIndex == 0) payload.color *= float3(1.0, 0.5, 0.5);
if (geoIndex == 1) payload.color *= float3(0.5, 1.5, 0.5);
if (geoIndex == 2) payload.color *= float3(0.5, 0.5, 1.5);
}

// This ray generation program will kick off the ray tracing process,
// generating rays and tracing them into the world.
[shader("raygeneration")]
void raygen(uniform RayGenData record) {
Payload payload;
uint2 pixelID = DispatchRaysIndex().xy;
uint2 fbSize = DispatchRaysDimensions().xy;
float2 screen = (float2(pixelID) + float2(.5f, .5f)) / float2(fbSize);

// Generate ray
RayDesc rayDesc;
rayDesc.Origin = pc.camera.pos;
rayDesc.Direction = normalize(pc.camera.dir_00 + screen.x * pc.camera.dir_du + screen.y * pc.camera.dir_dv);
rayDesc.TMin = 0.0;
rayDesc.TMax = 1e38f;
uint2 iResolution = DispatchRaysDimensions().xy;

// Trace ray against surface
RaytracingAccelerationStructure world = gprt::getAccelHandle(record.world);
TraceRay(world, // the tree
RAY_FLAG_FORCE_OPAQUE, // ray flags
0xff, // instance inclusion mask
0, // ray type
1, // number of ray types
0, // miss type
rayDesc, // the ray to trace
payload // the payload IO
);

const int fbOfs = pixelID.x + fbSize.x * pixelID.y;
record.frameBuffer[fbOfs] = gprt::make_bgra(payload.color);
}

[shader("closesthit")]
void closesthit(uniform TrianglesGeomData record, inout Payload payload) {
// compute normal:
uint primID = PrimitiveIndex();
uint3 index = record.index[primID];
float3 A = record.vertex[index.x];
float3 B = record.vertex[index.y];
float3 C = record.vertex[index.z];
float3 Ng = normalize(cross(B - A, C - A));
float3 rayDir = normalize(ObjectRayDirection());
payload.color = (.2f + .8f * abs(dot(rayDir, Ng))) * record.color;
// camera movement
float an = pc.time;
float3 ro = float3(-6.0 * sin(an), 0.0, -6.0 * cos(an));
float3 ta = float3(0.0, 0.0, 0.0);

// camera matrix
float3 ww = normalize(ta - ro);
float3 uu = normalize(cross(ww, float3(0.0, -1.0, 0.0)));
float3 vv = normalize(cross(uu, ww));

float3 tot = float3(0.0);

for (int m = 0; m < AA; m++) {
for (int n = 0; n < AA; n++) {
// pixel coordinates
float2 o = float2(float(m), float(n)) / float(AA) - 0.5;
float2 p = (2.0 * (pixelID + o) - iResolution.xy) / iResolution.y;

// create view ray
float3 rd = normalize(p.x * uu + p.y * vv + 4.0 * ww);

RayDesc rayDesc;
rayDesc.Origin = ro;
rayDesc.Direction = rd;
rayDesc.TMin = 0.0;
rayDesc.TMax = 10000.0;
TraceRay(world, // the tree
RAY_FLAG_NONE, // ray flags
0xff, // instance inclusion mask
0, // ray type
1, // number of ray types
0, // miss index
rayDesc, // the ray to trace
payload // the payload IO
);

tot += payload.color;
}
}

tot /= float(AA * AA);

const int fbOfs = pixelID.x + iResolution.x * pixelID.y;
record.frameBuffer[fbOfs] = gprt::make_bgra(tot);
}

[shader("miss")]
void miss(inout Payload payload) {
payload.color = float3(0.0);
}
}
Loading

0 comments on commit 3a868a5

Please sign in to comment.