Skip to content

Commit

Permalink
Add fma and sub
Browse files Browse the repository at this point in the history
  • Loading branch information
vosen committed Aug 21, 2024
1 parent 6cd18bf commit 798bbf0
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 10 deletions.
19 changes: 19 additions & 0 deletions ptx_parser/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,25 @@ gen::generate_instruction_type!(
src3: T,
}
},
Fma {
type: { Type::from(data.type_) },
data: ArithFloat,
arguments<T>: {
dst: T,
src1: T,
src2: T,
src3: T,
}
},
Sub {
type: { Type::from(data.type_()) },
data: ArithDetails,
arguments<T>: {
dst: T,
src1: T,
src2: T,
}
},
Trap { }
}
);
Expand Down
181 changes: 171 additions & 10 deletions ptx_parser/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1481,7 +1481,7 @@ derive_parser!(
mul{.rnd}{.ftz}{.sat}.f32 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
ArithFloat {
ast::ArithFloat {
type_: f32,
rounding: rnd.map(Into::into),
flush_to_zero: Some(ftz),
Expand All @@ -1494,7 +1494,7 @@ derive_parser!(
mul{.rnd}.f64 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
ArithFloat {
ast::ArithFloat {
type_: f64,
rounding: rnd.map(Into::into),
flush_to_zero: None,
Expand All @@ -1510,7 +1510,7 @@ derive_parser!(
mul{.rnd}{.ftz}{.sat}.f16 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
ArithFloat {
ast::ArithFloat {
type_: f16,
rounding: rnd.map(Into::into),
flush_to_zero: Some(ftz),
Expand All @@ -1523,7 +1523,7 @@ derive_parser!(
mul{.rnd}{.ftz}{.sat}.f16x2 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
ArithFloat {
ast::ArithFloat {
type_: f16x2,
rounding: rnd.map(Into::into),
flush_to_zero: Some(ftz),
Expand All @@ -1536,7 +1536,7 @@ derive_parser!(
mul{.rnd}.bf16 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
ArithFloat {
ast::ArithFloat {
type_: bf16,
rounding: rnd.map(Into::into),
flush_to_zero: None,
Expand All @@ -1549,7 +1549,7 @@ derive_parser!(
mul{.rnd}.bf16x2 d, a, b => {
ast::Instruction::Mul {
data: ast::MulDetails::Float (
ArithFloat {
ast::ArithFloat {
type_: bf16x2,
rounding: rnd.map(Into::into),
flush_to_zero: None,
Expand Down Expand Up @@ -1835,7 +1835,7 @@ derive_parser!(
mad{.ftz}{.sat}.f32 d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Float(
ArithFloat {
ast::ArithFloat {
type_: f32,
rounding: None,
flush_to_zero: Some(ftz),
Expand All @@ -1848,7 +1848,7 @@ derive_parser!(
mad.rnd{.ftz}{.sat}.f32 d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Float(
ArithFloat {
ast::ArithFloat {
type_: f32,
rounding: Some(rnd.into()),
flush_to_zero: Some(ftz),
Expand All @@ -1861,18 +1861,179 @@ derive_parser!(
mad.rnd.f64 d, a, b, c => {
ast::Instruction::Mad {
data: ast::MadDetails::Float(
ArithFloat {
ast::ArithFloat {
type_: f64,
rounding: Some(rnd.into()),
flush_to_zero: None,
saturate: false
}
),
arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
}}
}
}
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-fma
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-fma
fma.rnd{.ftz}{.sat}.f32 d, a, b, c => {
ast::Instruction::Fma {
data: ast::ArithFloat {
type_: f32,
rounding: Some(rnd.into()),
flush_to_zero: Some(ftz),
saturate: sat
},
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
}
}
fma.rnd.f64 d, a, b, c => {
ast::Instruction::Fma {
data: ast::ArithFloat {
type_: f64,
rounding: Some(rnd.into()),
flush_to_zero: None,
saturate: false
},
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
}
}
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };

fma.rnd{.ftz}{.sat}.f16 d, a, b, c => {
ast::Instruction::Fma {
data: ast::ArithFloat {
type_: f16,
rounding: Some(rnd.into()),
flush_to_zero: Some(ftz),
saturate: sat
},
arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
}
}
//fma.rnd{.ftz}{.sat}.f16x2 d, a, b, c;
//fma.rnd{.ftz}.relu.f16 d, a, b, c;
//fma.rnd{.ftz}.relu.f16x2 d, a, b, c;
//fma.rnd{.relu}.bf16 d, a, b, c;
//fma.rnd{.relu}.bf16x2 d, a, b, c;
//fma.rnd.oob.{relu}.type d, a, b, c;
.rnd: RawRoundingMode = { .rn };
ScalarType = { .f16 };

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub
sub.type d, a, b => {
ast::Instruction::Sub {
data: ast::ArithDetails::Integer(
ArithInteger {
type_,
saturate: false
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
}
}
sub.sat.s32 d, a, b => {
ast::Instruction::Sub {
data: ast::ArithDetails::Integer(
ArithInteger {
type_: s32,
saturate: true
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
}
}
.type: ScalarType = { .u16, .u32, .u64,
.s16, .s32, .s64 };
ScalarType = { .s32 };

sub{.rnd}{.ftz}{.sat}.f32 d, a, b => {
ast::Instruction::Sub {
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f32,
rounding: rnd.map(Into::into),
flush_to_zero: Some(ftz),
saturate: sat
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
}
}
sub{.rnd}.f64 d, a, b => {
ast::Instruction::Sub {
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f64,
rounding: rnd.map(Into::into),
flush_to_zero: None,
saturate: false
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
}
}
.rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
ScalarType = { .f32, .f64 };

sub{.rnd}{.ftz}{.sat}.f16 d, a, b => {
ast::Instruction::Sub {
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f16,
rounding: rnd.map(Into::into),
flush_to_zero: Some(ftz),
saturate: sat
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
}
}
sub{.rnd}{.ftz}{.sat}.f16x2 d, a, b => {
ast::Instruction::Sub {
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: f16x2,
rounding: rnd.map(Into::into),
flush_to_zero: Some(ftz),
saturate: sat
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
}
}
sub{.rnd}.bf16 d, a, b => {
ast::Instruction::Sub {
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: bf16,
rounding: rnd.map(Into::into),
flush_to_zero: None,
saturate: false
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
}
}
sub{.rnd}.bf16x2 d, a, b => {
ast::Instruction::Sub {
data: ast::ArithDetails::Float(
ast::ArithFloat {
type_: bf16x2,
rounding: rnd.map(Into::into),
flush_to_zero: None,
saturate: false
}
),
arguments: SubArgs { dst: d, src1: a, src2: b }
}
}
.rnd: RawRoundingMode = { .rn };
ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };

// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
ret{.uni} => {
Instruction::Ret { data: RetData { uniform: uni } }
Expand Down

0 comments on commit 798bbf0

Please sign in to comment.