From 798bbf06e102892113224a20af952f34503a72b8 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 21 Aug 2024 03:02:41 +0200 Subject: [PATCH] Add fma and sub --- ptx_parser/src/ast.rs | 19 +++++ ptx_parser/src/main.rs | 181 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 190 insertions(+), 10 deletions(-) diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 248a6f32..e1725c88 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -221,6 +221,25 @@ gen::generate_instruction_type!( src3: T, } }, + Fma { + type: { Type::from(data.type_) }, + data: ArithFloat, + arguments: { + dst: T, + src1: T, + src2: T, + src3: T, + } + }, + Sub { + type: { Type::from(data.type_()) }, + data: ArithDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, Trap { } } ); diff --git a/ptx_parser/src/main.rs b/ptx_parser/src/main.rs index ce1f56dd..9531f1c5 100644 --- a/ptx_parser/src/main.rs +++ b/ptx_parser/src/main.rs @@ -1481,7 +1481,7 @@ derive_parser!( mul{.rnd}{.ftz}{.sat}.f32 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( - ArithFloat { + ast::ArithFloat { type_: f32, rounding: rnd.map(Into::into), flush_to_zero: Some(ftz), @@ -1494,7 +1494,7 @@ derive_parser!( mul{.rnd}.f64 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( - ArithFloat { + ast::ArithFloat { type_: f64, rounding: rnd.map(Into::into), flush_to_zero: None, @@ -1510,7 +1510,7 @@ derive_parser!( mul{.rnd}{.ftz}{.sat}.f16 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( - ArithFloat { + ast::ArithFloat { type_: f16, rounding: rnd.map(Into::into), flush_to_zero: Some(ftz), @@ -1523,7 +1523,7 @@ derive_parser!( mul{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( - ArithFloat { + ast::ArithFloat { type_: f16x2, rounding: rnd.map(Into::into), flush_to_zero: Some(ftz), @@ -1536,7 +1536,7 @@ derive_parser!( mul{.rnd}.bf16 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( - ArithFloat { + ast::ArithFloat { type_: bf16, rounding: rnd.map(Into::into), flush_to_zero: None, @@ -1549,7 +1549,7 @@ derive_parser!( mul{.rnd}.bf16x2 d, a, b => { ast::Instruction::Mul { data: ast::MulDetails::Float ( - ArithFloat { + ast::ArithFloat { type_: bf16x2, rounding: rnd.map(Into::into), flush_to_zero: None, @@ -1835,7 +1835,7 @@ derive_parser!( mad{.ftz}{.sat}.f32 d, a, b, c => { ast::Instruction::Mad { data: ast::MadDetails::Float( - ArithFloat { + ast::ArithFloat { type_: f32, rounding: None, flush_to_zero: Some(ftz), @@ -1848,7 +1848,7 @@ derive_parser!( mad.rnd{.ftz}{.sat}.f32 d, a, b, c => { ast::Instruction::Mad { data: ast::MadDetails::Float( - ArithFloat { + ast::ArithFloat { type_: f32, rounding: Some(rnd.into()), flush_to_zero: Some(ftz), @@ -1861,7 +1861,7 @@ derive_parser!( mad.rnd.f64 d, a, b, c => { ast::Instruction::Mad { data: ast::MadDetails::Float( - ArithFloat { + ast::ArithFloat { type_: f64, rounding: Some(rnd.into()), flush_to_zero: None, @@ -1869,10 +1869,171 @@ derive_parser!( } ), arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } - }} + } + } .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; ScalarType = { .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-fma + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-fma + fma.rnd{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f32, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + fma.rnd.f64 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f64, + rounding: Some(rnd.into()), + flush_to_zero: None, + saturate: false + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + fma.rnd{.ftz}{.sat}.f16 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f16, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + //fma.rnd{.ftz}{.sat}.f16x2 d, a, b, c; + //fma.rnd{.ftz}.relu.f16 d, a, b, c; + //fma.rnd{.ftz}.relu.f16x2 d, a, b, c; + //fma.rnd{.relu}.bf16 d, a, b, c; + //fma.rnd{.relu}.bf16x2 d, a, b, c; + //fma.rnd.oob.{relu}.type d, a, b, c; + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub + sub.type d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Integer( + ArithInteger { + type_, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub.sat.s32 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Integer( + ArithInteger { + type_: s32, + saturate: true + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + ScalarType = { .s32 }; + + sub{.rnd}{.ftz}{.sat}.f32 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.f64 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + sub{.rnd}{.ftz}{.sat}.f16 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.bf16 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.bf16x2 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret ret{.uni} => { Instruction::Ret { data: RetData { uniform: uni } }