diff --git a/CMakeLists.txt b/CMakeLists.txt index 8499a6d..ac73026 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,7 +51,7 @@ foreach(EXAMPLE_SOURCE ${EXAMPLE_SOURCES}) # CUDA properties provided by CMAKE set_target_properties(${EXAMPLE_NAME} PROPERTIES CUDA_SEPARABLE_COMPILATION ON) - set_target_properties(${EXAMPLE_NAME} PROPERTIES CUDA_ARCHITECTURES 90a) + set_target_properties(${EXAMPLE_NAME} PROPERTIES CUDA_ARCHITECTURES 100) # Convert the flags string into a list of flags separate_arguments(EXTRA_CUDA_FLAGS_LIST UNIX_COMMAND "${EXTRA_CUDA_FLAGS}") diff --git a/examples/mx/num_to_ue8mo.cu b/examples/mx/num_to_ue8mo.cu new file mode 100644 index 0000000..9c59eda --- /dev/null +++ b/examples/mx/num_to_ue8mo.cu @@ -0,0 +1,56 @@ +#include +#include + +__global__ void convert_to_e8m0(float *in, __nv_fp8_storage_t *out) { + const float input_val = in[0]; + printf("Device input value: %f\n", input_val); + __nv_fp8_storage_t result = + __nv_cvt_float_to_e8m0(input_val, __NV_SATFINITE, cudaRoundNearest); + printf("Device output value (hex): 0x%02x, (decimal): %u\n", + (unsigned char)result, (unsigned char)result); + out[0] = result; +} + +int main() { + float h_in = 1.0f / 448.0f; + float *d_in; + __nv_fp8_storage_t *d_out, h_out; + + cudaMalloc(&d_in, sizeof(float)); + cudaMalloc(&d_out, sizeof(__nv_fp8_storage_t)); + + cudaError_t err = + cudaMemcpy(d_in, &h_in, sizeof(float), cudaMemcpyHostToDevice); + if (err != cudaSuccess) { + printf("Memcpy error: %s\n", cudaGetErrorString(err)); + return 1; + } + + convert_to_e8m0<<<1, 1>>>(d_in, d_out); + cudaDeviceSynchronize(); // Need this to see printf from kernel + err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("Kernel error: %s\n", cudaGetErrorString(err)); + return 1; + } + + err = cudaMemcpy(&h_out, d_out, sizeof(__nv_fp8_storage_t), + cudaMemcpyDeviceToHost); + if (err != cudaSuccess) { + printf("Memcpy error: %s\n", cudaGetErrorString(err)); + return 1; + } + + printf("Host input float: %f\n", h_in); + printf("Host output e8m0 hex: 0x%02x, decimal: %u\n", (unsigned char)h_out, + (unsigned char)h_out); + printf("Host output e8m0 bits: "); + for (int i = 7; i >= 0; i--) { + printf("%d", (h_out >> i) & 0x1); + } + printf("\n"); + + cudaFree(d_in); + cudaFree(d_out); + return 0; +}