diff --git a/Cargo.toml b/Cargo.toml index 63719813..350d568f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,7 @@ [workspace] +resolver = "2" + members = [ "cuda_base", "cuda_types", @@ -15,6 +17,9 @@ members = [ "zluda_redirect", "zluda_ml", "ptx", + "ptx_parser", + "ptx_parser_macros", + "ptx_parser_macros_impl", ] default-members = ["zluda_lib", "zluda_ml", "zluda_inject", "zluda_redirect"] diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 2ac1f689..d4852862 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -7,7 +7,7 @@ edition = "2018" [lib] [dependencies] -lalrpop-util = "0.19" +ptx_parser = { path = "../ptx_parser" } regex = "1" rspirv = "0.7" spirv_headers = "1.5" @@ -17,8 +17,12 @@ bit-vec = "0.6" half ="1.6" bitflags = "1.2" +[dependencies.lalrpop-util] +version = "0.19.12" +features = ["lexer"] + [build-dependencies.lalrpop] -version = "0.19" +version = "0.19.12" features = ["lexer"] [dev-dependencies] diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index d308479b..358b8cef 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -16,6 +16,8 @@ pub enum PtxError { source: ParseFloatError, }, #[error("")] + Unsupported32Bit, + #[error("")] SyntaxError, #[error("")] NonF32Ftz, @@ -32,15 +34,9 @@ pub enum PtxError { #[error("")] NonExternPointer, #[error("{start}:{end}")] - UnrecognizedStatement { - start: usize, - end: usize, - }, + UnrecognizedStatement { start: usize, end: usize }, #[error("{start}:{end}")] - UnrecognizedDirective { - start: usize, - end: usize, - }, + UnrecognizedDirective { start: usize, end: usize }, } // For some weird reson this is illegal: @@ -576,11 +572,15 @@ impl CvtDetails { if saturate { if src.kind() == ScalarKind::Signed { if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() { - err.push(ParseError::from(PtxError::SyntaxError)); + err.push(ParseError::User { + error: PtxError::SyntaxError, + }); } } else { if dst == src || dst.size_of() >= src.size_of() { - err.push(ParseError::from(PtxError::SyntaxError)); + err.push(ParseError::User { + error: PtxError::SyntaxError, + }); } } } @@ -596,7 +596,9 @@ impl CvtDetails { err: &'err mut Vec, PtxError>>, ) -> Self { if flush_to_zero && dst != ScalarType::F32 { - err.push(ParseError::from(PtxError::NonF32Ftz)); + err.push(ParseError::from(lalrpop_util::ParseError::User { + error: PtxError::NonF32Ftz, + })); } CvtDetails::FloatFromInt(CvtDesc { dst, @@ -616,7 +618,9 @@ impl CvtDetails { err: &'err mut Vec, PtxError>>, ) -> Self { if flush_to_zero && src != ScalarType::F32 { - err.push(ParseError::from(PtxError::NonF32Ftz)); + err.push(ParseError::from(lalrpop_util::ParseError::User { + error: PtxError::NonF32Ftz, + })); } CvtDetails::IntFromFloat(CvtDesc { dst, diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs index 1cb96308..5e95dae2 100644 --- a/ptx/src/lib.rs +++ b/ptx/src/lib.rs @@ -24,6 +24,7 @@ lalrpop_mod!( ); pub mod ast; +pub(crate) mod pass; #[cfg(test)] mod test; mod translate; diff --git a/ptx/src/pass/convert_dynamic_shared_memory_usage.rs b/ptx/src/pass/convert_dynamic_shared_memory_usage.rs new file mode 100644 index 00000000..1dac7fd7 --- /dev/null +++ b/ptx/src/pass/convert_dynamic_shared_memory_usage.rs @@ -0,0 +1,299 @@ +use std::collections::{BTreeMap, BTreeSet}; + +use super::*; + +/* + PTX represents dynamically allocated shared local memory as + .extern .shared .b32 shared_mem[]; + In SPIRV/OpenCL world this is expressed as an additional argument to the kernel + And in AMD compilation + This pass looks for all uses of .extern .shared and converts them to + an additional method argument + The question is how this artificial argument should be expressed. There are + several options: + * Straight conversion: + .shared .b32 shared_mem[] + * Introduce .param_shared statespace: + .param_shared .b32 shared_mem + or + .param_shared .b32 shared_mem[] + * Introduce .shared_ptr type: + .param .shared_ptr .b32 shared_mem + * Reuse .ptr hint: + .param .u64 .ptr shared_mem + This is the most tempting, but also the most nonsensical, .ptr is just a + hint, which has no semantical meaning (and the output of our + transformation has a semantical meaning - we emit additional + "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...") +*/ +pub(super) fn run<'input>( + module: Vec>, + kernels_methods_call_map: &MethodsCallMap<'input>, + new_id: &mut impl FnMut() -> SpirvWord, +) -> Result>, TranslateError> { + let mut globals_shared = HashMap::new(); + for dir in module.iter() { + match dir { + Directive::Variable( + _, + ast::Variable { + state_space: ast::StateSpace::Shared, + name, + v_type, + .. + }, + ) => { + globals_shared.insert(*name, v_type.clone()); + } + _ => {} + } + } + if globals_shared.len() == 0 { + return Ok(module); + } + let mut methods_to_directly_used_shared_globals = HashMap::<_, HashSet>::new(); + let module = module + .into_iter() + .map(|directive| match directive { + Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + import_as, + tuning, + linkage, + }) => { + let call_key = (*func_decl).borrow().name; + let statements = statements + .into_iter() + .map(|statement| { + statement.visit_map( + &mut |id, _: Option<(&ast::Type, ast::StateSpace)>, _, _| { + if let Some(_) = globals_shared.get(&id) { + methods_to_directly_used_shared_globals + .entry(call_key) + .or_insert_with(HashSet::new) + .insert(id); + } + Ok::<_, TranslateError>(id) + }, + ) + }) + .collect::, _>>()?; + Ok::<_, TranslateError>(Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + import_as, + tuning, + linkage, + })) + } + directive => Ok(directive), + }) + .collect::, _>>()?; + // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared, + // make sure it gets propagated to `fn1` and `kernel` + let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared( + methods_to_directly_used_shared_globals, + kernels_methods_call_map, + ); + // now visit every method declaration and inject those additional arguments + let mut directives = Vec::with_capacity(module.len()); + for directive in module.into_iter() { + match directive { + Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + import_as, + tuning, + linkage, + }) => { + let statements = { + let func_decl_ref = &mut (*func_decl).borrow_mut(); + let method_name = func_decl_ref.name; + insert_arguments_remap_statements( + new_id, + kernels_methods_call_map, + &globals_shared, + &methods_to_indirectly_used_shared_globals, + method_name, + &mut directives, + func_decl_ref, + statements, + )? + }; + directives.push(Directive::Method(Function { + func_decl, + globals, + body: Some(statements), + import_as, + tuning, + linkage, + })); + } + directive => directives.push(directive), + } + } + Ok(directives) +} + +// We need to compute two kinds of information: +// * If it's a kernel -> size of .shared globals in use (direct or indirect) +// * If it's a function -> does it use .shared global (directly or indirectly) +fn resolve_indirect_uses_of_globals_shared<'input>( + methods_use_of_globals_shared: HashMap, HashSet>, + kernels_methods_call_map: &MethodsCallMap<'input>, +) -> HashMap, BTreeSet> { + let mut result = HashMap::new(); + for (method, callees) in kernels_methods_call_map.methods() { + let mut indirect_globals = methods_use_of_globals_shared + .get(&method) + .into_iter() + .flatten() + .copied() + .collect::>(); + for &callee in callees { + indirect_globals.extend( + methods_use_of_globals_shared + .get(&ast::MethodName::Func(callee)) + .into_iter() + .flatten() + .copied(), + ); + } + result.insert(method, indirect_globals); + } + result +} + +fn insert_arguments_remap_statements<'input>( + new_id: &mut impl FnMut() -> SpirvWord, + kernels_methods_call_map: &MethodsCallMap<'input>, + globals_shared: &HashMap, + methods_to_indirectly_used_shared_globals: &HashMap< + ast::MethodName<'input, SpirvWord>, + BTreeSet, + >, + method_name: ast::MethodName, + result: &mut Vec, + func_decl_ref: &mut std::cell::RefMut>, + statements: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + let remapped_globals_in_method = + if let Some(method_globals) = methods_to_indirectly_used_shared_globals.get(&method_name) { + match method_name { + ast::MethodName::Func(..) => { + let remapped_globals = method_globals + .iter() + .map(|global| { + ( + *global, + ( + new_id(), + globals_shared + .get(&global) + .unwrap_or_else(|| todo!()) + .clone(), + ), + ) + }) + .collect::>(); + for (_, (new_shared_global_id, shared_global_type)) in remapped_globals.iter() { + func_decl_ref.input_arguments.push(ast::Variable { + align: None, + v_type: shared_global_type.clone(), + state_space: ast::StateSpace::Shared, + name: *new_shared_global_id, + array_init: Vec::new(), + }); + } + remapped_globals + } + ast::MethodName::Kernel(..) => method_globals + .iter() + .map(|global| { + ( + *global, + ( + *global, + globals_shared + .get(&global) + .unwrap_or_else(|| todo!()) + .clone(), + ), + ) + }) + .collect::>(), + } + } else { + return Ok(statements); + }; + replace_uses_of_shared_memory( + new_id, + methods_to_indirectly_used_shared_globals, + statements, + remapped_globals_in_method, + ) +} + +fn replace_uses_of_shared_memory<'input>( + new_id: &mut impl FnMut() -> SpirvWord, + methods_to_indirectly_used_shared_globals: &HashMap< + ast::MethodName<'input, SpirvWord>, + BTreeSet, + >, + statements: Vec, + remapped_globals_in_method: BTreeMap, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + match statement { + Statement::Instruction(ast::Instruction::Call { + mut data, + mut arguments, + }) => { + // We can safely skip checking call arguments, + // because there's simply no way to pass shared ptr + // without converting it to .b64 first + if let Some(shared_globals_used_by_callee) = + methods_to_indirectly_used_shared_globals + .get(&ast::MethodName::Func(arguments.func)) + { + for &shared_global_used_by_callee in shared_globals_used_by_callee { + let (remapped_shared_id, type_) = remapped_globals_in_method + .get(&shared_global_used_by_callee) + .unwrap_or_else(|| todo!()); + data.input_arguments + .push((type_.clone(), ast::StateSpace::Shared)); + arguments.input_arguments.push(*remapped_shared_id); + } + } + result.push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })) + } + statement => { + let new_statement = + statement.visit_map(&mut |id, + _: Option<(&ast::Type, ast::StateSpace)>, + _, + _| { + Ok::<_, TranslateError>( + if let Some((remapped_shared_id, _)) = + remapped_globals_in_method.get(&id) + { + *remapped_shared_id + } else { + id + }, + ) + })?; + result.push(new_statement); + } + } + } + Ok(result) +} diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs new file mode 100644 index 00000000..455a8c2e --- /dev/null +++ b/ptx/src/pass/convert_to_stateful_memory_access.rs @@ -0,0 +1,524 @@ +use super::*; +use ptx_parser as ast; +use std::{ + collections::{BTreeSet, HashSet}, + iter, + rc::Rc, +}; + +/* + Our goal here is to transform + .visible .entry foobar(.param .u64 input) { + .reg .b64 in_addr; + .reg .b64 in_addr2; + ld.param.u64 in_addr, [input]; + cvta.to.global.u64 in_addr2, in_addr; + } + into: + .visible .entry foobar(.param .u8 input[]) { + .reg .u8 in_addr[]; + .reg .u8 in_addr2[]; + ld.param.u8[] in_addr, [input]; + mov.u8[] in_addr2, in_addr; + } + or: + .visible .entry foobar(.reg .u8 input[]) { + .reg .u8 in_addr[]; + .reg .u8 in_addr2[]; + mov.u8[] in_addr, input; + mov.u8[] in_addr2, in_addr; + } + or: + .visible .entry foobar(.param ptr input) { + .reg ptr in_addr; + .reg ptr in_addr2; + ld.param.ptr in_addr, [input]; + mov.ptr in_addr2, in_addr; + } +*/ +// TODO: detect more patterns (mov, call via reg, call via param) +// TODO: don't convert to ptr if the register is not ultimately used for ld/st +// TODO: once insert_mem_ssa_statements is moved to later, move this pass after +// argument expansion +// TODO: propagate out of calls and into calls +pub(super) fn run<'a, 'input>( + func_args: Rc>>, + func_body: Vec, + id_defs: &mut NumericIdResolver<'a>, +) -> Result< + ( + Rc>>, + Vec, + ), + TranslateError, +> { + let mut method_decl = func_args.borrow_mut(); + if !matches!(method_decl.name, ast::MethodName::Kernel(..)) { + drop(method_decl); + return Ok((func_args, func_body)); + } + if Rc::strong_count(&func_args) != 1 { + return Err(error_unreachable()); + } + let func_args_64bit = (*method_decl) + .input_arguments + .iter() + .filter_map(|arg| match arg.v_type { + ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name), + _ => None, + }) + .collect::>(); + let mut stateful_markers = Vec::new(); + let mut stateful_init_reg = HashMap::<_, Vec<_>>::new(); + for statement in func_body.iter() { + match statement { + Statement::Instruction(ast::Instruction::Cvta { + data: + ast::CvtaDetails { + state_space: ast::StateSpace::Global, + direction: ast::CvtaDirection::GenericToExplicit, + }, + arguments, + }) => { + if let (TypedOperand::Reg(dst), Some(src)) = + (arguments.dst, arguments.src.underlying_register()) + { + if is_64_bit_integer(id_defs, src) && is_64_bit_integer(id_defs, dst) { + stateful_markers.push((dst, src)); + } + } + } + Statement::Instruction(ast::Instruction::Ld { + data: + ast::LdDetails { + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::U64), + .. + }, + arguments, + }) + | Statement::Instruction(ast::Instruction::Ld { + data: + ast::LdDetails { + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::S64), + .. + }, + arguments, + }) + | Statement::Instruction(ast::Instruction::Ld { + data: + ast::LdDetails { + state_space: ast::StateSpace::Param, + typ: ast::Type::Scalar(ast::ScalarType::B64), + .. + }, + arguments, + }) => { + if let (TypedOperand::Reg(dst), Some(src)) = + (arguments.dst, arguments.src.underlying_register()) + { + if func_args_64bit.contains(&src) { + multi_hash_map_append(&mut stateful_init_reg, dst, src); + } + } + } + _ => {} + } + } + if stateful_markers.len() == 0 { + drop(method_decl); + return Ok((func_args, func_body)); + } + let mut func_args_ptr = HashSet::new(); + let mut regs_ptr_current = HashSet::new(); + for (dst, src) in stateful_markers { + if let Some(func_args) = stateful_init_reg.get(&src) { + for a in func_args { + func_args_ptr.insert(*a); + regs_ptr_current.insert(src); + regs_ptr_current.insert(dst); + } + } + } + // BTreeSet here to have a stable order of iteration, + // unfortunately our tests rely on it + let mut regs_ptr_seen = BTreeSet::new(); + while regs_ptr_current.len() > 0 { + let mut regs_ptr_new = HashSet::new(); + for statement in func_body.iter() { + match statement { + Statement::Instruction(ast::Instruction::Add { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::U64, + saturate: false, + }), + arguments, + }) + | Statement::Instruction(ast::Instruction::Add { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, + saturate: false, + }), + arguments, + }) => { + // TODO: don't mark result of double pointer sub or double + // pointer add as ptr result + if let (TypedOperand::Reg(dst), Some(src1)) = + (arguments.dst, arguments.src1.underlying_register()) + { + if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) { + regs_ptr_new.insert(dst); + } + } else if let (TypedOperand::Reg(dst), Some(src2)) = + (arguments.dst, arguments.src2.underlying_register()) + { + if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) { + regs_ptr_new.insert(dst); + } + } + } + + Statement::Instruction(ast::Instruction::Sub { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::U64, + saturate: false, + }), + arguments, + }) + | Statement::Instruction(ast::Instruction::Sub { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, + saturate: false, + }), + arguments, + }) => { + // TODO: don't mark result of double pointer sub or double + // pointer add as ptr result + if let (TypedOperand::Reg(dst), Some(src1)) = + (arguments.dst, arguments.src1.underlying_register()) + { + if regs_ptr_current.contains(&src1) && !regs_ptr_seen.contains(&src1) { + regs_ptr_new.insert(dst); + } + } else if let (TypedOperand::Reg(dst), Some(src2)) = + (arguments.dst, arguments.src2.underlying_register()) + { + if regs_ptr_current.contains(&src2) && !regs_ptr_seen.contains(&src2) { + regs_ptr_new.insert(dst); + } + } + } + _ => {} + } + } + for id in regs_ptr_current { + regs_ptr_seen.insert(id); + } + regs_ptr_current = regs_ptr_new; + } + drop(regs_ptr_current); + let mut remapped_ids = HashMap::new(); + let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); + for reg in regs_ptr_seen { + let new_id = id_defs.register_variable( + ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Reg, + ); + result.push(Statement::Variable(ast::Variable { + align: None, + name: new_id, + array_init: Vec::new(), + v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + state_space: ast::StateSpace::Reg, + })); + remapped_ids.insert(reg, new_id); + } + for arg in (*method_decl).input_arguments.iter_mut() { + if !func_args_ptr.contains(&arg.name) { + continue; + } + let new_id = id_defs.register_variable( + ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global), + ast::StateSpace::Param, + ); + let old_name = arg.name; + arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global); + arg.name = new_id; + remapped_ids.insert(old_name, new_id); + } + for statement in func_body { + match statement { + l @ Statement::Label(_) => result.push(l), + c @ Statement::Conditional(_) => result.push(c), + c @ Statement::Constant(..) => result.push(c), + Statement::Variable(var) => { + if !remapped_ids.contains_key(&var.name) { + result.push(Statement::Variable(var)); + } + } + Statement::Instruction(ast::Instruction::Add { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::U64, + saturate: false, + }), + arguments, + }) + | Statement::Instruction(ast::Instruction::Add { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, + saturate: false, + }), + arguments, + }) if is_add_ptr_direct(&remapped_ids, &arguments) => { + let (ptr, offset) = match arguments.src1.underlying_register() { + Some(src1) if remapped_ids.contains_key(&src1) => { + (remapped_ids.get(&src1).unwrap(), arguments.src2) + } + Some(src2) if remapped_ids.contains_key(&src2) => { + (remapped_ids.get(&src2).unwrap(), arguments.src1) + } + _ => return Err(error_unreachable()), + }; + let dst = arguments.dst.unwrap_reg()?; + result.push(Statement::PtrAccess(PtrAccess { + underlying_type: ast::Type::Scalar(ast::ScalarType::U8), + state_space: ast::StateSpace::Global, + dst: *remapped_ids.get(&dst).unwrap(), + ptr_src: *ptr, + offset_src: offset, + })) + } + Statement::Instruction(ast::Instruction::Sub { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::U64, + saturate: false, + }), + arguments, + }) + | Statement::Instruction(ast::Instruction::Sub { + data: + ast::ArithDetails::Integer(ast::ArithInteger { + type_: ast::ScalarType::S64, + saturate: false, + }), + arguments, + }) if is_sub_ptr_direct(&remapped_ids, &arguments) => { + let (ptr, offset) = match arguments.src1.underlying_register() { + Some(ref src1) => (remapped_ids.get(src1).unwrap(), arguments.src2), + _ => return Err(error_unreachable()), + }; + let offset_neg = id_defs.register_intermediate(Some(( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ))); + result.push(Statement::Instruction(ast::Instruction::Neg { + data: ast::TypeFtz { + type_: ast::ScalarType::S64, + flush_to_zero: None, + }, + arguments: ast::NegArgs { + src: offset, + dst: TypedOperand::Reg(offset_neg), + }, + })); + let dst = arguments.dst.unwrap_reg()?; + result.push(Statement::PtrAccess(PtrAccess { + underlying_type: ast::Type::Scalar(ast::ScalarType::U8), + state_space: ast::StateSpace::Global, + dst: *remapped_ids.get(&dst).unwrap(), + ptr_src: *ptr, + offset_src: TypedOperand::Reg(offset_neg), + })) + } + inst @ Statement::Instruction(_) => { + let mut post_statements = Vec::new(); + let new_statement = inst.visit_map(&mut FnVisitor::new( + |operand, type_space, is_dst, relaxed_conversion| { + convert_to_stateful_memory_access_postprocess( + id_defs, + &remapped_ids, + &mut result, + &mut post_statements, + operand, + type_space, + is_dst, + relaxed_conversion, + ) + }, + ))?; + result.push(new_statement); + result.extend(post_statements); + } + repack @ Statement::RepackVector(_) => { + let mut post_statements = Vec::new(); + let new_statement = repack.visit_map(&mut FnVisitor::new( + |operand, type_space, is_dst, relaxed_conversion| { + convert_to_stateful_memory_access_postprocess( + id_defs, + &remapped_ids, + &mut result, + &mut post_statements, + operand, + type_space, + is_dst, + relaxed_conversion, + ) + }, + ))?; + result.push(new_statement); + result.extend(post_statements); + } + _ => return Err(error_unreachable()), + } + } + drop(method_decl); + Ok((func_args, result)) +} + +fn is_64_bit_integer(id_defs: &NumericIdResolver, id: SpirvWord) -> bool { + match id_defs.get_typed(id) { + Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _)) + | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _)) + | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true, + _ => false, + } +} + +fn is_add_ptr_direct( + remapped_ids: &HashMap, + arg: &ast::AddArgs, +) -> bool { + match arg.dst { + TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { + return false + } + TypedOperand::Reg(dst) => { + if !remapped_ids.contains_key(&dst) { + return false; + } + if let Some(ref src1_reg) = arg.src1.underlying_register() { + if remapped_ids.contains_key(src1_reg) { + // don't trigger optimization when adding two pointers + if let Some(ref src2_reg) = arg.src2.underlying_register() { + return !remapped_ids.contains_key(src2_reg); + } + } + } + if let Some(ref src2_reg) = arg.src2.underlying_register() { + remapped_ids.contains_key(src2_reg) + } else { + false + } + } + } +} + +fn is_sub_ptr_direct( + remapped_ids: &HashMap, + arg: &ast::SubArgs, +) -> bool { + match arg.dst { + TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { + return false + } + TypedOperand::Reg(dst) => { + if !remapped_ids.contains_key(&dst) { + return false; + } + match arg.src1.underlying_register() { + Some(ref src1_reg) => { + if remapped_ids.contains_key(src1_reg) { + // don't trigger optimization when subtracting two pointers + arg.src2 + .underlying_register() + .map_or(true, |ref src2_reg| !remapped_ids.contains_key(src2_reg)) + } else { + false + } + } + None => false, + } + } + } +} + +fn convert_to_stateful_memory_access_postprocess( + id_defs: &mut NumericIdResolver, + remapped_ids: &HashMap, + result: &mut Vec, + post_statements: &mut Vec, + operand: TypedOperand, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_conversion: bool, +) -> Result { + operand.map(|operand, _| { + Ok(match remapped_ids.get(&operand) { + Some(new_id) => { + let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?; + // TODO: readd if required + if let Some((expected_type, expected_space)) = type_space { + let implicit_conversion = if relaxed_conversion { + if is_dst { + super::insert_implicit_conversions::should_convert_relaxed_dst_wrapper + } else { + super::insert_implicit_conversions::should_convert_relaxed_src_wrapper + } + } else { + super::insert_implicit_conversions::default_implicit_conversion + }; + if implicit_conversion( + (new_operand_space, &new_operand_type), + (expected_space, expected_type), + ) + .is_ok() + { + return Ok(*new_id); + } + } + let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?; + let converting_id = id_defs + .register_intermediate(Some((old_operand_type.clone(), old_operand_space))); + let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) { + ConversionKind::Default + } else { + ConversionKind::PtrToPtr + }; + if is_dst { + post_statements.push(Statement::Conversion(ImplicitConversion { + src: converting_id, + dst: *new_id, + from_type: old_operand_type, + from_space: old_operand_space, + to_type: new_operand_type, + to_space: new_operand_space, + kind, + })); + converting_id + } else { + result.push(Statement::Conversion(ImplicitConversion { + src: *new_id, + dst: converting_id, + from_type: new_operand_type, + from_space: new_operand_space, + to_type: old_operand_type, + to_space: old_operand_space, + kind, + })); + converting_id + } + } + None => operand, + }) + }) +} diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs new file mode 100644 index 00000000..550c662e --- /dev/null +++ b/ptx/src/pass/convert_to_typed.rs @@ -0,0 +1,138 @@ +use super::*; +use ptx_parser as ast; + +pub(crate) fn run( + func: Vec, + fn_defs: &GlobalFnDeclResolver, + id_defs: &mut NumericIdResolver, +) -> Result, TranslateError> { + let mut result = Vec::::with_capacity(func.len()); + for s in func { + match s { + Statement::Instruction(inst) => match inst { + ast::Instruction::Mov { + data, + arguments: + ast::MovArgs { + dst: ast::ParsedOperand::Reg(dst_reg), + src: ast::ParsedOperand::Reg(src_reg), + }, + } if fn_defs.fns.contains_key(&src_reg) => { + if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { + return Err(error_mismatched_type()); + } + result.push(TypedStatement::FunctionPointer(FunctionPointerDetails { + dst: dst_reg, + src: src_reg, + })); + } + ast::Instruction::Call { data, arguments } => { + let resolver = fn_defs.get_fn_sig_resolver(arguments.func)?; + let resolved_call = resolver.resolve_in_spirv_repr(data, arguments)?; + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let reresolved_call = + Statement::Instruction(ast::visit_map(resolved_call, &mut visitor)?); + visitor.func.push(reresolved_call); + visitor.func.extend(visitor.post_stmts); + } + inst => { + let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); + let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?); + visitor.func.push(instruction); + visitor.func.extend(visitor.post_stmts); + } + }, + Statement::Label(i) => result.push(Statement::Label(i)), + Statement::Variable(v) => result.push(Statement::Variable(v)), + Statement::Conditional(c) => result.push(Statement::Conditional(c)), + _ => return Err(error_unreachable()), + } + } + Ok(result) +} + +struct VectorRepackVisitor<'a, 'b> { + func: &'b mut Vec, + id_def: &'b mut NumericIdResolver<'a>, + post_stmts: Option, +} + +impl<'a, 'b> VectorRepackVisitor<'a, 'b> { + fn new(func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>) -> Self { + VectorRepackVisitor { + func, + id_def, + post_stmts: None, + } + } + + fn convert_vector( + &mut self, + is_dst: bool, + relaxed_type_check: bool, + typ: &ast::Type, + state_space: ast::StateSpace, + idx: Vec, + ) -> Result { + // mov.u32 foobar, {a,b}; + let scalar_t = match typ { + ast::Type::Vector(_, scalar_t) => *scalar_t, + _ => return Err(error_mismatched_type()), + }; + let temp_vec = self + .id_def + .register_intermediate(Some((typ.clone(), state_space))); + let statement = Statement::RepackVector(RepackVectorDetails { + is_extract: is_dst, + typ: scalar_t, + packed: temp_vec, + unpacked: idx, + relaxed_type_check, + }); + if is_dst { + self.post_stmts = Some(statement); + } else { + self.func.push(statement); + } + Ok(temp_vec) + } +} + +impl<'a, 'b> ast::VisitorMap, TypedOperand, TranslateError> + for VectorRepackVisitor<'a, 'b> +{ + fn visit_ident( + &mut self, + ident: SpirvWord, + _: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + _: bool, + _: bool, + ) -> Result { + Ok(ident) + } + + fn visit( + &mut self, + op: ast::ParsedOperand, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + Ok(match op { + ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg), + ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), + ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x), + ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), + ast::ParsedOperand::VecPack(vec) => { + let (type_, space) = type_space.ok_or_else(|| error_mismatched_type())?; + TypedOperand::Reg(self.convert_vector( + is_dst, + relaxed_type_check, + type_, + space, + vec, + )?) + } + }) + } +} diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs new file mode 100644 index 00000000..5147b79f --- /dev/null +++ b/ptx/src/pass/emit_spirv.rs @@ -0,0 +1,2763 @@ +use super::*; +use half::f16; +use ptx_parser as ast; +use rspirv::{binary::Assemble, dr}; +use std::{ + collections::{HashMap, HashSet}, + ffi::CString, + mem, +}; + +pub(super) fn run<'input>( + mut builder: dr::Builder, + id_defs: &GlobalStringIdResolver<'input>, + call_map: MethodsCallMap<'input>, + denorm_information: HashMap< + ptx_parser::MethodName, + HashMap, + >, + directives: Vec>, +) -> Result<(dr::Module, HashMap, CString), TranslateError> { + builder.set_version(1, 3); + emit_capabilities(&mut builder); + emit_extensions(&mut builder); + let opencl_id = emit_opencl_import(&mut builder); + emit_memory_model(&mut builder); + let mut map = TypeWordMap::new(&mut builder); + //emit_builtins(&mut builder, &mut map, &id_defs); + let mut kernel_info = HashMap::new(); + let (build_options, should_flush_denorms) = + emit_denorm_build_string(&call_map, &denorm_information); + let (directives, globals_use_map) = get_globals_use_map(directives); + emit_directives( + &mut builder, + &mut map, + &id_defs, + opencl_id, + should_flush_denorms, + &call_map, + globals_use_map, + directives, + &mut kernel_info, + )?; + Ok((builder.module(), kernel_info, build_options)) +} + +fn emit_capabilities(builder: &mut dr::Builder) { + builder.capability(spirv::Capability::GenericPointer); + builder.capability(spirv::Capability::Linkage); + builder.capability(spirv::Capability::Addresses); + builder.capability(spirv::Capability::Kernel); + builder.capability(spirv::Capability::Int8); + builder.capability(spirv::Capability::Int16); + builder.capability(spirv::Capability::Int64); + builder.capability(spirv::Capability::Float16); + builder.capability(spirv::Capability::Float64); + builder.capability(spirv::Capability::DenormFlushToZero); + // TODO: re-enable when Intel float control extension works + //builder.capability(spirv::Capability::FunctionFloatControlINTEL); +} + +// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html +fn emit_extensions(builder: &mut dr::Builder) { + // TODO: re-enable when Intel float control extension works + //builder.extension("SPV_INTEL_float_controls2"); + builder.extension("SPV_KHR_float_controls"); + builder.extension("SPV_KHR_no_integer_wrap_decoration"); +} + +fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word { + builder.ext_inst_import("OpenCL.std") +} + +fn emit_memory_model(builder: &mut dr::Builder) { + builder.memory_model( + spirv::AddressingModel::Physical64, + spirv::MemoryModel::OpenCL, + ); +} + +struct TypeWordMap { + void: spirv::Word, + complex: HashMap, + constants: HashMap<(SpirvType, u64), SpirvWord>, +} + +impl TypeWordMap { + fn new(b: &mut dr::Builder) -> TypeWordMap { + let void = b.type_void(None); + TypeWordMap { + void: void, + complex: HashMap::::new(), + constants: HashMap::new(), + } + } + + fn void(&self) -> spirv::Word { + self.void + } + + fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> SpirvWord { + let key: SpirvScalarKey = t.into(); + self.get_or_add_spirv_scalar(b, key) + } + + fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> SpirvWord { + *self.complex.entry(SpirvType::Base(key)).or_insert_with(|| { + SpirvWord(match key { + SpirvScalarKey::B8 => b.type_int(None, 8, 0), + SpirvScalarKey::B16 => b.type_int(None, 16, 0), + SpirvScalarKey::B32 => b.type_int(None, 32, 0), + SpirvScalarKey::B64 => b.type_int(None, 64, 0), + SpirvScalarKey::F16 => b.type_float(None, 16), + SpirvScalarKey::F32 => b.type_float(None, 32), + SpirvScalarKey::F64 => b.type_float(None, 64), + SpirvScalarKey::Pred => b.type_bool(None), + SpirvScalarKey::F16x2 => todo!(), + }) + }) + } + + fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> SpirvWord { + match t { + SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key), + SpirvType::Pointer(ref typ, storage) => { + let base = self.get_or_add(b, *typ.clone()); + *self + .complex + .entry(t) + .or_insert_with(|| SpirvWord(b.type_pointer(None, storage, base.0))) + } + SpirvType::Vector(typ, len) => { + let base = self.get_or_add_spirv_scalar(b, typ); + *self + .complex + .entry(t) + .or_insert_with(|| SpirvWord(b.type_vector(None, base.0, len as u32))) + } + SpirvType::Array(typ, array_dimensions) => { + let (base_type, length) = match &*array_dimensions { + &[] => { + return self.get_or_add(b, SpirvType::Base(typ)); + } + &[len] => { + let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); + let base = self.get_or_add_spirv_scalar(b, typ); + let len_const = b.constant_u32(u32_type.0, None, len); + (base, len_const) + } + array_dimensions => { + let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); + let base = self + .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec())); + let len_const = b.constant_u32(u32_type.0, None, array_dimensions[0]); + (base, len_const) + } + }; + *self + .complex + .entry(SpirvType::Array(typ, array_dimensions)) + .or_insert_with(|| SpirvWord(b.type_array(None, base_type.0, length))) + } + SpirvType::Func(ref out_params, ref in_params) => { + let out_t = match out_params { + Some(p) => self.get_or_add(b, *p.clone()), + None => SpirvWord(self.void()), + }; + let in_t = in_params + .iter() + .map(|t| self.get_or_add(b, t.clone()).0) + .collect::>(); + *self + .complex + .entry(t) + .or_insert_with(|| SpirvWord(b.type_function(None, out_t.0, in_t))) + } + SpirvType::Struct(ref underlying) => { + let underlying_ids = underlying + .iter() + .map(|t| self.get_or_add_spirv_scalar(b, *t).0) + .collect::>(); + *self + .complex + .entry(t) + .or_insert_with(|| SpirvWord(b.type_struct(None, underlying_ids))) + } + } + } + + fn get_or_add_fn( + &mut self, + b: &mut dr::Builder, + in_params: impl Iterator, + mut out_params: impl ExactSizeIterator, + ) -> (SpirvWord, SpirvWord) { + let (out_args, out_spirv_type) = if out_params.len() == 0 { + (None, SpirvWord(self.void())) + } else if out_params.len() == 1 { + let arg_as_key = out_params.next().unwrap(); + ( + Some(Box::new(arg_as_key.clone())), + self.get_or_add(b, arg_as_key), + ) + } else { + // TODO: support multiple return values + todo!() + }; + ( + out_spirv_type, + self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::>())), + ) + } + + fn get_or_add_constant( + &mut self, + b: &mut dr::Builder, + typ: &ast::Type, + init: &[u8], + ) -> Result { + Ok(match typ { + ast::Type::Scalar(t) => match t { + ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| b.constant_u32(result_type, None, v as u32), + ), + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| b.constant_u32(result_type, None, v as u32), + ), + ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| b.constant_u32(result_type, None, v), + ), + ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self + .get_or_add_constant_single::( + b, + *t, + init, + |v| v, + |b, result_type, v| b.constant_u64(result_type, None, v), + ), + ast::ScalarType::F16 => self.get_or_add_constant_single::( + b, + *t, + init, + |v| unsafe { mem::transmute::<_, u16>(v) } as u64, + |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()), + ), + ast::ScalarType::F32 => self.get_or_add_constant_single::( + b, + *t, + init, + |v| unsafe { mem::transmute::<_, u32>(v) } as u64, + |b, result_type, v| b.constant_f32(result_type, None, v), + ), + ast::ScalarType::F64 => self.get_or_add_constant_single::( + b, + *t, + init, + |v| unsafe { mem::transmute::<_, u64>(v) }, + |b, result_type, v| b.constant_f64(result_type, None, v), + ), + ast::ScalarType::F16x2 => return Err(TranslateError::Todo), + ast::ScalarType::Pred => self.get_or_add_constant_single::( + b, + *t, + init, + |v| v as u64, + |b, result_type, v| { + if v == 0 { + b.constant_false(result_type, None) + } else { + b.constant_true(result_type, None) + } + }, + ), + ast::ScalarType::S16x2 + | ast::ScalarType::U16x2 + | ast::ScalarType::BF16 + | ast::ScalarType::BF16x2 + | ast::ScalarType::B128 => todo!(), + }, + ast::Type::Vector(len, typ) => { + let result_type = + self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len)); + let size_of_t = typ.size_of(); + let components = (0..*len) + .map(|x| { + Ok::<_, TranslateError>( + self.get_or_add_constant( + b, + &ast::Type::Scalar(*typ), + &init[((size_of_t as usize) * (x as usize))..], + )? + .0, + ) + }) + .collect::, _>>()?; + SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) + } + ast::Type::Array(_, typ, dims) => match dims.as_slice() { + [] => return Err(error_unreachable()), + [dim] => { + let result_type = self + .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim])); + let size_of_t = typ.size_of(); + let components = (0..*dim) + .map(|x| { + Ok::<_, TranslateError>( + self.get_or_add_constant( + b, + &ast::Type::Scalar(*typ), + &init[((size_of_t as usize) * (x as usize))..], + )? + .0, + ) + }) + .collect::, _>>()?; + SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) + } + [first_dim, rest @ ..] => { + let result_type = self.get_or_add( + b, + SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()), + ); + let size_of_t = rest + .iter() + .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y)); + let components = (0..*first_dim) + .map(|x| { + Ok::<_, TranslateError>( + self.get_or_add_constant( + b, + &ast::Type::Array(None, *typ, rest.to_vec()), + &init[((size_of_t as usize) * (x as usize))..], + )? + .0, + ) + }) + .collect::, _>>()?; + SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) + } + }, + ast::Type::Pointer(..) => return Err(error_unreachable()), + }) + } + + fn get_or_add_constant_single< + T: Copy, + CastAsU64: FnOnce(T) -> u64, + InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word, + >( + &mut self, + b: &mut dr::Builder, + key: ast::ScalarType, + init: &[u8], + cast: CastAsU64, + f: InsertConstant, + ) -> SpirvWord { + let value = unsafe { *(init.as_ptr() as *const T) }; + let value_64 = cast(value); + let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64); + match self.constants.get(&ht_key) { + Some(value) => *value, + None => { + let spirv_type = self.get_or_add_scalar(b, key); + let result = SpirvWord(f(b, spirv_type.0, value)); + self.constants.insert(ht_key, result); + result + } + } + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +enum SpirvType { + Base(SpirvScalarKey), + Vector(SpirvScalarKey, u8), + Array(SpirvScalarKey, Vec), + Pointer(Box, spirv::StorageClass), + Func(Option>, Vec), + Struct(Vec), +} + +impl SpirvType { + fn new(t: ast::Type) -> Self { + match t { + ast::Type::Scalar(t) => SpirvType::Base(t.into()), + ast::Type::Vector(len, typ) => SpirvType::Vector(typ.into(), len), + ast::Type::Array(_, t, len) => SpirvType::Array(t.into(), len), + ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( + Box::new(SpirvType::Base(pointer_t.into())), + space_to_spirv(space), + ), + } + } + + fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self { + let key = Self::new(t); + SpirvType::Pointer(Box::new(key), outer_space) + } +} + +impl From for SpirvType { + fn from(t: ast::ScalarType) -> Self { + SpirvType::Base(t.into()) + } +} +// SPIR-V integer type definitions are signless, more below: +// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers +// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a +#[derive(PartialEq, Eq, Hash, Clone, Copy)] +enum SpirvScalarKey { + B8, + B16, + B32, + B64, + F16, + F32, + F64, + Pred, + F16x2, +} + +impl From for SpirvScalarKey { + fn from(t: ast::ScalarType) -> Self { + match t { + ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8, + ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { + SpirvScalarKey::B16 + } + ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => { + SpirvScalarKey::B32 + } + ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => { + SpirvScalarKey::B64 + } + ast::ScalarType::F16 => SpirvScalarKey::F16, + ast::ScalarType::F32 => SpirvScalarKey::F32, + ast::ScalarType::F64 => SpirvScalarKey::F64, + ast::ScalarType::F16x2 => SpirvScalarKey::F16x2, + ast::ScalarType::Pred => SpirvScalarKey::Pred, + ast::ScalarType::S16x2 + | ast::ScalarType::U16x2 + | ast::ScalarType::BF16 + | ast::ScalarType::BF16x2 + | ast::ScalarType::B128 => todo!(), + } + } +} + +fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass { + match this { + ast::StateSpace::Const => spirv::StorageClass::UniformConstant, + ast::StateSpace::Generic => spirv::StorageClass::Generic, + ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, + ast::StateSpace::Local => spirv::StorageClass::Function, + ast::StateSpace::Shared => spirv::StorageClass::Workgroup, + ast::StateSpace::Param => spirv::StorageClass::Function, + ast::StateSpace::Reg => spirv::StorageClass::Function, + ast::StateSpace::Sreg => spirv::StorageClass::Input, + ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc + | ast::StateSpace::SharedCluster + | ast::StateSpace::SharedCta => todo!(), + } +} + +// TODO: remove this once we have pef-function support for denorms +fn emit_denorm_build_string<'input>( + call_map: &MethodsCallMap, + denorm_information: &HashMap< + ast::MethodName<'input, SpirvWord>, + HashMap, + >, +) -> (CString, bool) { + let denorm_counts = denorm_information + .iter() + .map(|(method, meth_denorm)| { + let f16_count = meth_denorm + .get(&(mem::size_of::() as u8)) + .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) + .1; + let f32_count = meth_denorm + .get(&(mem::size_of::() as u8)) + .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) + .1; + (method, (f16_count + f32_count)) + }) + .collect::>(); + let mut flush_over_preserve = 0; + for (kernel, children) in call_map.kernels() { + flush_over_preserve += *denorm_counts + .get(&ast::MethodName::Kernel(kernel)) + .unwrap_or(&0); + for child_fn in children { + flush_over_preserve += *denorm_counts + .get(&ast::MethodName::Func(*child_fn)) + .unwrap_or(&0); + } + } + if flush_over_preserve > 0 { + ( + CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap(), + true, + ) + } else { + (CString::new("-ze-take-global-address").unwrap(), false) + } +} + +fn get_globals_use_map<'input>( + directives: Vec>, +) -> ( + Vec>, + HashMap, HashSet>, +) { + let mut known_globals = HashSet::new(); + for directive in directives.iter() { + match directive { + Directive::Variable(_, ast::Variable { name, .. }) => { + known_globals.insert(*name); + } + Directive::Method(..) => {} + } + } + let mut symbol_uses_map = HashMap::new(); + let directives = directives + .into_iter() + .map(|directive| match directive { + Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => directive, + Directive::Method(Function { + func_decl, + body: Some(mut statements), + globals, + import_as, + tuning, + linkage, + }) => { + let method_name = func_decl.borrow().name; + statements = statements + .into_iter() + .map(|statement| { + statement.visit_map( + &mut |symbol, _: Option<(&ast::Type, ast::StateSpace)>, _, _| { + if known_globals.contains(&symbol) { + multi_hash_map_append( + &mut symbol_uses_map, + method_name, + symbol, + ); + } + Ok::<_, TranslateError>(symbol) + }, + ) + }) + .collect::, _>>() + .unwrap(); + Directive::Method(Function { + func_decl, + body: Some(statements), + globals, + import_as, + tuning, + linkage, + }) + } + }) + .collect::>(); + (directives, symbol_uses_map) +} + +fn emit_directives<'input>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + id_defs: &GlobalStringIdResolver<'input>, + opencl_id: spirv::Word, + should_flush_denorms: bool, + call_map: &MethodsCallMap<'input>, + globals_use_map: HashMap, HashSet>, + directives: Vec>, + kernel_info: &mut HashMap, +) -> Result<(), TranslateError> { + let empty_body = Vec::new(); + for d in directives.iter() { + match d { + Directive::Variable(linking, var) => { + emit_variable(builder, map, id_defs, *linking, &var)?; + } + Directive::Method(f) => { + let f_body = match &f.body { + Some(f) => f, + None => { + if f.linkage.contains(ast::LinkingDirective::EXTERN) { + &empty_body + } else { + continue; + } + } + }; + for var in f.globals.iter() { + emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; + } + let func_decl = (*f.func_decl).borrow(); + let fn_id = emit_function_header( + builder, + map, + &id_defs, + &*func_decl, + call_map, + &globals_use_map, + kernel_info, + )?; + if matches!(func_decl.name, ast::MethodName::Kernel(_)) { + if should_flush_denorms { + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::DenormFlushToZero, + [16], + ); + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::DenormFlushToZero, + [32], + ); + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::DenormFlushToZero, + [64], + ); + } + // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx) + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::ContractionOff, + [], + ); + for t in f.tuning.iter() { + match *t { + ast::TuningDirective::MaxNtid(nx, ny, nz) => { + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL, + [nx, ny, nz], + ); + } + ast::TuningDirective::ReqNtid(nx, ny, nz) => { + builder.execution_mode( + fn_id.0, + spirv_headers::ExecutionMode::LocalSize, + [nx, ny, nz], + ); + } + // Too architecture specific + ast::TuningDirective::MaxNReg(..) + | ast::TuningDirective::MinNCtaPerSm(..) => {} + } + } + } + emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?; + emit_function_linkage(builder, id_defs, f, fn_id)?; + builder.select_block(None)?; + builder.end_function()?; + } + } + } + Ok(()) +} + +fn emit_variable<'input>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + id_defs: &GlobalStringIdResolver<'input>, + linking: ast::LinkingDirective, + var: &ast::Variable, +) -> Result<(), TranslateError> { + let (must_init, st_class) = match var.state_space { + ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { + (false, spirv::StorageClass::Function) + } + ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup), + ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), + ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant), + ast::StateSpace::Generic => todo!(), + ast::StateSpace::Sreg => todo!(), + ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc + | ast::StateSpace::SharedCluster + | ast::StateSpace::SharedCta => todo!(), + }; + let initalizer = if var.array_init.len() > 0 { + Some( + map.get_or_add_constant( + builder, + &ast::Type::from(var.v_type.clone()), + &*var.array_init, + )? + .0, + ) + } else if must_init { + let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone())); + Some(builder.constant_null(type_id.0, None)) + } else { + None + }; + let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class)); + builder.variable(ptr_type_id.0, Some(var.name.0), st_class, initalizer); + if let Some(align) = var.align { + builder.decorate( + var.name.0, + spirv::Decoration::Alignment, + [dr::Operand::LiteralInt32(align)].iter().cloned(), + ); + } + if var.state_space != ast::StateSpace::Shared + || !linking.contains(ast::LinkingDirective::EXTERN) + { + emit_linking_decoration(builder, id_defs, None, var.name, linking); + } + Ok(()) +} + +fn emit_function_header<'input>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + defined_globals: &GlobalStringIdResolver<'input>, + func_decl: &ast::MethodDeclaration<'input, SpirvWord>, + call_map: &MethodsCallMap<'input>, + globals_use_map: &HashMap, HashSet>, + kernel_info: &mut HashMap, +) -> Result { + if let ast::MethodName::Kernel(name) = func_decl.name { + let args_lens = func_decl + .input_arguments + .iter() + .map(|param| { + ( + type_size_of(¶m.v_type), + matches!(param.v_type, ast::Type::Pointer(..)), + ) + }) + .collect(); + kernel_info.insert( + name.to_string(), + KernelInfo { + arguments_sizes: args_lens, + uses_shared_mem: func_decl.shared_mem.is_some(), + }, + ); + } + let (ret_type, func_type) = get_function_type( + builder, + map, + effective_input_arguments(func_decl).map(|(_, typ)| typ), + &func_decl.return_arguments, + ); + let fn_id = match func_decl.name { + ast::MethodName::Kernel(name) => { + let fn_id = defined_globals.get_id(name)?; + let interface = globals_use_map + .get(&ast::MethodName::Kernel(name)) + .into_iter() + .flatten() + .copied() + .chain({ + call_map + .get_kernel_children(name) + .copied() + .flat_map(|subfunction| { + globals_use_map + .get(&ast::MethodName::Func(subfunction)) + .into_iter() + .flatten() + .copied() + }) + .into_iter() + }) + .map(|word| word.0) + .collect::>(); + builder.entry_point(spirv::ExecutionModel::Kernel, fn_id.0, name, interface); + fn_id + } + ast::MethodName::Func(name) => name, + }; + builder.begin_function( + ret_type.0, + Some(fn_id.0), + spirv::FunctionControl::NONE, + func_type.0, + )?; + for (name, typ) in effective_input_arguments(func_decl) { + let result_type = map.get_or_add(builder, typ); + builder.function_parameter(Some(name.0), result_type.0)?; + } + Ok(fn_id) +} + +pub fn type_size_of(this: &ast::Type) -> usize { + match this { + ast::Type::Scalar(typ) => typ.size_of() as usize, + ast::Type::Vector(len, typ) => (typ.size_of() as usize) * (*len as usize), + ast::Type::Array(_, typ, len) => len + .iter() + .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), + ast::Type::Pointer(..) => mem::size_of::(), + } +} +fn emit_function_body_ops<'input>( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + id_defs: &GlobalStringIdResolver<'input>, + opencl: spirv::Word, + func: &[ExpandedStatement], +) -> Result<(), TranslateError> { + for s in func { + match s { + Statement::Label(id) => { + if builder.selected_block().is_some() { + builder.branch(id.0)?; + } + builder.begin_block(Some(id.0))?; + } + _ => { + if builder.selected_block().is_none() && builder.selected_function().is_some() { + builder.begin_block(None)?; + } + } + } + match s { + Statement::Label(_) => (), + Statement::Variable(var) => { + emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; + } + Statement::Constant(cnst) => { + let typ_id = map.get_or_add_scalar(builder, cnst.typ); + match (cnst.typ, cnst.value) { + (ast::ScalarType::B8, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32); + } + (ast::ScalarType::B16, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32); + } + (ast::ScalarType::B32, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32); + } + (ast::ScalarType::B64, ast::ImmediateValue::U64(value)) + | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => { + builder.constant_u64(typ_id.0, Some(cnst.dst.0), value); + } + (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32); + } + (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32); + } + (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32); + } + (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => { + builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as i64 as u64); + } + (ast::ScalarType::B8, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32); + } + (ast::ScalarType::B16, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32); + } + (ast::ScalarType::B32, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32); + } + (ast::ScalarType::B64, ast::ImmediateValue::S64(value)) + | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => { + builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64); + } + (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32); + } + (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32); + } + (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => { + builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32); + } + (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => { + builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64); + } + (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => { + builder.constant_f32( + typ_id.0, + Some(cnst.dst.0), + f16::from_f32(value).to_f32(), + ); + } + (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => { + builder.constant_f32(typ_id.0, Some(cnst.dst.0), value); + } + (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => { + builder.constant_f64(typ_id.0, Some(cnst.dst.0), value as f64); + } + (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => { + builder.constant_f32( + typ_id.0, + Some(cnst.dst.0), + f16::from_f64(value).to_f32(), + ); + } + (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => { + builder.constant_f32(typ_id.0, Some(cnst.dst.0), value as f32); + } + (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => { + builder.constant_f64(typ_id.0, Some(cnst.dst.0), value); + } + (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => { + let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0; + if value == 0 { + builder.constant_false(bool_type, Some(cnst.dst.0)); + } else { + builder.constant_true(bool_type, Some(cnst.dst.0)); + } + } + (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => { + let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0; + if value == 0 { + builder.constant_false(bool_type, Some(cnst.dst.0)); + } else { + builder.constant_true(bool_type, Some(cnst.dst.0)); + } + } + _ => return Err(error_mismatched_type()), + } + } + Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, + Statement::Conditional(bra) => { + builder.branch_conditional( + bra.predicate.0, + bra.if_true.0, + bra.if_false.0, + iter::empty(), + )?; + } + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { + // TODO: implement properly + let zero = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U64), + &vec_repr(0u64), + )?; + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64); + builder.copy_object(result_type.0, Some(dst.0), zero.0)?; + } + Statement::Instruction(inst) => match inst { + ast::Instruction::PrmtSlow { .. } | ast::Instruction::Trap { .. } => todo!(), + ast::Instruction::Call { data, arguments } => { + let (result_type, result_id) = + match (&*data.return_arguments, &*arguments.return_arguments) { + ([(type_, space)], [id]) => { + if *space != ast::StateSpace::Reg { + return Err(error_unreachable()); + } + ( + map.get_or_add(builder, SpirvType::new(type_.clone())).0, + Some(id.0), + ) + } + ([], []) => (map.void(), None), + _ => todo!(), + }; + let arg_list = arguments + .input_arguments + .iter() + .map(|id| id.0) + .collect::>(); + builder.function_call(result_type, result_id, arguments.func.0, arg_list)?; + } + ast::Instruction::Abs { data, arguments } => { + emit_abs(builder, map, opencl, data, arguments)? + } + // SPIR-V does not support marking jumps as guaranteed-converged + ast::Instruction::Bra { arguments, .. } => { + builder.branch(arguments.src.0)?; + } + ast::Instruction::Ld { data, arguments } => { + let mem_access = match data.qualifier { + ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, + // ld.volatile does not match Volatile OpLoad nor Relaxed OpAtomicLoad + ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, + _ => return Err(TranslateError::Todo), + }; + let result_type = + map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); + builder.load( + result_type.0, + Some(arguments.dst.0), + arguments.src.0, + Some(mem_access | spirv::MemoryAccess::ALIGNED), + [dr::Operand::LiteralInt32( + type_size_of(&ast::Type::from(data.typ.clone())) as u32, + )] + .iter() + .cloned(), + )?; + } + ast::Instruction::St { data, arguments } => { + let mem_access = match data.qualifier { + ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, + // st.volatile does not match Volatile OpStore nor Relaxed OpAtomicStore + ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, + _ => return Err(TranslateError::Todo), + }; + builder.store( + arguments.src1.0, + arguments.src2.0, + Some(mem_access | spirv::MemoryAccess::ALIGNED), + [dr::Operand::LiteralInt32( + type_size_of(&ast::Type::from(data.typ.clone())) as u32, + )] + .iter() + .cloned(), + )?; + } + // SPIR-V does not support ret as guaranteed-converged + ast::Instruction::Ret { .. } => builder.ret()?, + ast::Instruction::Mov { data, arguments } => { + let result_type = + map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); + builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?; + } + ast::Instruction::Mul { data, arguments } => match data { + ast::MulDetails::Integer { type_, control } => { + emit_mul_int(builder, map, opencl, *type_, *control, arguments)? + } + ast::MulDetails::Float(ref ctr) => { + emit_mul_float(builder, map, ctr, arguments)? + } + }, + ast::Instruction::Add { data, arguments } => match data { + ast::ArithDetails::Integer(desc) => { + emit_add_int(builder, map, desc.type_.into(), desc.saturate, arguments)? + } + ast::ArithDetails::Float(desc) => { + emit_add_float(builder, map, desc, arguments)? + } + }, + ast::Instruction::Setp { data, arguments } => { + if arguments.dst2.is_some() { + todo!() + } + emit_setp(builder, map, data, arguments)?; + } + ast::Instruction::Not { data, arguments } => { + let result_type = map.get_or_add(builder, SpirvType::from(*data)); + let result_id = Some(arguments.dst.0); + let operand = arguments.src; + match data { + ast::ScalarType::Pred => { + logical_not(builder, result_type.0, result_id, operand.0) + } + _ => builder.not(result_type.0, result_id, operand.0), + }?; + } + ast::Instruction::Shl { data, arguments } => { + let full_type = ast::Type::Scalar(*data); + let size_of = type_size_of(&full_type); + let result_type = map.get_or_add(builder, SpirvType::new(full_type)); + let offset_src = insert_shift_hack(builder, map, arguments.src2.0, size_of)?; + builder.shift_left_logical( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + offset_src, + )?; + } + ast::Instruction::Shr { data, arguments } => { + let full_type = ast::ScalarType::from(data.type_); + let size_of = full_type.size_of(); + let result_type = map.get_or_add_scalar(builder, full_type).0; + let offset_src = + insert_shift_hack(builder, map, arguments.src2.0, size_of as usize)?; + match data.kind { + ptx_parser::RightShiftKind::Arithmetic => { + builder.shift_right_arithmetic( + result_type, + Some(arguments.dst.0), + arguments.src1.0, + offset_src, + )?; + } + ptx_parser::RightShiftKind::Logical => { + builder.shift_right_logical( + result_type, + Some(arguments.dst.0), + arguments.src1.0, + offset_src, + )?; + } + } + } + ast::Instruction::Cvt { data, arguments } => { + emit_cvt(builder, map, opencl, data, arguments)?; + } + ast::Instruction::Cvta { data, arguments } => { + // This would be only meaningful if const/slm/global pointers + // had a different format than generic pointers, but they don't pretty much by ptx definition + // Honestly, I have no idea why this instruction exists and is emitted by the compiler + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64); + builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?; + } + ast::Instruction::SetpBool { .. } => todo!(), + ast::Instruction::Mad { data, arguments } => match data { + ast::MadDetails::Integer { + type_, + control, + saturate, + } => { + if *saturate { + todo!() + } + if type_.kind() == ast::ScalarKind::Signed { + emit_mad_sint(builder, map, opencl, *type_, *control, arguments)? + } else { + emit_mad_uint(builder, map, opencl, *type_, *control, arguments)? + } + } + ast::MadDetails::Float(desc) => { + emit_mad_float(builder, map, opencl, desc, arguments)? + } + }, + ast::Instruction::Fma { data, arguments } => { + emit_fma_float(builder, map, opencl, data, arguments)? + } + ast::Instruction::Or { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, *data).0; + if *data == ast::ScalarType::Pred { + builder.logical_or( + result_type, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } else { + builder.bitwise_or( + result_type, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + } + ast::Instruction::Sub { data, arguments } => match data { + ast::ArithDetails::Integer(desc) => { + emit_sub_int(builder, map, desc.type_.into(), desc.saturate, arguments)?; + } + ast::ArithDetails::Float(desc) => { + emit_sub_float(builder, map, desc, arguments)?; + } + }, + ast::Instruction::Min { data, arguments } => { + emit_min(builder, map, opencl, data, arguments)?; + } + ast::Instruction::Max { data, arguments } => { + emit_max(builder, map, opencl, data, arguments)?; + } + ast::Instruction::Rcp { data, arguments } => { + emit_rcp(builder, map, opencl, data, arguments)?; + } + ast::Instruction::And { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, *data); + if *data == ast::ScalarType::Pred { + builder.logical_and( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } else { + builder.bitwise_and( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + } + ast::Instruction::Selp { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, *data); + builder.select( + result_type.0, + Some(arguments.dst.0), + arguments.src3.0, + arguments.src1.0, + arguments.src2.0, + )?; + } + // TODO: implement named barriers + ast::Instruction::Bar { data, arguments } => { + let workgroup_scope = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(spirv::Scope::Workgroup as u32), + )?; + let barrier_semantics = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr( + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + )?; + builder.control_barrier( + workgroup_scope.0, + workgroup_scope.0, + barrier_semantics.0, + )?; + } + ast::Instruction::Atom { data, arguments } => { + emit_atom(builder, map, data, arguments)?; + } + ast::Instruction::AtomCas { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, data.type_); + let memory_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(scope_to_spirv(data.scope) as u32), + )?; + let semantics_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(semantics_to_spirv(data.semantics).bits()), + )?; + builder.atomic_compare_exchange( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + memory_const.0, + semantics_const.0, + semantics_const.0, + arguments.src3.0, + arguments.src2.0, + )?; + } + ast::Instruction::Div { data, arguments } => match data { + ast::DivDetails::Unsigned(t) => { + let result_type = map.get_or_add_scalar(builder, (*t).into()); + builder.u_div( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + ast::DivDetails::Signed(t) => { + let result_type = map.get_or_add_scalar(builder, (*t).into()); + builder.s_div( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + ast::DivDetails::Float(t) => { + let result_type = map.get_or_add_scalar(builder, t.type_.into()); + builder.f_div( + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + emit_float_div_decoration(builder, arguments.dst, t.kind); + } + }, + ast::Instruction::Sqrt { data, arguments } => { + emit_sqrt(builder, map, opencl, data, arguments)?; + } + ast::Instruction::Rsqrt { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, data.type_.into()); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::rsqrt as spirv::Word, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Neg { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, data.type_); + let negate_func = if data.type_.kind() == ast::ScalarKind::Float { + dr::Builder::f_negate + } else { + dr::Builder::s_negate + }; + negate_func( + builder, + result_type.0, + Some(arguments.dst.0), + arguments.src.0, + )?; + } + ast::Instruction::Sin { arguments, .. } => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::sin as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Cos { arguments, .. } => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::cos as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Lg2 { arguments, .. } => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::log2 as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Ex2 { arguments, .. } => { + let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::exp2 as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Clz { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder.ext_inst( + result_type.0, + Some(arguments.dst.0), + opencl, + spirv::CLOp::clz as u32, + [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), + )?; + } + ast::Instruction::Brev { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder.bit_reverse(result_type.0, Some(arguments.dst.0), arguments.src.0)?; + } + ast::Instruction::Popc { data, arguments } => { + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder.bit_count(result_type.0, Some(arguments.dst.0), arguments.src.0)?; + } + ast::Instruction::Xor { data, arguments } => { + let builder_fn: fn( + &mut dr::Builder, + u32, + Option, + u32, + u32, + ) -> Result = match data { + ast::ScalarType::Pred => emit_logical_xor_spirv, + _ => dr::Builder::bitwise_xor, + }; + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder_fn( + builder, + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + ast::Instruction::Bfe { .. } + | ast::Instruction::Bfi { .. } + | ast::Instruction::Activemask { .. } => { + // Should have beeen replaced with a funciton call earlier + return Err(error_unreachable()); + } + + ast::Instruction::Rem { data, arguments } => { + let builder_fn = if data.kind() == ast::ScalarKind::Signed { + dr::Builder::s_mod + } else { + dr::Builder::u_mod + }; + let result_type = map.get_or_add_scalar(builder, (*data).into()); + builder_fn( + builder, + result_type.0, + Some(arguments.dst.0), + arguments.src1.0, + arguments.src2.0, + )?; + } + ast::Instruction::Prmt { data, arguments } => { + let control = *data as u32; + let components = [ + (control >> 0) & 0b1111, + (control >> 4) & 0b1111, + (control >> 8) & 0b1111, + (control >> 12) & 0b1111, + ]; + if components.iter().any(|&c| c > 7) { + return Err(TranslateError::Todo); + } + let vec4_b8_type = + map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4)); + let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32); + let src1_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src1.0)?; + let src2_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src2.0)?; + let dst_vector = builder.vector_shuffle( + vec4_b8_type.0, + None, + src1_vector, + src2_vector, + components, + )?; + builder.bitcast(b32_type.0, Some(arguments.dst.0), dst_vector)?; + } + ast::Instruction::Membar { data } => { + let (scope, semantics) = match data { + ast::MemScope::Cta => ( + spirv::Scope::Workgroup, + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + ast::MemScope::Gpu => ( + spirv::Scope::Device, + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + ast::MemScope::Sys => ( + spirv::Scope::CrossDevice, + spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY + | spirv::MemorySemantics::WORKGROUP_MEMORY + | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, + ), + + ast::MemScope::Cluster => todo!(), + }; + let spirv_scope = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(scope as u32), + )?; + let spirv_semantics = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(semantics), + )?; + builder.memory_barrier(spirv_scope.0, spirv_semantics.0)?; + } + }, + Statement::LoadVar(details) => { + emit_load_var(builder, map, details)?; + } + Statement::StoreVar(details) => { + let dst_ptr = match details.member_index { + Some(index) => { + let result_ptr_type = map.get_or_add( + builder, + SpirvType::pointer_to( + details.typ.clone(), + spirv::StorageClass::Function, + ), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + builder.in_bounds_access_chain( + result_ptr_type.0, + None, + details.arg.src1.0, + [index_spirv.0].iter().copied(), + )? + } + None => details.arg.src1.0, + }; + builder.store(dst_ptr, details.arg.src2.0, None, iter::empty())?; + } + Statement::RetValue(_, id) => { + builder.ret_value(id.0)?; + } + Statement::PtrAccess(PtrAccess { + underlying_type, + state_space, + dst, + ptr_src, + offset_src, + }) => { + let u8_pointer = map.get_or_add( + builder, + SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)), + ); + let result_type = map.get_or_add( + builder, + SpirvType::pointer_to(underlying_type.clone(), space_to_spirv(*state_space)), + ); + let ptr_src_u8 = builder.bitcast(u8_pointer.0, None, ptr_src.0)?; + let temp = builder.in_bounds_ptr_access_chain( + u8_pointer.0, + None, + ptr_src_u8, + offset_src.0, + iter::empty(), + )?; + builder.bitcast(result_type.0, Some(dst.0), temp)?; + } + Statement::RepackVector(repack) => { + if repack.is_extract { + let scalar_type = map.get_or_add_scalar(builder, repack.typ); + for (index, dst_id) in repack.unpacked.iter().enumerate() { + builder.composite_extract( + scalar_type.0, + Some(dst_id.0), + repack.packed.0, + [index as u32].iter().copied(), + )?; + } + } else { + let vector_type = map.get_or_add( + builder, + SpirvType::Vector( + SpirvScalarKey::from(repack.typ), + repack.unpacked.len() as u8, + ), + ); + let mut temp_vec = builder.undef(vector_type.0, None); + for (index, src_id) in repack.unpacked.iter().enumerate() { + temp_vec = builder.composite_insert( + vector_type.0, + None, + src_id.0, + temp_vec, + [index as u32].iter().copied(), + )?; + } + builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?; + } + } + } + } + Ok(()) +} + +fn emit_function_linkage<'input>( + builder: &mut dr::Builder, + id_defs: &GlobalStringIdResolver<'input>, + f: &Function, + fn_name: SpirvWord, +) -> Result<(), TranslateError> { + if f.linkage == ast::LinkingDirective::NONE { + return Ok(()); + }; + let linking_name = match f.func_decl.borrow().name { + // According to SPIR-V rules linkage attributes are invalid on kernels + ast::MethodName::Kernel(..) => return Ok(()), + ast::MethodName::Func(fn_id) => f.import_as.as_deref().map_or_else( + || match id_defs.reverse_variables.get(&fn_id) { + Some(fn_name) => Ok(fn_name), + None => Err(error_unknown_symbol()), + }, + Result::Ok, + )?, + }; + emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage); + Ok(()) +} + +fn get_function_type( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + spirv_input: impl Iterator, + spirv_output: &[ast::Variable], +) -> (SpirvWord, SpirvWord) { + map.get_or_add_fn( + builder, + spirv_input, + spirv_output + .iter() + .map(|var| SpirvType::new(var.v_type.clone())), + ) +} + +fn emit_linking_decoration<'input>( + builder: &mut dr::Builder, + id_defs: &GlobalStringIdResolver<'input>, + name_override: Option<&str>, + name: SpirvWord, + linking: ast::LinkingDirective, +) { + if linking == ast::LinkingDirective::NONE { + return; + } + if linking.contains(ast::LinkingDirective::VISIBLE) { + let string_name = + name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); + builder.decorate( + name.0, + spirv::Decoration::LinkageAttributes, + [ + dr::Operand::LiteralString(string_name.to_string()), + dr::Operand::LinkageType(spirv::LinkageType::Export), + ] + .iter() + .cloned(), + ); + } else if linking.contains(ast::LinkingDirective::EXTERN) { + let string_name = + name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); + builder.decorate( + name.0, + spirv::Decoration::LinkageAttributes, + [ + dr::Operand::LiteralString(string_name.to_string()), + dr::Operand::LinkageType(spirv::LinkageType::Import), + ] + .iter() + .cloned(), + ); + } + // TODO: handle LinkingDirective::WEAK +} + +fn effective_input_arguments<'a>( + this: &'a ast::MethodDeclaration<'a, SpirvWord>, +) -> impl Iterator + 'a { + let is_kernel = matches!(this.name, ast::MethodName::Kernel(_)); + this.input_arguments.iter().map(move |arg| { + if !is_kernel && arg.state_space != ast::StateSpace::Reg { + let spirv_type = + SpirvType::pointer_to(arg.v_type.clone(), space_to_spirv(arg.state_space)); + (arg.name, spirv_type) + } else { + (arg.name, SpirvType::new(arg.v_type.clone())) + } + }) +} + +fn emit_implicit_conversion( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + cv: &ImplicitConversion, +) -> Result<(), TranslateError> { + let from_parts = to_parts(&cv.from_type); + let to_parts = to_parts(&cv.to_type); + match (from_parts.kind, to_parts.kind, &cv.kind) { + (_, _, &ConversionKind::BitToPtr) => { + let dst_type = map.get_or_add( + builder, + SpirvType::pointer_to(cv.to_type.clone(), space_to_spirv(cv.to_space)), + ); + builder.convert_u_to_ptr(dst_type.0, Some(cv.dst.0), cv.src.0)?; + } + (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => { + if from_parts.width == to_parts.width { + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + if from_parts.scalar_kind != ast::ScalarKind::Float + && to_parts.scalar_kind != ast::ScalarKind::Float + { + // It is noop, but another instruction expects result of this conversion + builder.copy_object(dst_type.0, Some(cv.dst.0), cv.src.0)?; + } else { + builder.bitcast(dst_type.0, Some(cv.dst.0), cv.src.0)?; + } + } else { + // This block is safe because it's illegal to implictly convert between floating point values + let same_width_bit_type = map.get_or_add( + builder, + SpirvType::new(type_from_parts(TypeParts { + scalar_kind: ast::ScalarKind::Bit, + ..from_parts + })), + ); + let same_width_bit_value = + builder.bitcast(same_width_bit_type.0, None, cv.src.0)?; + let wide_bit_type = type_from_parts(TypeParts { + scalar_kind: ast::ScalarKind::Bit, + ..to_parts + }); + let wide_bit_type_spirv = + map.get_or_add(builder, SpirvType::new(wide_bit_type.clone())); + if to_parts.scalar_kind == ast::ScalarKind::Unsigned + || to_parts.scalar_kind == ast::ScalarKind::Bit + { + builder.u_convert( + wide_bit_type_spirv.0, + Some(cv.dst.0), + same_width_bit_value, + )?; + } else { + let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed + && to_parts.scalar_kind == ast::ScalarKind::Signed + { + dr::Builder::s_convert + } else { + dr::Builder::u_convert + }; + let wide_bit_value = + conversion_fn(builder, wide_bit_type_spirv.0, None, same_width_bit_value)?; + emit_implicit_conversion( + builder, + map, + &ImplicitConversion { + src: SpirvWord(wide_bit_value), + dst: cv.dst, + from_type: wide_bit_type, + from_space: cv.from_space, + to_type: cv.to_type.clone(), + to_space: cv.to_space, + kind: ConversionKind::Default, + }, + )?; + } + } + } + (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.s_convert(result_type.0, Some(cv.dst.0), cv.src.0)?; + } + (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default) + | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default) + | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => { + let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.bitcast(into_type.0, Some(cv.dst.0), cv.src.0)?; + } + (_, _, &ConversionKind::PtrToPtr) => { + let result_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + space_to_spirv(cv.to_space), + ), + ); + if cv.to_space == ast::StateSpace::Generic && cv.from_space != ast::StateSpace::Generic + { + let src = if cv.from_type != cv.to_type { + let temp_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + space_to_spirv(cv.from_space), + ), + ); + builder.bitcast(temp_type.0, None, cv.src.0)? + } else { + cv.src.0 + }; + builder.ptr_cast_to_generic(result_type.0, Some(cv.dst.0), src)?; + } else if cv.from_space == ast::StateSpace::Generic + && cv.to_space != ast::StateSpace::Generic + { + let src = if cv.from_type != cv.to_type { + let temp_type = map.get_or_add( + builder, + SpirvType::Pointer( + Box::new(SpirvType::new(cv.to_type.clone())), + space_to_spirv(cv.from_space), + ), + ); + builder.bitcast(temp_type.0, None, cv.src.0)? + } else { + cv.src.0 + }; + builder.generic_cast_to_ptr(result_type.0, Some(cv.dst.0), src)?; + } else { + builder.bitcast(result_type.0, Some(cv.dst.0), cv.src.0)?; + } + } + (_, _, &ConversionKind::AddressOf) => { + let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_ptr_to_u(dst_type.0, Some(cv.dst.0), cv.src.0)?; + } + (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_ptr_to_u(result_type.0, Some(cv.dst.0), cv.src.0)?; + } + (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => { + let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); + builder.convert_u_to_ptr(result_type.0, Some(cv.dst.0), cv.src.0)?; + } + _ => unreachable!(), + } + Ok(()) +} + +fn vec_repr(t: T) -> Vec { + let mut result = vec![0; mem::size_of::()]; + unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) }; + result +} + +fn emit_abs( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + d: &ast::TypeFtz, + arg: &ast::AbsArgs, +) -> Result<(), dr::Error> { + let scalar_t = ast::ScalarType::from(d.type_); + let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); + let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed { + spirv::CLOp::s_abs + } else { + spirv::CLOp::fabs + }; + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + cl_abs as spirv::Word, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + Ok(()) +} + +fn emit_mul_int( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + type_: ast::ScalarType, + control: ast::MulIntControl, + arg: &ast::MulArgs, +) -> Result<(), dr::Error> { + let inst_type = map.get_or_add(builder, SpirvType::from(type_)); + match control { + ast::MulIntControl::Low => { + builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + } + ast::MulIntControl::High => { + let opencl_inst = if type_.kind() == ast::ScalarKind::Signed { + spirv::CLOp::s_mul_hi + } else { + spirv::CLOp::u_mul_hi + }; + builder.ext_inst( + inst_type.0, + Some(arg.dst.0), + opencl, + opencl_inst as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + ] + .iter() + .cloned(), + )?; + } + ast::MulIntControl::Wide => { + let instr_width = type_.size_of(); + let instr_kind = type_.kind(); + let dst_type = scalar_from_parts(instr_width * 2, instr_kind); + let dst_type_id = map.get_or_add_scalar(builder, dst_type); + let (src1, src2) = if type_.kind() == ast::ScalarKind::Signed { + let src1 = builder.s_convert(dst_type_id.0, None, arg.src1.0)?; + let src2 = builder.s_convert(dst_type_id.0, None, arg.src2.0)?; + (src1, src2) + } else { + let src1 = builder.u_convert(dst_type_id.0, None, arg.src1.0)?; + let src2 = builder.u_convert(dst_type_id.0, None, arg.src2.0)?; + (src1, src2) + }; + builder.i_mul(dst_type_id.0, Some(arg.dst.0), src1, src2)?; + builder.decorate(arg.dst.0, spirv::Decoration::NoSignedWrap, iter::empty()); + } + } + Ok(()) +} + +fn emit_mul_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + ctr: &ast::ArithFloat, + arg: &ast::MulArgs, +) -> Result<(), dr::Error> { + if ctr.saturate { + todo!() + } + let result_type = map.get_or_add_scalar(builder, ctr.type_.into()); + builder.f_mul(result_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + emit_rounding_decoration(builder, arg.dst, ctr.rounding); + Ok(()) +} + +fn scalar_from_parts(width: u8, kind: ast::ScalarKind) -> ast::ScalarType { + match kind { + ast::ScalarKind::Float => match width { + 2 => ast::ScalarType::F16, + 4 => ast::ScalarType::F32, + 8 => ast::ScalarType::F64, + _ => unreachable!(), + }, + ast::ScalarKind::Bit => match width { + 1 => ast::ScalarType::B8, + 2 => ast::ScalarType::B16, + 4 => ast::ScalarType::B32, + 8 => ast::ScalarType::B64, + _ => unreachable!(), + }, + ast::ScalarKind::Signed => match width { + 1 => ast::ScalarType::S8, + 2 => ast::ScalarType::S16, + 4 => ast::ScalarType::S32, + 8 => ast::ScalarType::S64, + _ => unreachable!(), + }, + ast::ScalarKind::Unsigned => match width { + 1 => ast::ScalarType::U8, + 2 => ast::ScalarType::U16, + 4 => ast::ScalarType::U32, + 8 => ast::ScalarType::U64, + _ => unreachable!(), + }, + ast::ScalarKind::Pred => ast::ScalarType::Pred, + } +} + +fn emit_rounding_decoration( + builder: &mut dr::Builder, + dst: SpirvWord, + rounding: Option, +) { + if let Some(rounding) = rounding { + builder.decorate( + dst.0, + spirv::Decoration::FPRoundingMode, + [rounding_to_spirv(rounding)].iter().cloned(), + ); + } +} + +fn rounding_to_spirv(this: ast::RoundingMode) -> rspirv::dr::Operand { + let mode = match this { + ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE, + ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ, + ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP, + ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN, + }; + rspirv::dr::Operand::FPRoundingMode(mode) +} + +fn emit_add_int( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + typ: ast::ScalarType, + saturate: bool, + arg: &ast::AddArgs, +) -> Result<(), dr::Error> { + if saturate { + todo!() + } + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))); + builder.i_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + Ok(()) +} + +fn emit_add_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + desc: &ast::ArithFloat, + arg: &ast::AddArgs, +) -> Result<(), dr::Error> { + let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))); + builder.f_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + emit_rounding_decoration(builder, arg.dst, desc.rounding); + Ok(()) +} + +fn emit_setp( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + setp: &ast::SetpData, + arg: &ast::SetpArgs, +) -> Result<(), dr::Error> { + let result_type = map + .get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred)) + .0; + let result_id = Some(arg.dst1.0); + let operand_1 = arg.src1.0; + let operand_2 = arg.src2.0; + match setp.cmp_op { + ast::SetpCompareOp::Integer(ast::SetpCompareInt::Eq) => { + builder.i_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::Eq) => { + builder.f_ord_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::NotEq) => { + builder.i_not_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NotEq) => { + builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLess) => { + builder.u_less_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLess) => { + builder.s_less_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::Less) => { + builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLessOrEq) => { + builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLessOrEq) => { + builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::LessOrEq) => { + builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreater) => { + builder.u_greater_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreater) => { + builder.s_greater_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::Greater) => { + builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreaterOrEq) => { + builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreaterOrEq) => { + builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::GreaterOrEq) => { + builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanEq) => { + builder.f_unord_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanNotEq) => { + builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLess) => { + builder.f_unord_less_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLessOrEq) => { + builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreater) => { + builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreaterOrEq) => { + builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsAnyNan) => { + let temp1 = builder.is_nan(result_type, None, operand_1)?; + let temp2 = builder.is_nan(result_type, None, operand_2)?; + builder.logical_or(result_type, result_id, temp1, temp2) + } + ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsNotNan) => { + let temp1 = builder.is_nan(result_type, None, operand_1)?; + let temp2 = builder.is_nan(result_type, None, operand_2)?; + let any_nan = builder.logical_or(result_type, None, temp1, temp2)?; + logical_not(builder, result_type, result_id, any_nan) + } + _ => todo!(), + }?; + Ok(()) +} + +// HACK ALERT +// Temporary workaround until IGC gets its shit together +// Currently IGC carries two copies of SPIRV-LLVM translator +// a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/. +// Obviously, old and buggy one is used for compiling L0 SPIRV +// https://github.com/intel/intel-graphics-compiler/issues/148 +fn logical_not( + builder: &mut dr::Builder, + result_type: spirv::Word, + result_id: Option, + operand: spirv::Word, +) -> Result { + let const_true = builder.constant_true(result_type, None); + let const_false = builder.constant_false(result_type, None); + builder.select(result_type, result_id, operand, const_false, const_true) +} + +// HACK ALERT +// For some reason IGC fails linking if the value and shift size are of different type +fn insert_shift_hack( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + offset_var: spirv::Word, + size_of: usize, +) -> Result { + let result_type = match size_of { + 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), + 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), + 4 => return Ok(offset_var), + _ => return Err(error_unreachable()), + }; + Ok(builder.u_convert(result_type.0, None, offset_var)?) +} + +fn emit_cvt( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + dets: &ast::CvtDetails, + arg: &ast::CvtArgs, +) -> Result<(), TranslateError> { + match dets.mode { + ptx_parser::CvtMode::SignExtend => { + let cv = ImplicitConversion { + src: arg.src, + dst: arg.dst, + from_type: dets.from.into(), + from_space: ast::StateSpace::Reg, + to_type: dets.to.into(), + to_space: ast::StateSpace::Reg, + kind: ConversionKind::SignExtend, + }; + emit_implicit_conversion(builder, map, &cv)?; + } + ptx_parser::CvtMode::ZeroExtend + | ptx_parser::CvtMode::Truncate + | ptx_parser::CvtMode::Bitcast => { + let cv = ImplicitConversion { + src: arg.src, + dst: arg.dst, + from_type: dets.from.into(), + from_space: ast::StateSpace::Reg, + to_type: dets.to.into(), + to_space: ast::StateSpace::Reg, + kind: ConversionKind::Default, + }; + emit_implicit_conversion(builder, map, &cv)?; + } + ptx_parser::CvtMode::SaturateUnsignedToSigned => { + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + builder.sat_convert_u_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; + } + ptx_parser::CvtMode::SaturateSignedToUnsigned => { + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; + } + ptx_parser::CvtMode::FPExtend { flush_to_zero } => { + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; + } + ptx_parser::CvtMode::FPTruncate { + rounding, + flush_to_zero, + } => { + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + ptx_parser::CvtMode::FPRound { + integer_rounding, + flush_to_zero, + } => { + if flush_to_zero == Some(true) { + todo!() + } + let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); + match integer_rounding { + Some(ast::RoundingMode::NearestEven) => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::rint as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + } + Some(ast::RoundingMode::Zero) => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::trunc as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + } + Some(ast::RoundingMode::NegativeInf) => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::floor as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + } + Some(ast::RoundingMode::PositiveInf) => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::ceil as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + } + None => { + builder.copy_object(result_type.0, Some(arg.dst.0), arg.src.0)?; + } + } + } + ptx_parser::CvtMode::SignedFromFP { + rounding, + flush_to_zero, + } => { + let dest_t: ast::ScalarType = dets.to.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + ptx_parser::CvtMode::UnsignedFromFP { + rounding, + flush_to_zero, + } => { + let dest_t: ast::ScalarType = dets.to.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + ptx_parser::CvtMode::FPFromSigned(rounding) => { + let dest_t: ast::ScalarType = dets.to.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.convert_s_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + ptx_parser::CvtMode::FPFromUnsigned(rounding) => { + let dest_t: ast::ScalarType = dets.to.into(); + let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); + builder.convert_u_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + } + } + Ok(()) +} + +fn emit_mad_uint( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + type_: ast::ScalarType, + control: ast::MulIntControl, + arg: &ast::MadArgs, +) -> Result<(), dr::Error> { + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(type_))) + .0; + match control { + ast::MulIntControl::Low => { + let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?; + builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?; + } + ast::MulIntControl::High => { + builder.ext_inst( + inst_type, + Some(arg.dst.0), + opencl, + spirv::CLOp::u_mad_hi as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + dr::Operand::IdRef(arg.src3.0), + ] + .iter() + .cloned(), + )?; + } + ast::MulIntControl::Wide => todo!(), + }; + Ok(()) +} + +fn emit_mad_sint( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + type_: ast::ScalarType, + control: ast::MulIntControl, + arg: &ast::MadArgs, +) -> Result<(), dr::Error> { + let inst_type = map.get_or_add(builder, SpirvType::from(type_)).0; + match control { + ast::MulIntControl::Low => { + let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?; + builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?; + } + ast::MulIntControl::High => { + builder.ext_inst( + inst_type, + Some(arg.dst.0), + opencl, + spirv::CLOp::s_mad_hi as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + dr::Operand::IdRef(arg.src3.0), + ] + .iter() + .cloned(), + )?; + } + ast::MulIntControl::Wide => todo!(), + }; + Ok(()) +} + +fn emit_mad_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::ArithFloat, + arg: &ast::MadArgs, +) -> Result<(), dr::Error> { + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) + .0; + builder.ext_inst( + inst_type, + Some(arg.dst.0), + opencl, + spirv::CLOp::mad as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + dr::Operand::IdRef(arg.src3.0), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + +fn emit_fma_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::ArithFloat, + arg: &ast::FmaArgs, +) -> Result<(), dr::Error> { + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) + .0; + builder.ext_inst( + inst_type, + Some(arg.dst.0), + opencl, + spirv::CLOp::fma as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + dr::Operand::IdRef(arg.src3.0), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + +fn emit_sub_int( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + typ: ast::ScalarType, + saturate: bool, + arg: &ast::SubArgs, +) -> Result<(), dr::Error> { + if saturate { + todo!() + } + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))) + .0; + builder.i_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + Ok(()) +} + +fn emit_sub_float( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + desc: &ast::ArithFloat, + arg: &ast::SubArgs, +) -> Result<(), dr::Error> { + let inst_type = map + .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) + .0; + builder.f_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; + emit_rounding_decoration(builder, arg.dst, desc.rounding); + Ok(()) +} + +fn emit_min( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::MinMaxDetails, + arg: &ast::MinArgs, +) -> Result<(), dr::Error> { + let cl_op = match desc { + ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min, + ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, + ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, + }; + let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_())); + builder.ext_inst( + inst_type.0, + Some(arg.dst.0), + opencl, + cl_op as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + +fn emit_max( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::MinMaxDetails, + arg: &ast::MaxArgs, +) -> Result<(), dr::Error> { + let cl_op = match desc { + ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max, + ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, + ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, + }; + let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_())); + builder.ext_inst( + inst_type.0, + Some(arg.dst.0), + opencl, + cl_op as spirv::Word, + [ + dr::Operand::IdRef(arg.src1.0), + dr::Operand::IdRef(arg.src2.0), + ] + .iter() + .cloned(), + )?; + Ok(()) +} + +fn emit_rcp( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + desc: &ast::RcpData, + arg: &ast::RcpArgs, +) -> Result<(), TranslateError> { + let is_f64 = desc.type_ == ast::ScalarType::F64; + let (instr_type, constant) = if is_f64 { + (ast::ScalarType::F64, vec_repr(1.0f64)) + } else { + (ast::ScalarType::F32, vec_repr(1.0f32)) + }; + let result_type = map.get_or_add_scalar(builder, instr_type); + let rounding = match desc.kind { + ptx_parser::RcpKind::Approx => { + builder.ext_inst( + result_type.0, + Some(arg.dst.0), + opencl, + spirv::CLOp::native_recip as u32, + [dr::Operand::IdRef(arg.src.0)].iter().cloned(), + )?; + return Ok(()); + } + ptx_parser::RcpKind::Compliant(rounding) => rounding, + }; + let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?; + builder.f_div(result_type.0, Some(arg.dst.0), one.0, arg.src.0)?; + emit_rounding_decoration(builder, arg.dst, Some(rounding)); + builder.decorate( + arg.dst.0, + spirv::Decoration::FPFastMathMode, + [dr::Operand::FPFastMathMode( + spirv::FPFastMathMode::ALLOW_RECIP, + )] + .iter() + .cloned(), + ); + Ok(()) +} + +fn emit_atom( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + details: &ast::AtomDetails, + arg: &ast::AtomArgs, +) -> Result<(), TranslateError> { + let spirv_op = match details.op { + ptx_parser::AtomicOp::And => dr::Builder::atomic_and, + ptx_parser::AtomicOp::Or => dr::Builder::atomic_or, + ptx_parser::AtomicOp::Xor => dr::Builder::atomic_xor, + ptx_parser::AtomicOp::Exchange => dr::Builder::atomic_exchange, + ptx_parser::AtomicOp::Add => dr::Builder::atomic_i_add, + ptx_parser::AtomicOp::IncrementWrap | ptx_parser::AtomicOp::DecrementWrap => { + return Err(error_unreachable()) + } + ptx_parser::AtomicOp::SignedMin => dr::Builder::atomic_s_min, + ptx_parser::AtomicOp::UnsignedMin => dr::Builder::atomic_u_min, + ptx_parser::AtomicOp::SignedMax => dr::Builder::atomic_s_max, + ptx_parser::AtomicOp::UnsignedMax => dr::Builder::atomic_u_max, + ptx_parser::AtomicOp::FloatAdd => dr::Builder::atomic_f_add_ext, + ptx_parser::AtomicOp::FloatMin => todo!(), + ptx_parser::AtomicOp::FloatMax => todo!(), + }; + let result_type = map.get_or_add(builder, SpirvType::new(details.type_.clone())); + let memory_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(scope_to_spirv(details.scope) as u32), + )?; + let semantics_const = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(semantics_to_spirv(details.semantics).bits()), + )?; + spirv_op( + builder, + result_type.0, + Some(arg.dst.0), + arg.src1.0, + memory_const.0, + semantics_const.0, + arg.src2.0, + )?; + Ok(()) +} + +fn scope_to_spirv(this: ast::MemScope) -> spirv::Scope { + match this { + ast::MemScope::Cta => spirv::Scope::Workgroup, + ast::MemScope::Gpu => spirv::Scope::Device, + ast::MemScope::Sys => spirv::Scope::CrossDevice, + ptx_parser::MemScope::Cluster => todo!(), + } +} + +fn semantics_to_spirv(this: ast::AtomSemantics) -> spirv::MemorySemantics { + match this { + ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED, + ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE, + ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE, + ast::AtomSemantics::AcqRel => spirv::MemorySemantics::ACQUIRE_RELEASE, + } +} + +fn emit_float_div_decoration(builder: &mut dr::Builder, dst: SpirvWord, kind: ast::DivFloatKind) { + match kind { + ast::DivFloatKind::Approx => { + builder.decorate( + dst.0, + spirv::Decoration::FPFastMathMode, + [dr::Operand::FPFastMathMode( + spirv::FPFastMathMode::ALLOW_RECIP, + )] + .iter() + .cloned(), + ); + } + ast::DivFloatKind::Rounding(rnd) => { + emit_rounding_decoration(builder, dst, Some(rnd)); + } + ast::DivFloatKind::ApproxFull => {} + } +} + +fn emit_sqrt( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + opencl: spirv::Word, + details: &ast::RcpData, + a: &ast::SqrtArgs, +) -> Result<(), TranslateError> { + let result_type = map.get_or_add_scalar(builder, details.type_.into()); + let (ocl_op, rounding) = match details.kind { + ast::RcpKind::Approx => (spirv::CLOp::sqrt, None), + ast::RcpKind::Compliant(rnd) => (spirv::CLOp::sqrt, Some(rnd)), + }; + builder.ext_inst( + result_type.0, + Some(a.dst.0), + opencl, + ocl_op as spirv::Word, + [dr::Operand::IdRef(a.src.0)].iter().cloned(), + )?; + emit_rounding_decoration(builder, a.dst, rounding); + Ok(()) +} + +// TODO: check what kind of assembly do we emit +fn emit_logical_xor_spirv( + builder: &mut dr::Builder, + result_type: spirv::Word, + result_id: Option, + op1: spirv::Word, + op2: spirv::Word, +) -> Result { + let temp_or = builder.logical_or(result_type, None, op1, op2)?; + let temp_and = builder.logical_and(result_type, None, op1, op2)?; + let temp_neg = logical_not(builder, result_type, None, temp_and)?; + builder.logical_and(result_type, result_id, temp_or, temp_neg) +} + +fn emit_load_var( + builder: &mut dr::Builder, + map: &mut TypeWordMap, + details: &LoadVarDetails, +) -> Result<(), TranslateError> { + let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone())); + match details.member_index { + Some((index, Some(width))) => { + let vector_type = match details.typ { + ast::Type::Scalar(scalar_t) => ast::Type::Vector(width, scalar_t), + _ => return Err(error_mismatched_type()), + }; + let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); + let vector_temp = builder.load( + vector_type_spirv.0, + None, + details.arg.src.0, + None, + iter::empty(), + )?; + builder.composite_extract( + result_type.0, + Some(details.arg.dst.0), + vector_temp, + [index as u32].iter().copied(), + )?; + } + Some((index, None)) => { + let result_ptr_type = map.get_or_add( + builder, + SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function), + ); + let index_spirv = map.get_or_add_constant( + builder, + &ast::Type::Scalar(ast::ScalarType::U32), + &vec_repr(index as u32), + )?; + let src = builder.in_bounds_access_chain( + result_ptr_type.0, + None, + details.arg.src.0, + [index_spirv.0].iter().copied(), + )?; + builder.load( + result_type.0, + Some(details.arg.dst.0), + src, + None, + iter::empty(), + )?; + } + None => { + builder.load( + result_type.0, + Some(details.arg.dst.0), + details.arg.src.0, + None, + iter::empty(), + )?; + } + }; + Ok(()) +} + +fn to_parts(this: &ast::Type) -> TypeParts { + match this { + ast::Type::Scalar(scalar) => TypeParts { + kind: TypeKind::Scalar, + state_space: ast::StateSpace::Reg, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: Vec::new(), + }, + ast::Type::Vector(components, scalar) => TypeParts { + kind: TypeKind::Vector, + state_space: ast::StateSpace::Reg, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: vec![*components as u32], + }, + ast::Type::Array(_, scalar, components) => TypeParts { + kind: TypeKind::Array, + state_space: ast::StateSpace::Reg, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: components.clone(), + }, + ast::Type::Pointer(scalar, space) => TypeParts { + kind: TypeKind::Pointer, + state_space: *space, + scalar_kind: scalar.kind(), + width: scalar.size_of(), + components: Vec::new(), + }, + } +} + +fn type_from_parts(t: TypeParts) -> ast::Type { + match t.kind { + TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)), + TypeKind::Vector => ast::Type::Vector( + t.components[0] as u8, + scalar_from_parts(t.width, t.scalar_kind), + ), + TypeKind::Array => ast::Type::Array( + None, + scalar_from_parts(t.width, t.scalar_kind), + t.components, + ), + TypeKind::Pointer => { + ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space) + } + } +} + +#[derive(Eq, PartialEq, Clone)] +struct TypeParts { + kind: TypeKind, + scalar_kind: ast::ScalarKind, + width: u8, + state_space: ast::StateSpace, + components: Vec, +} + +#[derive(Eq, PartialEq, Copy, Clone)] +enum TypeKind { + Scalar, + Vector, + Array, + Pointer, +} diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs new file mode 100644 index 00000000..d0c7c981 --- /dev/null +++ b/ptx/src/pass/expand_arguments.rs @@ -0,0 +1,181 @@ +use super::*; +use ptx_parser as ast; + +pub(super) fn run<'a, 'b>( + func: Vec, + id_def: &'b mut MutableNumericIdResolver<'a>, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for s in func { + match s { + Statement::Label(id) => result.push(Statement::Label(id)), + Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), + Statement::LoadVar(details) => result.push(Statement::LoadVar(details)), + Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), + Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), + Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), + Statement::Constant(c) => result.push(Statement::Constant(c)), + Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)), + s => { + let (new_statement, post_stmts) = { + let mut visitor = FlattenArguments::new(&mut result, id_def); + (s.visit_map(&mut visitor)?, visitor.post_stmts) + }; + result.push(new_statement); + result.extend(post_stmts); + } + } + } + Ok(result) +} + +struct FlattenArguments<'a, 'b> { + func: &'b mut Vec, + id_def: &'b mut MutableNumericIdResolver<'a>, + post_stmts: Vec, +} + +impl<'a, 'b> FlattenArguments<'a, 'b> { + fn new( + func: &'b mut Vec, + id_def: &'b mut MutableNumericIdResolver<'a>, + ) -> Self { + FlattenArguments { + func, + id_def, + post_stmts: Vec::new(), + } + } + + fn reg(&mut self, name: SpirvWord) -> Result { + Ok(name) + } + + fn reg_offset( + &mut self, + reg: SpirvWord, + offset: i32, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + _is_dst: bool, + ) -> Result { + let (type_, state_space) = if let Some((type_, state_space)) = type_space { + (type_, state_space) + } else { + return Err(TranslateError::UntypedSymbol); + }; + if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg { + let (reg_type, reg_space) = self.id_def.get_typed(reg)?; + if !space_is_compatible(reg_space, ast::StateSpace::Reg) { + return Err(error_mismatched_type()); + } + let reg_scalar_type = match reg_type { + ast::Type::Scalar(underlying_type) => underlying_type, + _ => return Err(error_mismatched_type()), + }; + let id_constant_stmt = self + .id_def + .register_intermediate(reg_type.clone(), ast::StateSpace::Reg); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: reg_scalar_type, + value: ast::ImmediateValue::S64(offset as i64), + })); + let arith_details = match reg_scalar_type.kind() { + ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, + saturate: false, + }), + ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { + ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, + saturate: false, + }) + } + _ => return Err(error_unreachable()), + }; + let id_add_result = self.id_def.register_intermediate(reg_type, state_space); + self.func + .push(Statement::Instruction(ast::Instruction::Add { + data: arith_details, + arguments: ast::AddArgs { + dst: id_add_result, + src1: reg, + src2: id_constant_stmt, + }, + })); + Ok(id_add_result) + } else { + let id_constant_stmt = self.id_def.register_intermediate( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: ast::ScalarType::S64, + value: ast::ImmediateValue::S64(offset as i64), + })); + let dst = self + .id_def + .register_intermediate(type_.clone(), state_space); + self.func.push(Statement::PtrAccess(PtrAccess { + underlying_type: type_.clone(), + state_space: state_space, + dst, + ptr_src: reg, + offset_src: id_constant_stmt, + })); + Ok(dst) + } + } + + fn immediate( + &mut self, + value: ast::ImmediateValue, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + ) -> Result { + let (scalar_t, state_space) = + if let Some((ast::Type::Scalar(scalar), state_space)) = type_space { + (*scalar, state_space) + } else { + return Err(TranslateError::UntypedSymbol); + }; + let id = self + .id_def + .register_intermediate(ast::Type::Scalar(scalar_t), state_space); + self.func.push(Statement::Constant(ConstantDefinition { + dst: id, + typ: scalar_t, + value, + })); + Ok(id) + } +} + +impl<'a, 'b> ast::VisitorMap for FlattenArguments<'a, 'b> { + fn visit( + &mut self, + args: TypedOperand, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + match args { + TypedOperand::Reg(r) => self.reg(r), + TypedOperand::Imm(x) => self.immediate(x, type_space), + TypedOperand::RegOffset(reg, offset) => { + self.reg_offset(reg, offset, type_space, is_dst) + } + TypedOperand::VecMember(..) => Err(error_unreachable()), + } + } + + fn visit_ident( + &mut self, + name: ::Ident, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + _is_dst: bool, + _relaxed_type_check: bool, + ) -> Result<::Ident, TranslateError> { + self.reg(name) + } +} diff --git a/ptx/src/pass/extract_globals.rs b/ptx/src/pass/extract_globals.rs new file mode 100644 index 00000000..680a5eee --- /dev/null +++ b/ptx/src/pass/extract_globals.rs @@ -0,0 +1,282 @@ +use super::*; + +pub(super) fn run<'input, 'b>( + sorted_statements: Vec, + ptx_impl_imports: &mut HashMap, + id_def: &mut NumericIdResolver, +) -> Result<(Vec, Vec>), TranslateError> { + let mut local = Vec::with_capacity(sorted_statements.len()); + let mut global = Vec::new(); + for statement in sorted_statements { + match statement { + Statement::Variable( + var @ ast::Variable { + state_space: ast::StateSpace::Shared, + .. + }, + ) + | Statement::Variable( + var @ ast::Variable { + state_space: ast::StateSpace::Global, + .. + }, + ) => global.push(var), + Statement::Instruction(ast::Instruction::Bfe { data, arguments }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", scalar_to_ptx_name(data)].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Bfe { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Bfi { data, arguments }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", scalar_to_ptx_name(data)].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Bfi { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Brev { data, arguments }) => { + let fn_name: String = + [ZLUDA_PTX_PREFIX, "brev_", scalar_to_ptx_name(data)].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Brev { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Activemask { arguments }) => { + let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Activemask { arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Atom { + data: + data @ ast::AtomDetails { + op: ast::AtomicOp::IncrementWrap, + semantics, + scope, + space, + .. + }, + arguments, + }) => { + let fn_name = [ + ZLUDA_PTX_PREFIX, + "atom_", + semantics_to_ptx_name(semantics), + "_", + scope_to_ptx_name(scope), + "_", + space_to_ptx_name(space), + "_inc", + ] + .concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Atom { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Atom { + data: + data @ ast::AtomDetails { + op: ast::AtomicOp::DecrementWrap, + semantics, + scope, + space, + .. + }, + arguments, + }) => { + let fn_name = [ + ZLUDA_PTX_PREFIX, + "atom_", + semantics_to_ptx_name(semantics), + "_", + scope_to_ptx_name(scope), + "_", + space_to_ptx_name(space), + "_dec", + ] + .concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Atom { data, arguments }, + fn_name, + )?); + } + Statement::Instruction(ast::Instruction::Atom { + data: + data @ ast::AtomDetails { + op: ast::AtomicOp::FloatAdd, + semantics, + scope, + space, + .. + }, + arguments, + }) => { + let scalar_type = match data.type_ { + ptx_parser::Type::Scalar(scalar) => scalar, + _ => return Err(error_unreachable()), + }; + let fn_name = [ + ZLUDA_PTX_PREFIX, + "atom_", + semantics_to_ptx_name(semantics), + "_", + scope_to_ptx_name(scope), + "_", + space_to_ptx_name(space), + "_add_", + scalar_to_ptx_name(scalar_type), + ] + .concat(); + local.push(instruction_to_fn_call( + id_def, + ptx_impl_imports, + ast::Instruction::Atom { data, arguments }, + fn_name, + )?); + } + s => local.push(s), + } + } + Ok((local, global)) +} + +fn instruction_to_fn_call( + id_defs: &mut NumericIdResolver, + ptx_impl_imports: &mut HashMap, + inst: ast::Instruction, + fn_name: String, +) -> Result { + let mut arguments = Vec::new(); + ast::visit_map(inst, &mut |operand, + type_space: Option<( + &ast::Type, + ast::StateSpace, + )>, + is_dst, + _| { + let (typ, space) = match type_space { + Some((typ, space)) => (typ.clone(), space), + None => return Err(error_unreachable()), + }; + arguments.push((operand, is_dst, typ, space)); + Ok(SpirvWord(0)) + })?; + let return_arguments_count = arguments + .iter() + .position(|(desc, is_dst, _, _)| !is_dst) + .unwrap_or(arguments.len()); + let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count); + let fn_id = register_external_fn_call( + id_defs, + ptx_impl_imports, + fn_name, + return_arguments + .iter() + .map(|(_, _, typ, state)| (typ, *state)), + input_arguments + .iter() + .map(|(_, _, typ, state)| (typ, *state)), + )?; + Ok(Statement::Instruction(ast::Instruction::Call { + data: ast::CallDetails { + uniform: false, + return_arguments: return_arguments + .iter() + .map(|(_, _, typ, state)| (typ.clone(), *state)) + .collect::>(), + input_arguments: input_arguments + .iter() + .map(|(_, _, typ, state)| (typ.clone(), *state)) + .collect::>(), + }, + arguments: ast::CallArgs { + return_arguments: return_arguments + .iter() + .map(|(name, _, _, _)| *name) + .collect::>(), + func: fn_id, + input_arguments: input_arguments + .iter() + .map(|(name, _, _, _)| *name) + .collect::>(), + }, + })) +} + +fn scalar_to_ptx_name(this: ast::ScalarType) -> &'static str { + match this { + ast::ScalarType::B8 => "b8", + ast::ScalarType::B16 => "b16", + ast::ScalarType::B32 => "b32", + ast::ScalarType::B64 => "b64", + ast::ScalarType::B128 => "b128", + ast::ScalarType::U8 => "u8", + ast::ScalarType::U16 => "u16", + ast::ScalarType::U16x2 => "u16x2", + ast::ScalarType::U32 => "u32", + ast::ScalarType::U64 => "u64", + ast::ScalarType::S8 => "s8", + ast::ScalarType::S16 => "s16", + ast::ScalarType::S16x2 => "s16x2", + ast::ScalarType::S32 => "s32", + ast::ScalarType::S64 => "s64", + ast::ScalarType::F16 => "f16", + ast::ScalarType::F16x2 => "f16x2", + ast::ScalarType::F32 => "f32", + ast::ScalarType::F64 => "f64", + ast::ScalarType::BF16 => "bf16", + ast::ScalarType::BF16x2 => "bf16x2", + ast::ScalarType::Pred => "pred", + } +} + +fn semantics_to_ptx_name(this: ast::AtomSemantics) -> &'static str { + match this { + ast::AtomSemantics::Relaxed => "relaxed", + ast::AtomSemantics::Acquire => "acquire", + ast::AtomSemantics::Release => "release", + ast::AtomSemantics::AcqRel => "acq_rel", + } +} + +fn scope_to_ptx_name(this: ast::MemScope) -> &'static str { + match this { + ast::MemScope::Cta => "cta", + ast::MemScope::Gpu => "gpu", + ast::MemScope::Sys => "sys", + ast::MemScope::Cluster => "cluster", + } +} + +fn space_to_ptx_name(this: ast::StateSpace) -> &'static str { + match this { + ast::StateSpace::Generic => "generic", + ast::StateSpace::Global => "global", + ast::StateSpace::Shared => "shared", + ast::StateSpace::Reg => "reg", + ast::StateSpace::Const => "const", + ast::StateSpace::Local => "local", + ast::StateSpace::Param => "param", + ast::StateSpace::Sreg => "sreg", + ast::StateSpace::SharedCluster => "shared_cluster", + ast::StateSpace::ParamEntry => "param_entry", + ast::StateSpace::SharedCta => "shared_cta", + ast::StateSpace::ParamFunc => "param_func", + } +} diff --git a/ptx/src/pass/fix_special_registers.rs b/ptx/src/pass/fix_special_registers.rs new file mode 100644 index 00000000..c0290167 --- /dev/null +++ b/ptx/src/pass/fix_special_registers.rs @@ -0,0 +1,130 @@ +use super::*; +use std::collections::HashMap; + +pub(super) fn run<'a, 'b, 'input>( + ptx_impl_imports: &'a mut HashMap>, + typed_statements: Vec, + numeric_id_defs: &'a mut NumericIdResolver<'b>, +) -> Result, TranslateError> { + let result = Vec::with_capacity(typed_statements.len()); + let mut sreg_sresolver = SpecialRegisterResolver { + ptx_impl_imports, + numeric_id_defs, + result, + }; + for statement in typed_statements { + let statement = statement.visit_map(&mut sreg_sresolver)?; + sreg_sresolver.result.push(statement); + } + Ok(sreg_sresolver.result) +} + +struct SpecialRegisterResolver<'a, 'b, 'input> { + ptx_impl_imports: &'a mut HashMap>, + numeric_id_defs: &'a mut NumericIdResolver<'b>, + result: Vec, +} + +impl<'a, 'b, 'input> ast::VisitorMap + for SpecialRegisterResolver<'a, 'b, 'input> +{ + fn visit( + &mut self, + operand: TypedOperand, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + operand.map(|name, vector_index| self.replace_sreg(name, is_dst, vector_index)) + } + + fn visit_ident( + &mut self, + args: SpirvWord, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + self.replace_sreg(args, is_dst, None) + } +} + +impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'b, 'input> { + fn replace_sreg( + &mut self, + name: SpirvWord, + is_dst: bool, + vector_index: Option, + ) -> Result { + if let Some(sreg) = self.numeric_id_defs.special_registers.get(name) { + if is_dst { + return Err(error_mismatched_type()); + } + let input_arguments = match (vector_index, sreg.get_function_input_type()) { + (Some(idx), Some(inp_type)) => { + if inp_type != ast::ScalarType::U8 { + return Err(TranslateError::Unreachable); + } + let constant = self.numeric_id_defs.register_intermediate(Some(( + ast::Type::Scalar(inp_type), + ast::StateSpace::Reg, + ))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: constant, + typ: inp_type, + value: ast::ImmediateValue::U64(idx as u64), + })); + vec![( + TypedOperand::Reg(constant), + ast::Type::Scalar(inp_type), + ast::StateSpace::Reg, + )] + } + (None, None) => Vec::new(), + _ => return Err(error_mismatched_type()), + }; + let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); + let return_type = sreg.get_function_return_type(); + let fn_result = self.numeric_id_defs.register_intermediate(Some(( + ast::Type::Scalar(return_type), + ast::StateSpace::Reg, + ))); + let return_arguments = vec![( + fn_result, + ast::Type::Scalar(return_type), + ast::StateSpace::Reg, + )]; + let fn_call = register_external_fn_call( + self.numeric_id_defs, + self.ptx_impl_imports, + ocl_fn_name.to_string(), + return_arguments.iter().map(|(_, typ, space)| (typ, *space)), + input_arguments.iter().map(|(_, typ, space)| (typ, *space)), + )?; + let data = ast::CallDetails { + uniform: false, + return_arguments: return_arguments + .iter() + .map(|(_, typ, space)| (typ.clone(), *space)) + .collect(), + input_arguments: input_arguments + .iter() + .map(|(_, typ, space)| (typ.clone(), *space)) + .collect(), + }; + let arguments = ast::CallArgs { + return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(), + func: fn_call, + input_arguments: input_arguments.iter().map(|(name, _, _)| *name).collect(), + }; + self.result + .push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })); + Ok(fn_result) + } else { + Ok(name) + } + } +} diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs new file mode 100644 index 00000000..25e80f05 --- /dev/null +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -0,0 +1,432 @@ +use std::mem; + +use super::*; +use ptx_parser as ast; + +/* + There are several kinds of implicit conversions in PTX: + * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands + * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size + - ld.param: not documented, but for instruction `ld.param. x, [y]`, + semantics are to first zext/chop/bitcast `y` as needed and then do + documented special ld/st/cvt conversion rules for destination operands + - st.param [x] y (used as function return arguments) same rule as above applies + - generic/global ld: for instruction `ld x, [y]`, y must be of type + b64/u64/s64, which is bitcast to a pointer, dereferenced and then + documented special ld/st/cvt conversion rules are applied to dst + - generic/global st: for instruction `st [x], y`, x must be of type + b64/u64/s64, which is bitcast to a pointer +*/ +pub(super) fn run( + func: Vec, + id_def: &mut MutableNumericIdResolver, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for s in func.into_iter() { + match s { + Statement::Instruction(inst) => { + insert_implicit_conversions_impl( + &mut result, + id_def, + Statement::Instruction(inst), + )?; + } + Statement::PtrAccess(access) => { + insert_implicit_conversions_impl( + &mut result, + id_def, + Statement::PtrAccess(access), + )?; + } + Statement::RepackVector(repack) => { + insert_implicit_conversions_impl( + &mut result, + id_def, + Statement::RepackVector(repack), + )?; + } + s @ Statement::Conditional(_) + | s @ Statement::Conversion(_) + | s @ Statement::Label(_) + | s @ Statement::Constant(_) + | s @ Statement::Variable(_) + | s @ Statement::LoadVar(..) + | s @ Statement::StoreVar(..) + | s @ Statement::RetValue(..) + | s @ Statement::FunctionPointer(..) => result.push(s), + } + } + Ok(result) +} + +fn insert_implicit_conversions_impl( + func: &mut Vec, + id_def: &mut MutableNumericIdResolver, + stmt: ExpandedStatement, +) -> Result<(), TranslateError> { + let mut post_conv = Vec::new(); + let statement = stmt.visit_map::( + &mut |operand, + type_state: Option<(&ast::Type, ast::StateSpace)>, + is_dst, + relaxed_type_check| { + let (instr_type, instruction_space) = match type_state { + None => return Ok(operand), + Some(t) => t, + }; + let (operand_type, operand_space) = id_def.get_typed(operand)?; + let conversion_fn = if relaxed_type_check { + if is_dst { + should_convert_relaxed_dst_wrapper + } else { + should_convert_relaxed_src_wrapper + } + } else { + default_implicit_conversion + }; + match conversion_fn( + (operand_space, &operand_type), + (instruction_space, instr_type), + )? { + Some(conv_kind) => { + let conv_output = if is_dst { &mut post_conv } else { &mut *func }; + let mut from_type = instr_type.clone(); + let mut from_space = instruction_space; + let mut to_type = operand_type; + let mut to_space = operand_space; + let mut src = + id_def.register_intermediate(instr_type.clone(), instruction_space); + let mut dst = operand; + let result = Ok::<_, TranslateError>(src); + if !is_dst { + mem::swap(&mut src, &mut dst); + mem::swap(&mut from_type, &mut to_type); + mem::swap(&mut from_space, &mut to_space); + } + conv_output.push(Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + from_space, + to_type, + to_space, + kind: conv_kind, + })); + result + } + None => Ok(operand), + } + }, + )?; + func.push(statement); + func.append(&mut post_conv); + Ok(()) +} + +pub(crate) fn default_implicit_conversion( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if instruction_space == ast::StateSpace::Reg { + if space_is_compatible(operand_space, ast::StateSpace::Reg) { + if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) = + (operand_type, instruction_type) + { + if scalar.kind() == ast::ScalarKind::Bit + && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) + { + return Ok(Some(ConversionKind::Default)); + } + } + } else if is_addressable(operand_space) { + return Ok(Some(ConversionKind::AddressOf)); + } + } + if !space_is_compatible(instruction_space, operand_space) { + default_implicit_conversion_space( + (operand_space, operand_type), + (instruction_space, instruction_type), + ) + } else if instruction_type != operand_type { + default_implicit_conversion_type(instruction_space, operand_type, instruction_type) + } else { + Ok(None) + } +} + +fn is_addressable(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Const + | ast::StateSpace::Generic + | ast::StateSpace::Global + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false, + ast::StateSpace::SharedCluster + | ast::StateSpace::SharedCta + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc => todo!(), + } +} + +// Space is different +fn default_implicit_conversion_space( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space)) + || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) + { + Ok(Some(ConversionKind::PtrToPtr)) + } else if space_is_compatible(operand_space, ast::StateSpace::Reg) { + match operand_type { + ast::Type::Pointer(operand_ptr_type, operand_ptr_space) + if *operand_ptr_space == instruction_space => + { + if instruction_type != &ast::Type::Scalar(*operand_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + // TODO: 32 bit + ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { + ast::StateSpace::Global + | ast::StateSpace::Generic + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), + _ => Err(error_mismatched_type()), + }, + ast::Type::Scalar(ast::ScalarType::B32) + | ast::Type::Scalar(ast::ScalarType::U32) + | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { + ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { + Ok(Some(ConversionKind::BitToPtr)) + } + _ => Err(error_mismatched_type()), + }, + _ => Err(error_mismatched_type()), + } + } else if space_is_compatible(instruction_space, ast::StateSpace::Reg) { + match instruction_type { + ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) + if operand_space == *instruction_ptr_space => + { + if operand_type != &ast::Type::Scalar(*instruction_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + _ => Err(error_mismatched_type()), + } + } else { + Err(error_mismatched_type()) + } +} + +// Space is same, but type is different +fn default_implicit_conversion_type( + space: ast::StateSpace, + operand_type: &ast::Type, + instruction_type: &ast::Type, +) -> Result, TranslateError> { + if space_is_compatible(space, ast::StateSpace::Reg) { + if should_bitcast(instruction_type, operand_type) { + Ok(Some(ConversionKind::Default)) + } else { + Err(TranslateError::MismatchedType) + } + } else { + Ok(Some(ConversionKind::PtrToPtr)) + } +} + +fn coerces_to_generic(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::Local + | ptx_parser::StateSpace::SharedCta + | ast::StateSpace::SharedCluster + | ast::StateSpace::Shared => true, + ast::StateSpace::Reg + | ast::StateSpace::Param + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc + | ast::StateSpace::Generic + | ast::StateSpace::Sreg => false, + } +} + +fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { + match (instr, operand) { + (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { + if inst.size_of() != operand.size_of() { + return false; + } + match inst.kind() { + ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, + ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, + ast::ScalarKind::Signed => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Unsigned + } + ast::ScalarKind::Unsigned => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Signed + } + ast::ScalarKind::Pred => false, + } + } + (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand)) + | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => { + should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) + } + _ => false, + } +} + +pub(crate) fn should_convert_relaxed_dst_wrapper( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if !space_is_compatible(operand_space, instruction_space) { + return Err(TranslateError::MismatchedType); + } + if operand_type == instruction_type { + return Ok(None); + } + match should_convert_relaxed_dst(operand_type, instruction_type) { + conv @ Some(_) => Ok(conv), + None => Err(TranslateError::MismatchedType), + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands +fn should_convert_relaxed_dst( + dst_type: &ast::Type, + instr_type: &ast::Type, +) -> Option { + if dst_type == instr_type { + return None; + } + match (dst_type, instr_type) { + (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ast::ScalarKind::Bit => { + if instr_type.size_of() <= dst_type.size_of() { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Signed => { + if dst_type.kind() != ast::ScalarKind::Float { + if instr_type.size_of() == dst_type.size_of() { + Some(ConversionKind::Default) + } else if instr_type.size_of() < dst_type.size_of() { + Some(ConversionKind::SignExtend) + } else { + None + } + } else { + None + } + } + ast::ScalarKind::Unsigned => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() != ast::ScalarKind::Float + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Float => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() == ast::ScalarKind::Bit + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Pred => None, + }, + (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) + | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { + should_convert_relaxed_dst( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) + } + _ => None, + } +} + +pub(crate) fn should_convert_relaxed_src_wrapper( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if !space_is_compatible(operand_space, instruction_space) { + return Err(error_mismatched_type()); + } + if operand_type == instruction_type { + return Ok(None); + } + match should_convert_relaxed_src(operand_type, instruction_type) { + conv @ Some(_) => Ok(conv), + None => Err(error_mismatched_type()), + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands +fn should_convert_relaxed_src( + src_type: &ast::Type, + instr_type: &ast::Type, +) -> Option { + if src_type == instr_type { + return None; + } + match (src_type, instr_type) { + (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ast::ScalarKind::Bit => { + if instr_type.size_of() <= src_type.size_of() { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() != ast::ScalarKind::Float + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Float => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() == ast::ScalarKind::Bit + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Pred => None, + }, + (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) + | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { + should_convert_relaxed_src( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) + } + _ => None, + } +} diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs new file mode 100644 index 00000000..e314b05d --- /dev/null +++ b/ptx/src/pass/insert_mem_ssa_statements.rs @@ -0,0 +1,275 @@ +use super::*; +use ptx_parser as ast; + +/* + How do we handle arguments: + - input .params in kernels + .param .b64 in_arg + get turned into this SPIR-V: + %1 = OpFunctionParameter %ulong + %2 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %1 + We do this for two reasons. One, common treatment for argument-declared + .param variables and .param variables inside function (we assume that + at SPIR-V level every .param is a pointer in Function storage class) + - input .params in functions + .param .b64 in_arg + get turned into this SPIR-V: + %1 = OpFunctionParameter %_ptr_Function_ulong + - input .regs + .reg .b64 in_arg + get turned into the same SPIR-V as kernel .params: + %1 = OpFunctionParameter %ulong + %2 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %1 + - output .regs + .reg .b64 out_arg + get just a variable declaration: + %2 = OpVariable %%_ptr_Function_ulong Function + - output .params don't exist, they have been moved to input positions + by an earlier pass + Distinguishing betweem kernel .params and function .params is not the + cleanest solution. Alternatively, we could "deparamize" all kernel .param + arguments by turning them into .reg arguments like this: + .param .b64 arg -> .reg ptr<.b64,.param> arg + This has the massive downside that this transformation would have to run + very early and would muddy up already difficult code. It's simpler to just + have an if here +*/ +pub(super) fn run<'a, 'b>( + func: Vec, + id_def: &mut NumericIdResolver, + fn_decl: &'a mut ast::MethodDeclaration<'b, SpirvWord>, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for arg in fn_decl.input_arguments.iter_mut() { + insert_mem_ssa_argument( + id_def, + &mut result, + arg, + matches!(fn_decl.name, ast::MethodName::Kernel(_)), + ); + } + for arg in fn_decl.return_arguments.iter() { + insert_mem_ssa_argument_reg_return(&mut result, arg); + } + for s in func { + match s { + Statement::Instruction(inst) => match inst { + ast::Instruction::Ret { data } => { + // TODO: handle multiple output args + match &fn_decl.return_arguments[..] { + [return_reg] => { + let new_id = id_def.register_intermediate(Some(( + return_reg.v_type.clone(), + ast::StateSpace::Reg, + ))); + result.push(Statement::LoadVar(LoadVarDetails { + arg: ast::LdArgs { + dst: new_id, + src: return_reg.name, + }, + typ: return_reg.v_type.clone(), + member_index: None, + })); + result.push(Statement::RetValue(data, new_id)); + } + [] => result.push(Statement::Instruction(ast::Instruction::Ret { data })), + _ => unimplemented!(), + } + } + inst => insert_mem_ssa_statement_default( + id_def, + &mut result, + Statement::Instruction(inst), + )?, + }, + Statement::Conditional(bra) => { + insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conditional(bra))? + } + Statement::Conversion(conv) => { + insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conversion(conv))? + } + Statement::PtrAccess(ptr_access) => insert_mem_ssa_statement_default( + id_def, + &mut result, + Statement::PtrAccess(ptr_access), + )?, + Statement::RepackVector(repack) => insert_mem_ssa_statement_default( + id_def, + &mut result, + Statement::RepackVector(repack), + )?, + Statement::FunctionPointer(func_ptr) => insert_mem_ssa_statement_default( + id_def, + &mut result, + Statement::FunctionPointer(func_ptr), + )?, + s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => { + result.push(s) + } + _ => return Err(error_unreachable()), + } + } + Ok(result) +} + +fn insert_mem_ssa_argument( + id_def: &mut NumericIdResolver, + func: &mut Vec, + arg: &mut ast::Variable, + is_kernel: bool, +) { + if !is_kernel && arg.state_space == ast::StateSpace::Param { + return; + } + let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space))); + func.push(Statement::Variable(ast::Variable { + align: arg.align, + v_type: arg.v_type.clone(), + state_space: ast::StateSpace::Reg, + name: arg.name, + array_init: Vec::new(), + })); + func.push(Statement::StoreVar(StoreVarDetails { + arg: ast::StArgs { + src1: arg.name, + src2: new_id, + }, + typ: arg.v_type.clone(), + member_index: None, + })); + arg.name = new_id; +} + +fn insert_mem_ssa_argument_reg_return( + func: &mut Vec, + arg: &ast::Variable, +) { + func.push(Statement::Variable(ast::Variable { + align: arg.align, + v_type: arg.v_type.clone(), + state_space: arg.state_space, + name: arg.name, + array_init: arg.array_init.clone(), + })); +} + +fn insert_mem_ssa_statement_default<'a, 'input>( + id_def: &'a mut NumericIdResolver<'input>, + func: &'a mut Vec, + stmt: TypedStatement, +) -> Result<(), TranslateError> { + let mut visitor = InsertMemSSAVisitor { + id_def, + func, + post_statements: Vec::new(), + }; + let new_stmt = stmt.visit_map(&mut visitor)?; + visitor.func.push(new_stmt); + visitor.func.extend(visitor.post_statements); + Ok(()) +} + +struct InsertMemSSAVisitor<'a, 'input> { + id_def: &'a mut NumericIdResolver<'input>, + func: &'a mut Vec, + post_statements: Vec, +} + +impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { + fn symbol( + &mut self, + symbol: SpirvWord, + member_index: Option, + expected: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + ) -> Result { + if expected.is_none() { + return Ok(symbol); + }; + let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?; + if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable { + return Ok(symbol); + }; + let member_index = match member_index { + Some(idx) => { + let vector_width = match var_type { + ast::Type::Vector(width, scalar_t) => { + var_type = ast::Type::Scalar(scalar_t); + width + } + _ => return Err(error_mismatched_type()), + }; + Some(( + idx, + if self.id_def.special_registers.get(symbol).is_some() { + Some(vector_width) + } else { + None + }, + )) + } + None => None, + }; + let generated_id = self + .id_def + .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg))); + if !is_dst { + self.func.push(Statement::LoadVar(LoadVarDetails { + arg: ast::LdArgs { + dst: generated_id, + src: symbol, + }, + typ: var_type, + member_index, + })); + } else { + self.post_statements + .push(Statement::StoreVar(StoreVarDetails { + arg: ast::StArgs { + src1: symbol, + src2: generated_id, + }, + typ: var_type, + member_index: member_index.map(|(idx, _)| idx), + })); + } + Ok(generated_id) + } +} + +impl<'a, 'input> ast::VisitorMap + for InsertMemSSAVisitor<'a, 'input> +{ + fn visit( + &mut self, + operand: TypedOperand, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + Ok(match operand { + TypedOperand::Reg(reg) => { + TypedOperand::Reg(self.symbol(reg, None, type_space, is_dst)?) + } + TypedOperand::RegOffset(reg, offset) => { + TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset) + } + op @ TypedOperand::Imm(..) => op, + TypedOperand::VecMember(symbol, index) => { + TypedOperand::Reg(self.symbol(symbol, Some(index), type_space, is_dst)?) + } + }) + } + + fn visit_ident( + &mut self, + args: SpirvWord, + type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + self.symbol(args, None, type_space, is_dst) + } +} diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs new file mode 100644 index 00000000..2be6297a --- /dev/null +++ b/ptx/src/pass/mod.rs @@ -0,0 +1,1677 @@ +use ptx_parser as ast; +use rspirv::{binary::Assemble, dr}; +use std::hash::Hash; +use std::num::NonZeroU8; +use std::{ + borrow::Cow, + cell::RefCell, + collections::{hash_map, HashMap, HashSet}, + ffi::CString, + iter, + marker::PhantomData, + mem, + rc::Rc, +}; + +mod convert_dynamic_shared_memory_usage; +mod convert_to_stateful_memory_access; +mod convert_to_typed; +mod emit_spirv; +mod expand_arguments; +mod extract_globals; +mod fix_special_registers; +mod insert_implicit_conversions; +mod insert_mem_ssa_statements; +mod normalize_identifiers; +mod normalize_labels; +mod normalize_predicates; + +static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); +static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); +const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__"; + +pub fn to_spirv_module<'input>(ast: ast::Module<'input>) -> Result { + let mut id_defs = GlobalStringIdResolver::<'input>::new(SpirvWord(1)); + let mut ptx_impl_imports = HashMap::new(); + let directives = ast + .directives + .into_iter() + .filter_map(|directive| { + translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose() + }) + .collect::, _>>()?; + let directives = hoist_function_globals(directives); + let must_link_ptx_impl = ptx_impl_imports.len() > 0; + let mut directives = ptx_impl_imports + .into_iter() + .map(|(_, v)| v) + .chain(directives.into_iter()) + .collect::>(); + let mut builder = dr::Builder::new(); + builder.reserve_ids(id_defs.current_id().0); + let call_map = MethodsCallMap::new(&directives); + let mut directives = + convert_dynamic_shared_memory_usage::run(directives, &call_map, &mut || { + SpirvWord(builder.id()) + })?; + normalize_variable_decls(&mut directives); + let denorm_information = compute_denorm_information(&directives); + let (spirv, kernel_info, build_options) = + emit_spirv::run(builder, &id_defs, call_map, denorm_information, directives)?; + Ok(Module { + spirv, + kernel_info, + should_link_ptx_impl: if must_link_ptx_impl { + Some((ZLUDA_PTX_IMPL_INTEL, ZLUDA_PTX_IMPL_AMD)) + } else { + None + }, + build_options, + }) +} + +fn translate_directive<'input, 'a>( + id_defs: &'a mut GlobalStringIdResolver<'input>, + ptx_impl_imports: &'a mut HashMap>, + d: ast::Directive<'input, ast::ParsedOperand<&'input str>>, +) -> Result>, TranslateError> { + Ok(match d { + ast::Directive::Variable(linking, var) => Some(Directive::Variable( + linking, + ast::Variable { + align: var.align, + v_type: var.v_type.clone(), + state_space: var.state_space, + name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true), + array_init: var.array_init, + }, + )), + ast::Directive::Method(linkage, f) => { + translate_function(id_defs, ptx_impl_imports, linkage, f)?.map(Directive::Method) + } + }) +} + +type ParsedFunction<'a> = ast::Function<'a, &'a str, ast::Statement>>; + +fn translate_function<'input, 'a>( + id_defs: &'a mut GlobalStringIdResolver<'input>, + ptx_impl_imports: &'a mut HashMap>, + linkage: ast::LinkingDirective, + f: ParsedFunction<'input>, +) -> Result>, TranslateError> { + let import_as = match &f.func_directive { + ast::MethodDeclaration { + name: ast::MethodName::Func(func_name), + .. + } if *func_name == "__assertfail" || *func_name == "vprintf" => { + Some([ZLUDA_PTX_PREFIX, func_name].concat()) + } + _ => None, + }; + let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?; + let mut func = to_ssa( + ptx_impl_imports, + str_resolver, + fn_resolver, + fn_decl, + f.body, + f.tuning, + linkage, + )?; + func.import_as = import_as; + if func.import_as.is_some() { + ptx_impl_imports.insert( + func.import_as.as_ref().unwrap().clone(), + Directive::Method(func), + ); + Ok(None) + } else { + Ok(Some(func)) + } +} + +fn to_ssa<'input, 'b>( + ptx_impl_imports: &'b mut HashMap>, + mut id_defs: FnStringIdResolver<'input, 'b>, + fn_defs: GlobalFnDeclResolver<'input, 'b>, + func_decl: Rc>>, + f_body: Option>>>, + tuning: Vec, + linkage: ast::LinkingDirective, +) -> Result, TranslateError> { + //deparamize_function_decl(&func_decl)?; + let f_body = match f_body { + Some(vec) => vec, + None => { + return Ok(Function { + func_decl: func_decl, + body: None, + globals: Vec::new(), + import_as: None, + tuning, + linkage, + }) + } + }; + let normalized_ids = normalize_identifiers::run(&mut id_defs, &fn_defs, f_body)?; + let mut numeric_id_defs = id_defs.finish(); + let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?; + let typed_statements = + convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; + let typed_statements = + fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?; + let (func_decl, typed_statements) = + convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?; + let ssa_statements = insert_mem_ssa_statements::run( + typed_statements, + &mut numeric_id_defs, + &mut (*func_decl).borrow_mut(), + )?; + let mut numeric_id_defs = numeric_id_defs.finish(); + let expanded_statements = expand_arguments::run(ssa_statements, &mut numeric_id_defs)?; + let expanded_statements = + insert_implicit_conversions::run(expanded_statements, &mut numeric_id_defs)?; + let mut numeric_id_defs = numeric_id_defs.unmut(); + let labeled_statements = normalize_labels::run(expanded_statements, &mut numeric_id_defs); + let (f_body, globals) = + extract_globals::run(labeled_statements, ptx_impl_imports, &mut numeric_id_defs)?; + Ok(Function { + func_decl: func_decl, + globals: globals, + body: Some(f_body), + import_as: None, + tuning, + linkage, + }) +} + +pub struct Module { + pub spirv: dr::Module, + pub kernel_info: HashMap, + pub should_link_ptx_impl: Option<(&'static [u8], &'static [u8])>, + pub build_options: CString, +} + +impl Module { + pub fn assemble(&self) -> Vec { + self.spirv.assemble() + } +} + +struct GlobalStringIdResolver<'input> { + current_id: SpirvWord, + variables: HashMap, SpirvWord>, + reverse_variables: HashMap, + variables_type_check: HashMap>, + special_registers: SpecialRegistersMap, + fns: HashMap>, +} + +impl<'input> GlobalStringIdResolver<'input> { + fn new(start_id: SpirvWord) -> Self { + Self { + current_id: start_id, + variables: HashMap::new(), + reverse_variables: HashMap::new(), + variables_type_check: HashMap::new(), + special_registers: SpecialRegistersMap::new(), + fns: HashMap::new(), + } + } + + fn get_or_add_def(&mut self, id: &'input str) -> SpirvWord { + self.get_or_add_impl(id, None) + } + + fn get_or_add_def_typed( + &mut self, + id: &'input str, + typ: ast::Type, + state_space: ast::StateSpace, + is_variable: bool, + ) -> SpirvWord { + self.get_or_add_impl(id, Some((typ, state_space, is_variable))) + } + + fn get_or_add_impl( + &mut self, + id: &'input str, + typ: Option<(ast::Type, ast::StateSpace, bool)>, + ) -> SpirvWord { + let id = match self.variables.entry(Cow::Borrowed(id)) { + hash_map::Entry::Occupied(e) => *(e.get()), + hash_map::Entry::Vacant(e) => { + let numeric_id = self.current_id; + e.insert(numeric_id); + self.reverse_variables.insert(numeric_id, id); + self.current_id.0 += 1; + numeric_id + } + }; + self.variables_type_check.insert(id, typ); + id + } + + fn get_id(&self, id: &str) -> Result { + self.variables + .get(id) + .copied() + .ok_or_else(error_unknown_symbol) + } + + fn current_id(&self) -> SpirvWord { + self.current_id + } + + fn start_fn<'b>( + &'b mut self, + header: &'b ast::MethodDeclaration<'input, &'input str>, + ) -> Result< + ( + FnStringIdResolver<'input, 'b>, + GlobalFnDeclResolver<'input, 'b>, + Rc>>, + ), + TranslateError, + > { + // In case a function decl was inserted earlier we want to use its id + let name_id = self.get_or_add_def(header.name()); + let mut fn_resolver = FnStringIdResolver { + current_id: &mut self.current_id, + global_variables: &self.variables, + global_type_check: &self.variables_type_check, + special_registers: &mut self.special_registers, + variables: vec![HashMap::new(); 1], + type_check: HashMap::new(), + }; + let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments); + let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments); + let name = match header.name { + ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), + ast::MethodName::Func(_) => ast::MethodName::Func(name_id), + }; + let fn_decl = ast::MethodDeclaration { + return_arguments, + name, + input_arguments, + shared_mem: None, + }; + let new_fn_decl = if !matches!(fn_decl.name, ast::MethodName::Kernel(_)) { + let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl); + let new_fn_decl = resolver.func_decl.clone(); + self.fns.insert(name_id, resolver); + new_fn_decl + } else { + Rc::new(RefCell::new(fn_decl)) + }; + Ok(( + fn_resolver, + GlobalFnDeclResolver { fns: &self.fns }, + new_fn_decl, + )) + } +} + +fn rename_fn_params<'a, 'b>( + fn_resolver: &mut FnStringIdResolver<'a, 'b>, + args: &'b [ast::Variable<&'a str>], +) -> Vec> { + args.iter() + .map(|a| ast::Variable { + name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true), + v_type: a.v_type.clone(), + state_space: a.state_space, + align: a.align, + array_init: a.array_init.clone(), + }) + .collect() +} + +pub struct KernelInfo { + pub arguments_sizes: Vec<(usize, bool)>, + pub uses_shared_mem: bool, +} + +#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] +enum PtxSpecialRegister { + Tid, + Ntid, + Ctaid, + Nctaid, + Clock, + LanemaskLt, +} + +impl PtxSpecialRegister { + fn try_parse(s: &str) -> Option { + match s { + "%tid" => Some(Self::Tid), + "%ntid" => Some(Self::Ntid), + "%ctaid" => Some(Self::Ctaid), + "%nctaid" => Some(Self::Nctaid), + "%clock" => Some(Self::Clock), + "%lanemask_lt" => Some(Self::LanemaskLt), + _ => None, + } + } + + fn get_type(self) -> ast::Type { + match self { + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => ast::Type::Vector(4, self.get_function_return_type()), + _ => ast::Type::Scalar(self.get_function_return_type()), + } + } + + fn get_function_return_type(self) -> ast::ScalarType { + match self { + PtxSpecialRegister::Tid => ast::ScalarType::U32, + PtxSpecialRegister::Ntid => ast::ScalarType::U32, + PtxSpecialRegister::Ctaid => ast::ScalarType::U32, + PtxSpecialRegister::Nctaid => ast::ScalarType::U32, + PtxSpecialRegister::Clock => ast::ScalarType::U32, + PtxSpecialRegister::LanemaskLt => ast::ScalarType::U32, + } + } + + fn get_function_input_type(self) -> Option { + match self { + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8), + PtxSpecialRegister::Clock | PtxSpecialRegister::LanemaskLt => None, + } + } + + fn get_unprefixed_function_name(self) -> &'static str { + match self { + PtxSpecialRegister::Tid => "sreg_tid", + PtxSpecialRegister::Ntid => "sreg_ntid", + PtxSpecialRegister::Ctaid => "sreg_ctaid", + PtxSpecialRegister::Nctaid => "sreg_nctaid", + PtxSpecialRegister::Clock => "sreg_clock", + PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt", + } + } +} + +struct SpecialRegistersMap { + reg_to_id: HashMap, + id_to_reg: HashMap, +} + +impl SpecialRegistersMap { + fn new() -> Self { + SpecialRegistersMap { + reg_to_id: HashMap::new(), + id_to_reg: HashMap::new(), + } + } + + fn get(&self, id: SpirvWord) -> Option { + self.id_to_reg.get(&id).copied() + } + + fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord { + match self.reg_to_id.entry(reg) { + hash_map::Entry::Occupied(e) => *e.get(), + hash_map::Entry::Vacant(e) => { + let numeric_id = SpirvWord(current_id.0); + current_id.0 += 1; + e.insert(numeric_id); + self.id_to_reg.insert(numeric_id, reg); + numeric_id + } + } + } +} + +struct FnStringIdResolver<'input, 'b> { + current_id: &'b mut SpirvWord, + global_variables: &'b HashMap, SpirvWord>, + global_type_check: &'b HashMap>, + special_registers: &'b mut SpecialRegistersMap, + variables: Vec, SpirvWord>>, + type_check: HashMap>, +} + +impl<'a, 'b> FnStringIdResolver<'a, 'b> { + fn finish(self) -> NumericIdResolver<'b> { + NumericIdResolver { + current_id: self.current_id, + global_type_check: self.global_type_check, + type_check: self.type_check, + special_registers: self.special_registers, + } + } + + fn start_block(&mut self) { + self.variables.push(HashMap::new()) + } + + fn end_block(&mut self) { + self.variables.pop(); + } + + fn get_id(&mut self, id: &str) -> Result { + for scope in self.variables.iter().rev() { + match scope.get(id) { + Some(id) => return Ok(*id), + None => continue, + } + } + match self.global_variables.get(id) { + Some(id) => Ok(*id), + None => { + let sreg = PtxSpecialRegister::try_parse(id).ok_or_else(error_unknown_symbol)?; + Ok(self.special_registers.get_or_add(self.current_id, sreg)) + } + } + } + + fn add_def( + &mut self, + id: &'a str, + typ: Option<(ast::Type, ast::StateSpace)>, + is_variable: bool, + ) -> SpirvWord { + let numeric_id = *self.current_id; + self.variables + .last_mut() + .unwrap() + .insert(Cow::Borrowed(id), numeric_id); + self.type_check.insert( + numeric_id, + typ.map(|(typ, space)| (typ, space, is_variable)), + ); + self.current_id.0 += 1; + numeric_id + } + + #[must_use] + fn add_defs( + &mut self, + base_id: &'a str, + count: u32, + typ: ast::Type, + state_space: ast::StateSpace, + is_variable: bool, + ) -> impl Iterator { + let numeric_id = *self.current_id; + for i in 0..count { + self.variables.last_mut().unwrap().insert( + Cow::Owned(format!("{}{}", base_id, i)), + SpirvWord(numeric_id.0 + i), + ); + self.type_check.insert( + SpirvWord(numeric_id.0 + i), + Some((typ.clone(), state_space, is_variable)), + ); + } + self.current_id.0 += count; + (0..count) + .into_iter() + .map(move |i| SpirvWord(i + numeric_id.0)) + } +} + +struct NumericIdResolver<'b> { + current_id: &'b mut SpirvWord, + global_type_check: &'b HashMap>, + type_check: HashMap>, + special_registers: &'b mut SpecialRegistersMap, +} + +impl<'b> NumericIdResolver<'b> { + fn finish(self) -> MutableNumericIdResolver<'b> { + MutableNumericIdResolver { base: self } + } + + fn get_typed( + &self, + id: SpirvWord, + ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> { + match self.type_check.get(&id) { + Some(Some(x)) => Ok(x.clone()), + Some(None) => Err(TranslateError::UntypedSymbol), + None => match self.special_registers.get(id) { + Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)), + None => match self.global_type_check.get(&id) { + Some(Some(result)) => Ok(result.clone()), + Some(None) | None => Err(TranslateError::UntypedSymbol), + }, + }, + } + } + + // This is for identifiers which will be emitted later as OpVariable + // They are candidates for insertion of LoadVar/StoreVar + fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord { + let new_id = *self.current_id; + self.type_check + .insert(new_id, Some((typ, state_space, true))); + self.current_id.0 += 1; + new_id + } + + fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord { + let new_id = *self.current_id; + self.type_check + .insert(new_id, typ.map(|(t, space)| (t, space, false))); + self.current_id.0 += 1; + new_id + } +} + +struct MutableNumericIdResolver<'b> { + base: NumericIdResolver<'b>, +} + +impl<'b> MutableNumericIdResolver<'b> { + fn unmut(self) -> NumericIdResolver<'b> { + self.base + } + + fn get_typed(&self, id: SpirvWord) -> Result<(ast::Type, ast::StateSpace), TranslateError> { + self.base.get_typed(id).map(|(t, space, _)| (t, space)) + } + + fn register_intermediate(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> SpirvWord { + self.base.register_intermediate(Some((typ, state_space))) + } +} + +quick_error! { + #[derive(Debug)] + pub enum TranslateError { + UnknownSymbol {} + UntypedSymbol {} + MismatchedType {} + Spirv(err: rspirv::dr::Error) { + from() + display("{}", err) + cause(err) + } + Unreachable {} + Todo {} + } +} + +#[cfg(debug_assertions)] +fn error_unreachable() -> TranslateError { + unreachable!() +} + +#[cfg(not(debug_assertions))] +fn error_unreachable() -> TranslateError { + TranslateError::Unreachable +} + +fn error_unknown_symbol() -> TranslateError { + panic!() +} + +#[cfg(not(debug_assertions))] +fn error_unknown_symbol() -> TranslateError { + TranslateError::UnknownSymbol +} + +fn error_mismatched_type() -> TranslateError { + panic!() +} + +#[cfg(not(debug_assertions))] +fn error_mismatched_type() -> TranslateError { + TranslateError::MismatchedType +} + +pub struct GlobalFnDeclResolver<'input, 'a> { + fns: &'a HashMap>, +} + +impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { + fn get_fn_sig_resolver(&self, id: SpirvWord) -> Result<&FnSigMapper<'input>, TranslateError> { + self.fns.get(&id).ok_or_else(error_unknown_symbol) + } +} + +struct FnSigMapper<'input> { + // true - stays as return argument + // false - is moved to input argument + return_param_args: Vec, + func_decl: Rc>>, +} + +impl<'input> FnSigMapper<'input> { + fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, SpirvWord>) -> Self { + let return_param_args = method + .return_arguments + .iter() + .map(|a| a.state_space != ast::StateSpace::Param) + .collect::>(); + let mut new_return_arguments = Vec::new(); + for arg in method.return_arguments.into_iter() { + if arg.state_space == ast::StateSpace::Param { + method.input_arguments.push(arg); + } else { + new_return_arguments.push(arg); + } + } + method.return_arguments = new_return_arguments; + FnSigMapper { + return_param_args, + func_decl: Rc::new(RefCell::new(method)), + } + } + + fn resolve_in_spirv_repr( + &self, + data: ast::CallDetails, + arguments: ast::CallArgs>, + ) -> Result>, TranslateError> { + let func_decl = (*self.func_decl).borrow(); + let mut data_return = Vec::new(); + let mut arguments_return = Vec::new(); + let mut data_input = data.input_arguments; + let mut arguments_input = arguments.input_arguments; + let mut func_decl_return_iter = func_decl.return_arguments.iter(); + let mut func_decl_input_iter = func_decl.input_arguments[arguments_input.len()..].iter(); + for (idx, id) in arguments.return_arguments.iter().enumerate() { + let stays_as_return = match self.return_param_args.get(idx) { + Some(x) => *x, + None => return Err(TranslateError::MismatchedType), + }; + if stays_as_return { + if let Some(var) = func_decl_return_iter.next() { + data_return.push((var.v_type.clone(), var.state_space)); + arguments_return.push(*id); + } else { + return Err(TranslateError::MismatchedType); + } + } else { + if let Some(var) = func_decl_input_iter.next() { + data_input.push((var.v_type.clone(), var.state_space)); + arguments_input.push(ast::ParsedOperand::Reg(*id)); + } else { + return Err(TranslateError::MismatchedType); + } + } + } + if arguments_return.len() != func_decl.return_arguments.len() + || arguments_input.len() != func_decl.input_arguments.len() + { + return Err(TranslateError::MismatchedType); + } + let data = ast::CallDetails { + uniform: data.uniform, + return_arguments: data_return, + input_arguments: data_input, + }; + let arguments = ast::CallArgs { + func: arguments.func, + return_arguments: arguments_return, + input_arguments: arguments_input, + }; + Ok(ast::Instruction::Call { data, arguments }) + } +} + +enum Statement { + Label(SpirvWord), + Variable(ast::Variable), + Instruction(I), + // SPIR-V compatible replacement for PTX predicates + Conditional(BrachCondition), + LoadVar(LoadVarDetails), + StoreVar(StoreVarDetails), + Conversion(ImplicitConversion), + Constant(ConstantDefinition), + RetValue(ast::RetData, SpirvWord), + PtrAccess(PtrAccess

), + RepackVector(RepackVectorDetails), + FunctionPointer(FunctionPointerDetails), +} + +impl> Statement, T> { + fn visit_map, Err>( + self, + visitor: &mut impl ast::VisitorMap, + ) -> std::result::Result, To>, Err> { + Ok(match self { + Statement::Instruction(i) => { + return ast::visit_map(i, visitor).map(Statement::Instruction) + } + Statement::Label(label) => { + Statement::Label(visitor.visit_ident(label, None, false, false)?) + } + Statement::Variable(var) => { + let name = visitor.visit_ident( + var.name, + Some((&var.v_type, var.state_space)), + true, + false, + )?; + Statement::Variable(ast::Variable { + align: var.align, + v_type: var.v_type, + state_space: var.state_space, + name, + array_init: var.array_init, + }) + } + Statement::Conditional(conditional) => { + let predicate = visitor.visit_ident( + conditional.predicate, + Some((&ast::ScalarType::Pred.into(), ast::StateSpace::Reg)), + false, + false, + )?; + let if_true = visitor.visit_ident(conditional.if_true, None, false, false)?; + let if_false = visitor.visit_ident(conditional.if_false, None, false, false)?; + Statement::Conditional(BrachCondition { + predicate, + if_true, + if_false, + }) + } + Statement::LoadVar(LoadVarDetails { + arg, + typ, + member_index, + }) => { + let dst = visitor.visit_ident( + arg.dst, + Some((&typ, ast::StateSpace::Reg)), + true, + false, + )?; + let src = visitor.visit_ident( + arg.src, + Some((&typ, ast::StateSpace::Local)), + false, + false, + )?; + Statement::LoadVar(LoadVarDetails { + arg: ast::LdArgs { dst, src }, + typ, + member_index, + }) + } + Statement::StoreVar(StoreVarDetails { + arg, + typ, + member_index, + }) => { + let src1 = visitor.visit_ident( + arg.src1, + Some((&typ, ast::StateSpace::Local)), + false, + false, + )?; + let src2 = visitor.visit_ident( + arg.src2, + Some((&typ, ast::StateSpace::Reg)), + false, + false, + )?; + Statement::StoreVar(StoreVarDetails { + arg: ast::StArgs { src1, src2 }, + typ, + member_index, + }) + } + Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + to_type, + from_space, + to_space, + kind, + }) => { + let dst = visitor.visit_ident( + dst, + Some((&to_type, ast::StateSpace::Reg)), + true, + false, + )?; + let src = visitor.visit_ident( + src, + Some((&from_type, ast::StateSpace::Reg)), + false, + false, + )?; + Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + to_type, + from_space, + to_space, + kind, + }) + } + Statement::Constant(ConstantDefinition { dst, typ, value }) => { + let dst = visitor.visit_ident( + dst, + Some((&typ.into(), ast::StateSpace::Reg)), + true, + false, + )?; + Statement::Constant(ConstantDefinition { dst, typ, value }) + } + Statement::RetValue(data, value) => { + // TODO: + // We should report type here + let value = visitor.visit_ident(value, None, false, false)?; + Statement::RetValue(data, value) + } + Statement::PtrAccess(PtrAccess { + underlying_type, + state_space, + dst, + ptr_src, + offset_src, + }) => { + let dst = + visitor.visit_ident(dst, Some((&underlying_type, state_space)), true, false)?; + let ptr_src = visitor.visit_ident( + ptr_src, + Some((&underlying_type, state_space)), + false, + false, + )?; + let offset_src = visitor.visit( + offset_src, + Some(( + &ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + )), + false, + false, + )?; + Statement::PtrAccess(PtrAccess { + underlying_type, + state_space, + dst, + ptr_src, + offset_src, + }) + } + Statement::RepackVector(RepackVectorDetails { + is_extract, + typ, + packed, + unpacked, + relaxed_type_check, + }) => { + let (packed, unpacked) = if is_extract { + let unpacked = unpacked + .into_iter() + .map(|ident| { + visitor.visit_ident( + ident, + Some((&typ.into(), ast::StateSpace::Reg)), + true, + relaxed_type_check, + ) + }) + .collect::, _>>()?; + let packed = visitor.visit_ident( + packed, + Some(( + &ast::Type::Vector(unpacked.len() as u8, typ), + ast::StateSpace::Reg, + )), + false, + false, + )?; + (packed, unpacked) + } else { + let packed = visitor.visit_ident( + packed, + Some(( + &ast::Type::Vector(unpacked.len() as u8, typ), + ast::StateSpace::Reg, + )), + true, + false, + )?; + let unpacked = unpacked + .into_iter() + .map(|ident| { + visitor.visit_ident( + ident, + Some((&typ.into(), ast::StateSpace::Reg)), + false, + relaxed_type_check, + ) + }) + .collect::, _>>()?; + (packed, unpacked) + }; + Statement::RepackVector(RepackVectorDetails { + is_extract, + typ, + packed, + unpacked, + relaxed_type_check, + }) + } + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { + let dst = visitor.visit_ident( + dst, + Some(( + &ast::Type::Scalar(ast::ScalarType::U64), + ast::StateSpace::Reg, + )), + true, + false, + )?; + let src = visitor.visit_ident(src, None, false, false)?; + Statement::FunctionPointer(FunctionPointerDetails { dst, src }) + } + }) + } +} + +struct BrachCondition { + predicate: SpirvWord, + if_true: SpirvWord, + if_false: SpirvWord, +} +struct LoadVarDetails { + arg: ast::LdArgs, + typ: ast::Type, + // (index, vector_width) + // HACK ALERT + // For some reason IGC explodes when you try to load from builtin vectors + // using OpInBoundsAccessChain, the one true way to do it is to + // OpLoad+OpCompositeExtract + member_index: Option<(u8, Option)>, +} + +struct StoreVarDetails { + arg: ast::StArgs, + typ: ast::Type, + member_index: Option, +} + +#[derive(Clone)] +struct ImplicitConversion { + src: SpirvWord, + dst: SpirvWord, + from_type: ast::Type, + to_type: ast::Type, + from_space: ast::StateSpace, + to_space: ast::StateSpace, + kind: ConversionKind, +} + +#[derive(PartialEq, Clone)] +enum ConversionKind { + Default, + // zero-extend/chop/bitcast depending on types + SignExtend, + BitToPtr, + PtrToPtr, + AddressOf, +} + +struct ConstantDefinition { + pub dst: SpirvWord, + pub typ: ast::ScalarType, + pub value: ast::ImmediateValue, +} + +pub struct PtrAccess { + underlying_type: ast::Type, + state_space: ast::StateSpace, + dst: SpirvWord, + ptr_src: SpirvWord, + offset_src: T, +} + +struct RepackVectorDetails { + is_extract: bool, + typ: ast::ScalarType, + packed: SpirvWord, + unpacked: Vec, + relaxed_type_check: bool, +} + +struct FunctionPointerDetails { + dst: SpirvWord, + src: SpirvWord, +} + +#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +struct SpirvWord(spirv::Word); + +impl From for SpirvWord { + fn from(value: spirv::Word) -> Self { + Self(value) + } +} +impl From for spirv::Word { + fn from(value: SpirvWord) -> Self { + value.0 + } +} + +impl ast::Operand for SpirvWord { + type Ident = Self; + + fn from_ident(ident: Self::Ident) -> Self { + ident + } +} + +fn pred_map_variable Result>( + this: ast::PredAt, + f: &mut F, +) -> Result, TranslateError> { + let new_label = f(this.label)?; + Ok(ast::PredAt { + not: this.not, + label: new_label, + }) +} + +pub(crate) enum Directive<'input> { + Variable(ast::LinkingDirective, ast::Variable), + Method(Function<'input>), +} + +pub(crate) struct Function<'input> { + pub func_decl: Rc>>, + pub globals: Vec>, + pub body: Option>, + import_as: Option, + tuning: Vec, + linkage: ast::LinkingDirective, +} + +type ExpandedStatement = Statement, SpirvWord>; + +type NormalizedStatement = Statement< + ( + Option>, + ast::Instruction>, + ), + ast::ParsedOperand, +>; + +type UnconditionalStatement = + Statement>, ast::ParsedOperand>; + +type TypedStatement = Statement, TypedOperand>; + +#[derive(Copy, Clone)] +enum TypedOperand { + Reg(SpirvWord), + RegOffset(SpirvWord, i32), + Imm(ast::ImmediateValue), + VecMember(SpirvWord, u8), +} + +impl TypedOperand { + fn map( + self, + fn_: impl FnOnce(SpirvWord, Option) -> Result, + ) -> Result { + Ok(match self { + TypedOperand::Reg(reg) => TypedOperand::Reg(fn_(reg, None)?), + TypedOperand::RegOffset(reg, off) => TypedOperand::RegOffset(fn_(reg, None)?, off), + TypedOperand::Imm(imm) => TypedOperand::Imm(imm), + TypedOperand::VecMember(reg, idx) => TypedOperand::VecMember(fn_(reg, Some(idx))?, idx), + }) + } + + fn underlying_register(&self) -> Option { + match self { + Self::Reg(r) | Self::RegOffset(r, _) | Self::VecMember(r, _) => Some(*r), + Self::Imm(_) => None, + } + } + + fn unwrap_reg(&self) -> Result { + match self { + TypedOperand::Reg(reg) => Ok(*reg), + _ => Err(error_unreachable()), + } + } +} + +impl ast::Operand for TypedOperand { + type Ident = SpirvWord; + + fn from_ident(ident: Self::Ident) -> Self { + TypedOperand::Reg(ident) + } +} + +impl ast::VisitorMap + for FnVisitor +where + Fn: FnMut( + TypedOperand, + Option<(&ast::Type, ast::StateSpace)>, + bool, + bool, + ) -> Result, +{ + fn visit( + &mut self, + args: TypedOperand, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + (self.fn_)(args, type_space, is_dst, relaxed_type_check) + } + + fn visit_ident( + &mut self, + args: SpirvWord, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + match (self.fn_)( + TypedOperand::Reg(args), + type_space, + is_dst, + relaxed_type_check, + )? { + TypedOperand::Reg(reg) => Ok(reg), + _ => Err(TranslateError::Unreachable), + } + } +} + +struct FnVisitor< + T, + U, + Err, + Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result, +> { + fn_: Fn, + _marker: PhantomData Result>, +} + +impl< + T, + U, + Err, + Fn: FnMut(T, Option<(&ast::Type, ast::StateSpace)>, bool, bool) -> Result, + > FnVisitor +{ + fn new(fn_: Fn) -> Self { + Self { + fn_, + _marker: PhantomData, + } + } +} + +fn space_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { + this == other + || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg + || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg +} + +fn register_external_fn_call<'a>( + id_defs: &mut NumericIdResolver, + ptx_impl_imports: &mut HashMap, + name: String, + return_arguments: impl Iterator, + input_arguments: impl Iterator, +) -> Result { + match ptx_impl_imports.entry(name) { + hash_map::Entry::Vacant(entry) => { + let fn_id = id_defs.register_intermediate(None); + let return_arguments = fn_arguments_to_variables(id_defs, return_arguments); + let input_arguments = fn_arguments_to_variables(id_defs, input_arguments); + let func_decl = ast::MethodDeclaration:: { + return_arguments, + name: ast::MethodName::Func(fn_id), + input_arguments, + shared_mem: None, + }; + let func = Function { + func_decl: Rc::new(RefCell::new(func_decl)), + globals: Vec::new(), + body: None, + import_as: Some(entry.key().clone()), + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + }; + entry.insert(Directive::Method(func)); + Ok(fn_id) + } + hash_map::Entry::Occupied(entry) => match entry.get() { + Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name { + ast::MethodName::Func(fn_id) => Ok(fn_id), + ast::MethodName::Kernel(_) => Err(error_unreachable()), + }, + _ => Err(error_unreachable()), + }, + } +} + +fn fn_arguments_to_variables<'a>( + id_defs: &mut NumericIdResolver, + args: impl Iterator, +) -> Vec> { + args.map(|(typ, space)| ast::Variable { + align: None, + v_type: typ.clone(), + state_space: space, + name: id_defs.register_intermediate(None), + array_init: Vec::new(), + }) + .collect::>() +} + +fn hoist_function_globals(directives: Vec) -> Vec { + let mut result = Vec::with_capacity(directives.len()); + for directive in directives { + match directive { + Directive::Method(method) => { + for variable in method.globals { + result.push(Directive::Variable(ast::LinkingDirective::NONE, variable)); + } + result.push(Directive::Method(Function { + globals: Vec::new(), + ..method + })) + } + _ => result.push(directive), + } + } + result +} + +struct MethodsCallMap<'input> { + map: HashMap, HashSet>, +} + +impl<'input> MethodsCallMap<'input> { + fn new(module: &[Directive<'input>]) -> Self { + let mut directly_called_by = HashMap::new(); + for directive in module { + match directive { + Directive::Method(Function { + func_decl, + body: Some(statements), + .. + }) => { + let call_key: ast::MethodName<_> = (**func_decl).borrow().name; + if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) { + entry.insert(Vec::new()); + } + for statement in statements { + match statement { + Statement::Instruction(ast::Instruction::Call { data, arguments }) => { + multi_hash_map_append( + &mut directly_called_by, + call_key, + arguments.func, + ); + } + _ => {} + } + } + } + _ => {} + } + } + let mut result = HashMap::new(); + for (&method_key, children) in directly_called_by.iter() { + let mut visited = HashSet::new(); + for child in children { + Self::add_call_map_single(&directly_called_by, &mut visited, *child); + } + result.insert(method_key, visited); + } + MethodsCallMap { map: result } + } + + fn add_call_map_single( + directly_called_by: &HashMap, Vec>, + visited: &mut HashSet, + current: SpirvWord, + ) { + if !visited.insert(current) { + return; + } + if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) { + for child in children { + Self::add_call_map_single(directly_called_by, visited, *child); + } + } + } + + fn get_kernel_children(&self, name: &'input str) -> impl Iterator { + self.map + .get(&ast::MethodName::Kernel(name)) + .into_iter() + .flatten() + } + + fn kernels(&self) -> impl Iterator)> { + self.map + .iter() + .filter_map(|(method, children)| match method { + ast::MethodName::Kernel(kernel) => Some((*kernel, children)), + ast::MethodName::Func(..) => None, + }) + } + + fn methods( + &self, + ) -> impl Iterator, &HashSet)> { + self.map + .iter() + .map(|(method, children)| (*method, children)) + } + + fn visit_callees(&self, method: ast::MethodName<'input, SpirvWord>, f: impl FnMut(SpirvWord)) { + self.map + .get(&method) + .into_iter() + .flatten() + .copied() + .for_each(f); + } +} + +fn multi_hash_map_append< + K: Eq + std::hash::Hash, + V, + Collection: std::iter::Extend + std::default::Default, +>( + m: &mut HashMap, + key: K, + value: V, +) { + match m.entry(key) { + hash_map::Entry::Occupied(mut entry) => { + entry.get_mut().extend(iter::once(value)); + } + hash_map::Entry::Vacant(entry) => { + entry.insert(Default::default()).extend(iter::once(value)); + } + } +} + +fn normalize_variable_decls(directives: &mut Vec) { + for directive in directives { + match directive { + Directive::Method(Function { + body: Some(func), .. + }) => { + func[1..].sort_by_key(|s| match s { + Statement::Variable(_) => 0, + _ => 1, + }); + } + _ => (), + } + } +} + +// HACK ALERT! +// This function is a "good enough" heuristic of whetever to mark f16/f32 operations +// in the kernel as flushing denorms to zero or preserving them +// PTX support per-instruction ftz information. Unfortunately SPIR-V has no +// such capability, so instead we guesstimate which use is more common in the kernel +// and emit suitable execution mode +fn compute_denorm_information<'input>( + module: &[Directive<'input>], +) -> HashMap, HashMap> { + let mut denorm_methods = HashMap::new(); + for directive in module { + match directive { + Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {} + Directive::Method(Function { + func_decl, + body: Some(statements), + .. + }) => { + let mut flush_counter = DenormCountMap::new(); + let method_key = (**func_decl).borrow().name; + for statement in statements { + match statement { + Statement::Instruction(inst) => { + if let Some((flush, width)) = flush_to_zero(inst) { + denorm_count_map_update(&mut flush_counter, width, flush); + } + } + Statement::LoadVar(..) => {} + Statement::StoreVar(..) => {} + Statement::Conditional(_) => {} + Statement::Conversion(_) => {} + Statement::Constant(_) => {} + Statement::RetValue(_, _) => {} + Statement::Label(_) => {} + Statement::Variable(_) => {} + Statement::PtrAccess { .. } => {} + Statement::RepackVector(_) => {} + Statement::FunctionPointer(_) => {} + } + } + denorm_methods.insert(method_key, flush_counter); + } + } + } + denorm_methods + .into_iter() + .map(|(name, v)| { + let width_to_denorm = v + .into_iter() + .map(|(k, flush_over_preserve)| { + let mode = if flush_over_preserve > 0 { + spirv::FPDenormMode::FlushToZero + } else { + spirv::FPDenormMode::Preserve + }; + (k, (mode, flush_over_preserve)) + }) + .collect(); + (name, width_to_denorm) + }) + .collect() +} + +fn flush_to_zero(this: &ast::Instruction) -> Option<(bool, u8)> { + match this { + ast::Instruction::Ld { .. } => None, + ast::Instruction::St { .. } => None, + ast::Instruction::Mov { .. } => None, + ast::Instruction::Not { .. } => None, + ast::Instruction::Bra { .. } => None, + ast::Instruction::Shl { .. } => None, + ast::Instruction::Shr { .. } => None, + ast::Instruction::Ret { .. } => None, + ast::Instruction::Call { .. } => None, + ast::Instruction::Or { .. } => None, + ast::Instruction::And { .. } => None, + ast::Instruction::Cvta { .. } => None, + ast::Instruction::Selp { .. } => None, + ast::Instruction::Bar { .. } => None, + ast::Instruction::Atom { .. } => None, + ast::Instruction::AtomCas { .. } => None, + ast::Instruction::Sub { + data: ast::ArithDetails::Integer(_), + .. + } => None, + ast::Instruction::Add { + data: ast::ArithDetails::Integer(_), + .. + } => None, + ast::Instruction::Mul { + data: ast::MulDetails::Integer { .. }, + .. + } => None, + ast::Instruction::Mad { + data: ast::MadDetails::Integer { .. }, + .. + } => None, + ast::Instruction::Min { + data: ast::MinMaxDetails::Signed(_), + .. + } => None, + ast::Instruction::Min { + data: ast::MinMaxDetails::Unsigned(_), + .. + } => None, + ast::Instruction::Max { + data: ast::MinMaxDetails::Signed(_), + .. + } => None, + ast::Instruction::Max { + data: ast::MinMaxDetails::Unsigned(_), + .. + } => None, + ast::Instruction::Cvt { + data: + ast::CvtDetails { + mode: + ast::CvtMode::ZeroExtend + | ast::CvtMode::SignExtend + | ast::CvtMode::Truncate + | ast::CvtMode::Bitcast + | ast::CvtMode::SaturateUnsignedToSigned + | ast::CvtMode::SaturateSignedToUnsigned + | ast::CvtMode::FPFromSigned(_) + | ast::CvtMode::FPFromUnsigned(_), + .. + }, + .. + } => None, + ast::Instruction::Div { + data: ast::DivDetails::Unsigned(_), + .. + } => None, + ast::Instruction::Div { + data: ast::DivDetails::Signed(_), + .. + } => None, + ast::Instruction::Clz { .. } => None, + ast::Instruction::Brev { .. } => None, + ast::Instruction::Popc { .. } => None, + ast::Instruction::Xor { .. } => None, + ast::Instruction::Bfe { .. } => None, + ast::Instruction::Bfi { .. } => None, + ast::Instruction::Rem { .. } => None, + ast::Instruction::Prmt { .. } => None, + ast::Instruction::Activemask { .. } => None, + ast::Instruction::Membar { .. } => None, + ast::Instruction::Sub { + data: ast::ArithDetails::Float(float_control), + .. + } + | ast::Instruction::Add { + data: ast::ArithDetails::Float(float_control), + .. + } + | ast::Instruction::Mul { + data: ast::MulDetails::Float(float_control), + .. + } + | ast::Instruction::Mad { + data: ast::MadDetails::Float(float_control), + .. + } => float_control + .flush_to_zero + .map(|ftz| (ftz, float_control.type_.size_of())), + ast::Instruction::Fma { data, .. } => { + data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) + } + ast::Instruction::Setp { data, .. } => { + data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) + } + ast::Instruction::SetpBool { data, .. } => data + .base + .flush_to_zero + .map(|ftz| (ftz, data.base.type_.size_of())), + ast::Instruction::Abs { data, .. } + | ast::Instruction::Rsqrt { data, .. } + | ast::Instruction::Neg { data, .. } + | ast::Instruction::Ex2 { data, .. } => { + data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) + } + ast::Instruction::Min { + data: ast::MinMaxDetails::Float(float_control), + .. + } + | ast::Instruction::Max { + data: ast::MinMaxDetails::Float(float_control), + .. + } => float_control + .flush_to_zero + .map(|ftz| (ftz, ast::ScalarType::from(float_control.type_).size_of())), + ast::Instruction::Sqrt { data, .. } | ast::Instruction::Rcp { data, .. } => { + data.flush_to_zero.map(|ftz| (ftz, data.type_.size_of())) + } + // Modifier .ftz can only be specified when either .dtype or .atype + // is .f32 and applies only to single precision (.f32) inputs and results. + ast::Instruction::Cvt { + data: + ast::CvtDetails { + mode: + ast::CvtMode::FPExtend { flush_to_zero } + | ast::CvtMode::FPTruncate { flush_to_zero, .. } + | ast::CvtMode::FPRound { flush_to_zero, .. } + | ast::CvtMode::SignedFromFP { flush_to_zero, .. } + | ast::CvtMode::UnsignedFromFP { flush_to_zero, .. }, + .. + }, + .. + } => flush_to_zero.map(|ftz| (ftz, 4)), + ast::Instruction::Div { + data: + ast::DivDetails::Float(ast::DivFloatDetails { + type_, + flush_to_zero, + .. + }), + .. + } => flush_to_zero.map(|ftz| (ftz, type_.size_of())), + ast::Instruction::Sin { data, .. } + | ast::Instruction::Cos { data, .. } + | ast::Instruction::Lg2 { data, .. } => { + Some((data.flush_to_zero, mem::size_of::() as u8)) + } + ptx_parser::Instruction::PrmtSlow { .. } => None, + ptx_parser::Instruction::Trap {} => None, + } +} + +type DenormCountMap = HashMap; + +fn denorm_count_map_update(map: &mut DenormCountMap, key: T, value: bool) { + let num_value = if value { 1 } else { -1 }; + denorm_count_map_update_impl(map, key, num_value); +} + +fn denorm_count_map_update_impl( + map: &mut DenormCountMap, + key: T, + num_value: isize, +) { + match map.entry(key) { + hash_map::Entry::Occupied(mut counter) => { + *(counter.get_mut()) += num_value; + } + hash_map::Entry::Vacant(entry) => { + entry.insert(num_value); + } + } +} diff --git a/ptx/src/pass/normalize_identifiers.rs b/ptx/src/pass/normalize_identifiers.rs new file mode 100644 index 00000000..b5983453 --- /dev/null +++ b/ptx/src/pass/normalize_identifiers.rs @@ -0,0 +1,80 @@ +use super::*; +use ptx_parser as ast; + +pub(crate) fn run<'input, 'b>( + id_defs: &mut FnStringIdResolver<'input, 'b>, + fn_defs: &GlobalFnDeclResolver<'input, 'b>, + func: Vec>>, +) -> Result, TranslateError> { + for s in func.iter() { + match s { + ast::Statement::Label(id) => { + id_defs.add_def(*id, None, false); + } + _ => (), + } + } + let mut result = Vec::new(); + for s in func { + expand_map_variables(id_defs, fn_defs, &mut result, s)?; + } + Ok(result) +} + +fn expand_map_variables<'a, 'b>( + id_defs: &mut FnStringIdResolver<'a, 'b>, + fn_defs: &GlobalFnDeclResolver<'a, 'b>, + result: &mut Vec, + s: ast::Statement>, +) -> Result<(), TranslateError> { + match s { + ast::Statement::Block(block) => { + id_defs.start_block(); + for s in block { + expand_map_variables(id_defs, fn_defs, result, s)?; + } + id_defs.end_block(); + } + ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)), + ast::Statement::Instruction(p, i) => result.push(Statement::Instruction(( + p.map(|p| pred_map_variable(p, &mut |id| id_defs.get_id(id))) + .transpose()?, + ast::visit_map(i, &mut |id, + _: Option<(&ast::Type, ast::StateSpace)>, + _: bool, + _: bool| { + id_defs.get_id(id) + })?, + ))), + ast::Statement::Variable(var) => { + let var_type = var.var.v_type.clone(); + match var.count { + Some(count) => { + for new_id in + id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true) + { + result.push(Statement::Variable(ast::Variable { + align: var.var.align, + v_type: var.var.v_type.clone(), + state_space: var.var.state_space, + name: new_id, + array_init: var.var.array_init.clone(), + })) + } + } + None => { + let new_id = + id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true); + result.push(Statement::Variable(ast::Variable { + align: var.var.align, + v_type: var.var.v_type.clone(), + state_space: var.var.state_space, + name: new_id, + array_init: var.var.array_init, + })); + } + } + } + }; + Ok(()) +} diff --git a/ptx/src/pass/normalize_labels.rs b/ptx/src/pass/normalize_labels.rs new file mode 100644 index 00000000..097d87c7 --- /dev/null +++ b/ptx/src/pass/normalize_labels.rs @@ -0,0 +1,48 @@ +use std::{collections::HashSet, iter}; + +use super::*; + +pub(super) fn run( + func: Vec, + id_def: &mut NumericIdResolver, +) -> Vec { + let mut labels_in_use = HashSet::new(); + for s in func.iter() { + match s { + Statement::Instruction(i) => { + if let Some(target) = jump_target(i) { + labels_in_use.insert(target); + } + } + Statement::Conditional(cond) => { + labels_in_use.insert(cond.if_true); + labels_in_use.insert(cond.if_false); + } + Statement::Variable(..) + | Statement::LoadVar(..) + | Statement::StoreVar(..) + | Statement::RetValue(..) + | Statement::Conversion(..) + | Statement::Constant(..) + | Statement::Label(..) + | Statement::PtrAccess { .. } + | Statement::RepackVector(..) + | Statement::FunctionPointer(..) => {} + } + } + iter::once(Statement::Label(id_def.register_intermediate(None))) + .chain(func.into_iter().filter(|s| match s { + Statement::Label(i) => labels_in_use.contains(i), + _ => true, + })) + .collect::>() +} + +fn jump_target>( + this: &ast::Instruction, +) -> Option { + match this { + ast::Instruction::Bra { arguments } => Some(arguments.src), + _ => None, + } +} diff --git a/ptx/src/pass/normalize_predicates.rs b/ptx/src/pass/normalize_predicates.rs new file mode 100644 index 00000000..c971cfaa --- /dev/null +++ b/ptx/src/pass/normalize_predicates.rs @@ -0,0 +1,44 @@ +use super::*; +use ptx_parser as ast; + +pub(crate) fn run( + func: Vec, + id_def: &mut NumericIdResolver, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for s in func { + match s { + Statement::Label(id) => result.push(Statement::Label(id)), + Statement::Instruction((pred, inst)) => { + if let Some(pred) = pred { + let if_true = id_def.register_intermediate(None); + let if_false = id_def.register_intermediate(None); + let folded_bra = match &inst { + ast::Instruction::Bra { arguments, .. } => Some(arguments.src), + _ => None, + }; + let mut branch = BrachCondition { + predicate: pred.label, + if_true: folded_bra.unwrap_or(if_true), + if_false, + }; + if pred.not { + std::mem::swap(&mut branch.if_true, &mut branch.if_false); + } + result.push(Statement::Conditional(branch)); + if folded_bra.is_none() { + result.push(Statement::Label(if_true)); + result.push(Statement::Instruction(inst)); + } + result.push(Statement::Label(if_false)); + } else { + result.push(Statement::Instruction(inst)); + } + } + Statement::Variable(var) => result.push(Statement::Variable(var)), + // Blocks are flattened when resolving ids + _ => return Err(error_unreachable()), + } + } + Ok(result) +} diff --git a/ptx/src/test/spirv_run/clz.spvtxt b/ptx/src/test/spirv_run/clz.spvtxt index 9a7f2542..1feb5a0a 100644 --- a/ptx/src/test/spirv_run/clz.spvtxt +++ b/ptx/src/test/spirv_run/clz.spvtxt @@ -7,20 +7,24 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %22 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "clz" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong + %25 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %24 + %1 = OpFunction %void None %25 %7 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong - %19 = OpLabel + %20 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -37,11 +41,12 @@ %11 = OpLoad %uint %17 Aligned 4 OpStore %6 %11 %14 = OpLoad %uint %6 - %13 = OpExtInst %uint %21 clz %14 + %18 = OpExtInst %uint %22 clz %14 + %13 = OpCopyObject %uint %18 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %uint %6 - %18 = OpConvertUToPtr %_ptr_Generic_uint %15 - OpStore %18 %16 Aligned 4 + %19 = OpConvertUToPtr %_ptr_Generic_uint %15 + OpStore %19 %16 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt b/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt index 5f4b050a..92322ecc 100644 --- a/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt +++ b/ptx/src/test/spirv_run/cvt_s16_s8.spvtxt @@ -7,6 +7,9 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" %24 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvt_s16_s8" @@ -45,9 +48,7 @@ %32 = OpBitcast %uint %15 %34 = OpUConvert %uchar %32 %20 = OpCopyObject %uchar %34 - %35 = OpBitcast %uchar %20 - %37 = OpSConvert %ushort %35 - %19 = OpCopyObject %ushort %37 + %19 = OpSConvert %ushort %20 %14 = OpSConvert %uint %19 OpStore %6 %14 %16 = OpLoad %ulong %5 diff --git a/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt b/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt index 3f461034..11652905 100644 --- a/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt +++ b/ptx/src/test/spirv_run/cvt_s64_s32.spvtxt @@ -7,9 +7,13 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" %24 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvt_s64_s32" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 %27 = OpTypeFunction %void %ulong %ulong @@ -40,9 +44,7 @@ %12 = OpCopyObject %uint %18 OpStore %6 %12 %15 = OpLoad %uint %6 - %32 = OpBitcast %uint %15 - %33 = OpSConvert %ulong %32 - %14 = OpCopyObject %ulong %33 + %14 = OpSConvert %ulong %15 OpStore %7 %14 %16 = OpLoad %ulong %5 %17 = OpLoad %ulong %7 diff --git a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt index b6760499..07b228e8 100644 --- a/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt +++ b/ptx/src/test/spirv_run/cvt_sat_s_u.spvtxt @@ -7,9 +7,13 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" %25 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "cvt_sat_s_u" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 %28 = OpTypeFunction %void %ulong %ulong @@ -42,7 +46,7 @@ %15 = OpSatConvertSToU %uint %16 OpStore %7 %15 %18 = OpLoad %uint %7 - %17 = OpBitcast %uint %18 + %17 = OpCopyObject %uint %18 OpStore %8 %17 %19 = OpLoad %ulong %5 %20 = OpLoad %uint %8 diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index f5dfa640..a798720b 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -1,3 +1,4 @@ +use crate::pass; use crate::ptx; use crate::translate; use hip_runtime_sys::hipError_t; @@ -385,10 +386,8 @@ fn test_spvtxt_assert<'a>( spirv_txt: &'a [u8], spirv_file_name: &'a str, ) -> Result<(), Box> { - let mut errors = Vec::new(); - let ast = ptx::ModuleParser::new().parse(&mut errors, ptx_txt)?; - assert!(errors.len() == 0); - let spirv_module = translate::to_spirv_module(ast)?; + let ast = ptx_parser::parse_module_checked(ptx_txt).unwrap(); + let spirv_module = pass::to_spirv_module(ast)?; let spv_context = unsafe { spirv_tools::spvContextCreate(spv_target_env::SPV_ENV_UNIVERSAL_1_3) }; assert!(spv_context != ptr::null_mut()); diff --git a/ptx/src/test/spirv_run/popc.spvtxt b/ptx/src/test/spirv_run/popc.spvtxt index 845add7a..c41e7926 100644 --- a/ptx/src/test/spirv_run/popc.spvtxt +++ b/ptx/src/test/spirv_run/popc.spvtxt @@ -7,20 +7,24 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %21 = OpExtInstImport "OpenCL.std" + OpCapability DenormFlushToZero + OpExtension "SPV_KHR_float_controls" + OpExtension "SPV_KHR_no_integer_wrap_decoration" + %22 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "popc" + OpExecutionMode %1 ContractionOff %void = OpTypeVoid %ulong = OpTypeInt 64 0 - %24 = OpTypeFunction %void %ulong %ulong + %25 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %1 = OpFunction %void None %24 + %1 = OpFunction %void None %25 %7 = OpFunctionParameter %ulong %8 = OpFunctionParameter %ulong - %19 = OpLabel + %20 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function @@ -37,11 +41,12 @@ %11 = OpLoad %uint %17 Aligned 4 OpStore %6 %11 %14 = OpLoad %uint %6 - %13 = OpBitCount %uint %14 + %18 = OpBitCount %uint %14 + %13 = OpCopyObject %uint %18 OpStore %6 %13 %15 = OpLoad %ulong %5 %16 = OpLoad %uint %6 - %18 = OpConvertUToPtr %_ptr_Generic_uint %15 - OpStore %18 %16 Aligned 4 + %19 = OpConvertUToPtr %_ptr_Generic_uint %15 + OpStore %19 %16 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector.ptx b/ptx/src/test/spirv_run/vector.ptx index 90b8ad30..ba07e15e 100644 --- a/ptx/src/test/spirv_run/vector.ptx +++ b/ptx/src/test/spirv_run/vector.ptx @@ -1,4 +1,4 @@ -// Excersise as many features of vector types as possible +// Exercise as many features of vector types as possible .version 6.5 .target sm_60 diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index db1063b6..9b422fda 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1608,17 +1608,13 @@ fn extract_globals<'input, 'b>( for statement in sorted_statements { match statement { Statement::Variable( - var - @ - ast::Variable { + var @ ast::Variable { state_space: ast::StateSpace::Shared, .. }, ) | Statement::Variable( - var - @ - ast::Variable { + var @ ast::Variable { state_space: ast::StateSpace::Global, .. }, @@ -1660,9 +1656,7 @@ fn extract_globals<'input, 'b>( )?); } Statement::Instruction(ast::Instruction::Atom( - details - @ - ast::AtomDetails { + details @ ast::AtomDetails { inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Inc, @@ -1691,9 +1685,7 @@ fn extract_globals<'input, 'b>( )?); } Statement::Instruction(ast::Instruction::Atom( - details - @ - ast::AtomDetails { + details @ ast::AtomDetails { inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Dec, @@ -1722,9 +1714,7 @@ fn extract_globals<'input, 'b>( )?); } Statement::Instruction(ast::Instruction::Atom( - details - @ - ast::AtomDetails { + details @ ast::AtomDetails { inner: ast::AtomInnerDetails::Float { op: ast::AtomFloatOp::Add, diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml new file mode 100644 index 00000000..9032de5c --- /dev/null +++ b/ptx_parser/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "ptx_parser" +version = "0.0.0" +authors = ["Andrzej Janik "] +edition = "2021" + +[lib] + +[dependencies] +logos = "0.14" +winnow = { version = "0.6.18" } +#winnow = { version = "0.6.18", features = ["debug"] } +ptx_parser_macros = { path = "../ptx_parser_macros" } +thiserror = "1.0" +bitflags = "1.2" +rustc-hash = "2.0.0" +derive_more = { version = "1", features = ["display"] } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs new file mode 100644 index 00000000..d0dc303c --- /dev/null +++ b/ptx_parser/src/ast.rs @@ -0,0 +1,1695 @@ +use super::{ + AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, + StateSpace, VectorPrefix, +}; +use crate::{PtxError, PtxParserState}; +use bitflags::bitflags; +use std::{cmp::Ordering, num::NonZeroU8}; + +pub enum Statement { + Label(P::Ident), + Variable(MultiVariable), + Instruction(Option>, Instruction

), + Block(Vec>), +} + +// We define the instruction enum through the macro instead of normally, because we have some of how +// we use this type in the compilee. Each instruction can be logically split into two parts: +// properties that define instruction semantics (e.g. is memory load volatile?) that don't change +// during compilation and arguments (e.g. memory load source and destination) that evolve during +// compilation. To support compilation passes we need to be able to visit (and change) every +// argument in a generic way. This macro has visibility over all the fields. Consequently, we use it +// to generate visitor functions. There re three functions to support three different semantics: +// visit-by-ref, visit-by-mutable-ref, visit-and-map. In a previous version of the compiler it was +// done by hand and was very limiting (we supported only visit-and-map). +// The visitor must implement appropriate visitor trait defined below this macro. For convenience, +// we implemented visitors for some corresponding FnMut(...) types. +// Properties in this macro are used to encode information about the instruction arguments (what +// Rust type is used for it post-parsing, what PTX type does it expect, what PTX address space does +// it expect, etc.). +// This information is then available to a visitor. +ptx_parser_macros::generate_instruction_type!( + pub enum Instruction { + Mov { + type: { &data.typ }, + data: MovDetails, + arguments: { + dst: T, + src: T + } + }, + Ld { + type: { &data.typ }, + data: LdDetails, + arguments: { + dst: { + repr: T, + relaxed_type_check: true, + }, + src: { + repr: T, + space: { data.state_space }, + } + } + }, + Add { + type: { Type::from(data.type_()) }, + data: ArithDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + St { + type: { &data.typ }, + data: StData, + arguments: { + src1: { + repr: T, + space: { data.state_space }, + }, + src2: { + repr: T, + relaxed_type_check: true, + } + } + }, + Mul { + type: { Type::from(data.type_()) }, + data: MulDetails, + arguments: { + dst: { + repr: T, + type: { Type::from(data.dst_type()) }, + }, + src1: T, + src2: T, + } + }, + Setp { + data: SetpData, + arguments: { + dst1: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + dst2: { + repr: Option, + type: Type::from(ScalarType::Pred) + }, + src1: { + repr: T, + type: Type::from(data.type_), + }, + src2: { + repr: T, + type: Type::from(data.type_), + } + } + }, + SetpBool { + data: SetpBoolData, + arguments: { + dst1: { + repr: T, + type: Type::from(ScalarType::Pred) + }, + dst2: { + repr: Option, + type: Type::from(ScalarType::Pred) + }, + src1: { + repr: T, + type: Type::from(data.base.type_), + }, + src2: { + repr: T, + type: Type::from(data.base.type_), + }, + src3: { + repr: T, + type: Type::from(ScalarType::Pred) + } + } + }, + Not { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src: T, + } + }, + Or { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + And { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Bra { + type: !, + arguments: { + src: T + } + }, + Call { + data: CallDetails, + arguments: CallArgs, + visit: arguments.visit(data, visitor)?, + visit_mut: arguments.visit_mut(data, visitor)?, + map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data } + }, + Cvt { + data: CvtDetails, + arguments: { + dst: { + repr: T, + type: { Type::Scalar(data.to) }, + // TODO: double check + relaxed_type_check: true, + }, + src: { + repr: T, + type: { Type::Scalar(data.from) }, + relaxed_type_check: true, + }, + } + }, + Shr { + data: ShrData, + type: { Type::Scalar(data.type_.clone()) }, + arguments: { + dst: T, + src1: T, + src2: { + repr: T, + type: { Type::Scalar(ScalarType::U32) }, + }, + } + }, + Shl { + data: ScalarType, + type: { Type::Scalar(data.clone()) }, + arguments: { + dst: T, + src1: T, + src2: { + repr: T, + type: { Type::Scalar(ScalarType::U32) }, + }, + } + }, + Ret { + data: RetData + }, + Cvta { + data: CvtaDetails, + type: { Type::Scalar(ScalarType::B64) }, + arguments: { + dst: T, + src: T, + } + }, + Abs { + data: TypeFtz, + type: { Type::Scalar(data.type_) }, + arguments: { + dst: T, + src: T, + } + }, + Mad { + type: { Type::from(data.type_()) }, + data: MadDetails, + arguments: { + dst: { + repr: T, + type: { Type::from(data.dst_type()) }, + }, + src1: T, + src2: T, + src3: T, + } + }, + Fma { + type: { Type::from(data.type_) }, + data: ArithFloat, + arguments: { + dst: T, + src1: T, + src2: T, + src3: T, + } + }, + Sub { + type: { Type::from(data.type_()) }, + data: ArithDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Min { + type: { Type::from(data.type_()) }, + data: MinMaxDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Max { + type: { Type::from(data.type_()) }, + data: MinMaxDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Rcp { + type: { Type::from(data.type_) }, + data: RcpData, + arguments: { + dst: T, + src: T, + } + }, + Sqrt { + type: { Type::from(data.type_) }, + data: RcpData, + arguments: { + dst: T, + src: T, + } + }, + Rsqrt { + type: { Type::from(data.type_) }, + data: TypeFtz, + arguments: { + dst: T, + src: T, + } + }, + Selp { + type: { Type::Scalar(data.clone()) }, + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T, + src3: { + repr: T, + type: Type::Scalar(ScalarType::Pred) + }, + } + }, + Bar { + type: Type::Scalar(ScalarType::U32), + data: BarData, + arguments: { + src1: T, + src2: Option, + } + }, + Atom { + type: &data.type_, + data: AtomDetails, + arguments: { + dst: T, + src1: { + repr: T, + space: { data.space }, + }, + src2: T, + } + }, + AtomCas { + type: Type::Scalar(data.type_), + data: AtomCasDetails, + arguments: { + dst: T, + src1: { + repr: T, + space: { data.space }, + }, + src2: T, + src3: T, + } + }, + Div { + type: Type::Scalar(data.type_()), + data: DivDetails, + arguments: { + dst: T, + src1: T, + src2: T, + } + }, + Neg { + type: Type::Scalar(data.type_), + data: TypeFtz, + arguments: { + dst: T, + src: T + } + }, + Sin { + type: Type::Scalar(ScalarType::F32), + data: FlushToZero, + arguments: { + dst: T, + src: T + } + }, + Cos { + type: Type::Scalar(ScalarType::F32), + data: FlushToZero, + arguments: { + dst: T, + src: T + } + }, + Lg2 { + type: Type::Scalar(ScalarType::F32), + data: FlushToZero, + arguments: { + dst: T, + src: T + } + }, + Ex2 { + type: Type::Scalar(ScalarType::F32), + data: TypeFtz, + arguments: { + dst: T, + src: T + } + }, + Clz { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src: T + } + }, + Brev { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src: T + } + }, + Popc { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src: T + } + }, + Xor { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T + } + }, + Rem { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T + } + }, + Bfe { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src3: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + } + }, + Bfi { + type: Type::Scalar(data.clone()), + data: ScalarType, + arguments: { + dst: T, + src1: T, + src2: T, + src3: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + src4: { + repr: T, + type: Type::Scalar(ScalarType::U32) + }, + } + }, + PrmtSlow { + type: Type::Scalar(ScalarType::U32), + arguments: { + dst: T, + src1: T, + src2: T, + src3: T + } + }, + Prmt { + type: Type::Scalar(ScalarType::B32), + data: u16, + arguments: { + dst: T, + src1: T, + src2: T + } + }, + Activemask { + type: Type::Scalar(ScalarType::B32), + arguments: { + dst: T + } + }, + Membar { + data: MemScope + }, + Trap { } + } +); + +pub trait Visitor { + fn visit( + &mut self, + args: &T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<(), Err>; + fn visit_ident( + &mut self, + args: &T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<(), Err>; +} + +impl< + T: Operand, + Err, + Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>, + > Visitor for Fn +{ + fn visit( + &mut self, + args: &T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<(), Err> { + (self)(args, type_space, is_dst, relaxed_type_check) + } + + fn visit_ident( + &mut self, + args: &T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<(), Err> { + (self)( + &T::from_ident(*args), + type_space, + is_dst, + relaxed_type_check, + ) + } +} + +pub trait VisitorMut { + fn visit( + &mut self, + args: &mut T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<(), Err>; + fn visit_ident( + &mut self, + args: &mut T::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result<(), Err>; +} + +pub trait VisitorMap { + fn visit( + &mut self, + args: From, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result; + fn visit_ident( + &mut self, + args: From::Ident, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result; +} + +impl VisitorMap, ParsedOperand, Err> for Fn +where + Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result, +{ + fn visit( + &mut self, + args: ParsedOperand, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result, Err> { + Ok(match args { + ParsedOperand::Reg(ident) => { + ParsedOperand::Reg((self)(ident, type_space, is_dst, relaxed_type_check)?) + } + ParsedOperand::RegOffset(ident, imm) => ParsedOperand::RegOffset( + (self)(ident, type_space, is_dst, relaxed_type_check)?, + imm, + ), + ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm), + ParsedOperand::VecMember(ident, index) => ParsedOperand::VecMember( + (self)(ident, type_space, is_dst, relaxed_type_check)?, + index, + ), + ParsedOperand::VecPack(vec) => ParsedOperand::VecPack( + vec.into_iter() + .map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check)) + .collect::, _>>()?, + ), + }) + } + + fn visit_ident( + &mut self, + args: T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + (self)(args, type_space, is_dst, relaxed_type_check) + } +} + +impl, U: Operand, Err, Fn> VisitorMap for Fn +where + Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result, +{ + fn visit( + &mut self, + args: T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + (self)(args, type_space, is_dst, relaxed_type_check) + } + + fn visit_ident( + &mut self, + args: T, + type_space: Option<(&Type, StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + (self)(args, type_space, is_dst, relaxed_type_check) + } +} + +trait VisitOperand { + type Operand: Operand; + #[allow(unused)] // Used by generated code + fn visit(&self, fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err>; + #[allow(unused)] // Used by generated code + fn visit_mut( + &mut self, + fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err>; +} + +impl VisitOperand for T { + type Operand = Self; + fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { + fn_(self) + } + fn visit_mut( + &mut self, + mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err> { + fn_(self) + } +} + +impl VisitOperand for Option { + type Operand = T; + fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { + if let Some(x) = self { + fn_(x)?; + } + Ok(()) + } + fn visit_mut( + &mut self, + mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err> { + if let Some(x) = self { + fn_(x)?; + } + Ok(()) + } +} + +impl VisitOperand for Vec { + type Operand = T; + fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { + for o in self { + fn_(o)?; + } + Ok(()) + } + fn visit_mut( + &mut self, + mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, + ) -> Result<(), Err> { + for o in self { + fn_(o)?; + } + Ok(()) + } +} + +trait MapOperand: Sized { + type Input; + type Output; + #[allow(unused)] // Used by generated code + fn map( + self, + fn_: impl FnOnce(Self::Input) -> Result, + ) -> Result, Err>; +} + +impl MapOperand for T { + type Input = Self; + type Output = U; + fn map(self, fn_: impl FnOnce(T) -> Result) -> Result { + fn_(self) + } +} + +impl MapOperand for Option { + type Input = T; + type Output = Option; + fn map(self, fn_: impl FnOnce(T) -> Result) -> Result, Err> { + self.map(|x| fn_(x)).transpose() + } +} + +pub struct MultiVariable { + pub var: Variable, + pub count: Option, +} + +#[derive(Clone)] +pub struct Variable { + pub align: Option, + pub v_type: Type, + pub state_space: StateSpace, + pub name: ID, + pub array_init: Vec, +} + +pub struct PredAt { + pub not: bool, + pub label: ID, +} + +#[derive(PartialEq, Eq, Clone, Hash)] +pub enum Type { + // .param.b32 foo; + Scalar(ScalarType), + // .param.v2.b32 foo; + Vector(u8, ScalarType), + // .param.b32 foo[4]; + Array(Option, ScalarType, Vec), + Pointer(ScalarType, StateSpace), +} + +impl Type { + pub(crate) fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { + match vector { + Some(prefix) => Type::Vector(prefix.len().get(), scalar), + None => Type::Scalar(scalar), + } + } + + pub(crate) fn maybe_vector_parsed(prefix: Option, scalar: ScalarType) -> Self { + match prefix { + Some(prefix) => Type::Vector(prefix.get(), scalar), + None => Type::Scalar(scalar), + } + } + + pub(crate) fn maybe_array( + prefix: Option, + scalar: ScalarType, + array: Option>, + ) -> Self { + match array { + Some(dimensions) => Type::Array(prefix, scalar, dimensions), + None => Self::maybe_vector_parsed(prefix, scalar), + } + } +} + +impl ScalarType { + pub fn size_of(self) -> u8 { + match self { + ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1, + ScalarType::U16 + | ScalarType::S16 + | ScalarType::B16 + | ScalarType::F16 + | ScalarType::BF16 => 2, + ScalarType::U32 + | ScalarType::S32 + | ScalarType::B32 + | ScalarType::F32 + | ScalarType::U16x2 + | ScalarType::S16x2 + | ScalarType::F16x2 + | ScalarType::BF16x2 => 4, + ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => 8, + ScalarType::B128 => 16, + ScalarType::Pred => 1, + } + } + + pub fn kind(self) -> ScalarKind { + match self { + ScalarType::U8 => ScalarKind::Unsigned, + ScalarType::U16 => ScalarKind::Unsigned, + ScalarType::U16x2 => ScalarKind::Unsigned, + ScalarType::U32 => ScalarKind::Unsigned, + ScalarType::U64 => ScalarKind::Unsigned, + ScalarType::S8 => ScalarKind::Signed, + ScalarType::S16 => ScalarKind::Signed, + ScalarType::S16x2 => ScalarKind::Signed, + ScalarType::S32 => ScalarKind::Signed, + ScalarType::S64 => ScalarKind::Signed, + ScalarType::B8 => ScalarKind::Bit, + ScalarType::B16 => ScalarKind::Bit, + ScalarType::B32 => ScalarKind::Bit, + ScalarType::B64 => ScalarKind::Bit, + ScalarType::B128 => ScalarKind::Bit, + ScalarType::F16 => ScalarKind::Float, + ScalarType::F16x2 => ScalarKind::Float, + ScalarType::F32 => ScalarKind::Float, + ScalarType::F64 => ScalarKind::Float, + ScalarType::BF16 => ScalarKind::Float, + ScalarType::BF16x2 => ScalarKind::Float, + ScalarType::Pred => ScalarKind::Pred, + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum ScalarKind { + Bit, + Unsigned, + Signed, + Float, + Pred, +} +impl From for Type { + fn from(value: ScalarType) -> Self { + Type::Scalar(value) + } +} + +#[derive(Clone)] +pub struct MovDetails { + pub typ: super::Type, + pub src_is_address: bool, + // two fields below are in use by member moves + pub dst_width: u8, + pub src_width: u8, + // This is in use by auto-generated movs + pub relaxed_src2_conv: bool, +} + +impl MovDetails { + pub(crate) fn new(vector: Option, scalar: ScalarType) -> Self { + MovDetails { + typ: Type::maybe_vector(vector, scalar), + src_is_address: false, + dst_width: 0, + src_width: 0, + relaxed_src2_conv: false, + } + } +} + +#[derive(Clone)] +pub enum ParsedOperand { + Reg(Ident), + RegOffset(Ident, i32), + Imm(ImmediateValue), + VecMember(Ident, u8), + VecPack(Vec), +} + +impl Operand for ParsedOperand { + type Ident = Ident; + + fn from_ident(ident: Self::Ident) -> Self { + ParsedOperand::Reg(ident) + } +} + +pub trait Operand: Sized { + type Ident: Copy; + + fn from_ident(ident: Self::Ident) -> Self; +} + +#[derive(Copy, Clone)] +pub enum ImmediateValue { + U64(u64), + S64(i64), + F32(f32), + F64(f64), +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum StCacheOperator { + Writeback, + L2Only, + Streaming, + Writethrough, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum LdCacheOperator { + Cached, + L2Only, + Streaming, + LastUse, + Uncached, +} + +#[derive(Copy, Clone)] +pub enum ArithDetails { + Integer(ArithInteger), + Float(ArithFloat), +} + +impl ArithDetails { + pub fn type_(&self) -> ScalarType { + match self { + ArithDetails::Integer(t) => t.type_, + ArithDetails::Float(arith) => arith.type_, + } + } +} + +#[derive(Copy, Clone)] +pub struct ArithInteger { + pub type_: ScalarType, + pub saturate: bool, +} + +#[derive(Copy, Clone)] +pub struct ArithFloat { + pub type_: ScalarType, + pub rounding: Option, + pub flush_to_zero: Option, + pub saturate: bool, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum LdStQualifier { + Weak, + Volatile, + Relaxed(MemScope), + Acquire(MemScope), + Release(MemScope), +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum RoundingMode { + NearestEven, + Zero, + NegativeInf, + PositiveInf, +} + +pub struct LdDetails { + pub qualifier: LdStQualifier, + pub state_space: StateSpace, + pub caching: LdCacheOperator, + pub typ: Type, + pub non_coherent: bool, +} + +pub struct StData { + pub qualifier: LdStQualifier, + pub state_space: StateSpace, + pub caching: StCacheOperator, + pub typ: Type, +} + +#[derive(Copy, Clone)] +pub struct RetData { + pub uniform: bool, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum TuningDirective { + MaxNReg(u32), + MaxNtid(u32, u32, u32), + ReqNtid(u32, u32, u32), + MinNCtaPerSm(u32), +} + +pub struct MethodDeclaration<'input, ID> { + pub return_arguments: Vec>, + pub name: MethodName<'input, ID>, + pub input_arguments: Vec>, + pub shared_mem: Option, +} + +impl<'input> MethodDeclaration<'input, &'input str> { + pub fn name(&self) -> &'input str { + match self.name { + MethodName::Kernel(n) => n, + MethodName::Func(n) => n, + } + } +} + +#[derive(Hash, PartialEq, Eq, Copy, Clone)] +pub enum MethodName<'input, ID> { + Kernel(&'input str), + Func(ID), +} + +bitflags! { + pub struct LinkingDirective: u8 { + const NONE = 0b000; + const EXTERN = 0b001; + const VISIBLE = 0b10; + const WEAK = 0b100; + } +} + +pub struct Function<'a, ID, S> { + pub func_directive: MethodDeclaration<'a, ID>, + pub tuning: Vec, + pub body: Option>, +} + +pub enum Directive<'input, O: Operand> { + Variable(LinkingDirective, Variable), + Method( + LinkingDirective, + Function<'input, &'input str, Statement>, + ), +} + +pub struct Module<'input> { + pub version: (u8, u8), + pub directives: Vec>>, +} + +#[derive(Copy, Clone)] +pub enum MulDetails { + Integer { + type_: ScalarType, + control: MulIntControl, + }, + Float(ArithFloat), +} + +impl MulDetails { + pub fn type_(&self) -> ScalarType { + match self { + MulDetails::Integer { type_, .. } => *type_, + MulDetails::Float(arith) => arith.type_, + } + } + + pub fn dst_type(&self) -> ScalarType { + match self { + MulDetails::Integer { + type_, + control: MulIntControl::Wide, + } => match type_ { + ScalarType::U16 => ScalarType::U32, + ScalarType::S16 => ScalarType::S32, + ScalarType::U32 => ScalarType::U64, + ScalarType::S32 => ScalarType::S64, + _ => unreachable!(), + }, + _ => self.type_(), + } + } +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum MulIntControl { + Low, + High, + Wide, +} + +pub struct SetpData { + pub type_: ScalarType, + pub flush_to_zero: Option, + pub cmp_op: SetpCompareOp, +} + +impl SetpData { + pub(crate) fn try_parse( + state: &mut PtxParserState, + cmp_op: super::RawSetpCompareOp, + ftz: bool, + type_: ScalarType, + ) -> Self { + let flush_to_zero = match (ftz, type_) { + (_, ScalarType::F32) => Some(ftz), + (true, _) => { + state.errors.push(PtxError::NonF32Ftz); + None + } + _ => None + }; + let type_kind = type_.kind(); + let cmp_op = if type_kind == ScalarKind::Float { + SetpCompareOp::Float(SetpCompareFloat::from(cmp_op)) + } else { + match SetpCompareInt::try_from((cmp_op, type_kind)) { + Ok(op) => SetpCompareOp::Integer(op), + Err(err) => { + state.errors.push(err); + SetpCompareOp::Integer(SetpCompareInt::Eq) + } + } + }; + Self { + type_, + flush_to_zero, + cmp_op, + } + } +} + +pub struct SetpBoolData { + pub base: SetpData, + pub bool_op: SetpBoolPostOp, + pub negate_src3: bool, +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum SetpCompareOp { + Integer(SetpCompareInt), + Float(SetpCompareFloat), +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum SetpCompareInt { + Eq, + NotEq, + UnsignedLess, + UnsignedLessOrEq, + UnsignedGreater, + UnsignedGreaterOrEq, + SignedLess, + SignedLessOrEq, + SignedGreater, + SignedGreaterOrEq, +} + +#[derive(PartialEq, Eq, Copy, Clone)] +pub enum SetpCompareFloat { + Eq, + NotEq, + Less, + LessOrEq, + Greater, + GreaterOrEq, + NanEq, + NanNotEq, + NanLess, + NanLessOrEq, + NanGreater, + NanGreaterOrEq, + IsNotNan, + IsAnyNan, +} + +impl TryFrom<(RawSetpCompareOp, ScalarKind)> for SetpCompareInt { + type Error = PtxError; + + fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result { + match (value, kind) { + (RawSetpCompareOp::Eq, _) => Ok(SetpCompareInt::Eq), + (RawSetpCompareOp::Ne, _) => Ok(SetpCompareInt::NotEq), + (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedLess) + } + (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, _) => Ok(SetpCompareInt::UnsignedLess), + (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedLessOrEq) + } + (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, _) => { + Ok(SetpCompareInt::UnsignedLessOrEq) + } + (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedGreater) + } + (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, _) => Ok(SetpCompareInt::UnsignedGreater), + (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, ScalarKind::Signed) => { + Ok(SetpCompareInt::SignedGreaterOrEq) + } + (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, _) => { + Ok(SetpCompareInt::UnsignedGreaterOrEq) + } + (RawSetpCompareOp::Equ, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Neu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Ltu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Leu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Gtu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Geu, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Num, _) => Err(PtxError::WrongType), + (RawSetpCompareOp::Nan, _) => Err(PtxError::WrongType), + } + } +} + +impl From for SetpCompareFloat { + fn from(value: RawSetpCompareOp) -> Self { + match value { + RawSetpCompareOp::Eq => SetpCompareFloat::Eq, + RawSetpCompareOp::Ne => SetpCompareFloat::NotEq, + RawSetpCompareOp::Lt => SetpCompareFloat::Less, + RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq, + RawSetpCompareOp::Gt => SetpCompareFloat::Greater, + RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq, + RawSetpCompareOp::Lo => SetpCompareFloat::Less, + RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq, + RawSetpCompareOp::Hi => SetpCompareFloat::Greater, + RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq, + RawSetpCompareOp::Equ => SetpCompareFloat::NanEq, + RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq, + RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess, + RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq, + RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater, + RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq, + RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan, + RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan, + } + } +} + +pub struct CallDetails { + pub uniform: bool, + pub return_arguments: Vec<(Type, StateSpace)>, + pub input_arguments: Vec<(Type, StateSpace)>, +} + +pub struct CallArgs { + pub return_arguments: Vec, + pub func: T::Ident, + pub input_arguments: Vec, +} + +impl CallArgs { + #[allow(dead_code)] // Used by generated code + fn visit( + &self, + details: &CallDetails, + visitor: &mut impl Visitor, + ) -> Result<(), Err> { + for (param, (type_, space)) in self + .return_arguments + .iter() + .zip(details.return_arguments.iter()) + { + visitor.visit_ident(param, Some((type_, *space)), true, false)?; + } + visitor.visit_ident(&self.func, None, false, false)?; + for (param, (type_, space)) in self + .input_arguments + .iter() + .zip(details.input_arguments.iter()) + { + visitor.visit(param, Some((type_, *space)), false, false)?; + } + Ok(()) + } + + #[allow(dead_code)] // Used by generated code + fn visit_mut( + &mut self, + details: &CallDetails, + visitor: &mut impl VisitorMut, + ) -> Result<(), Err> { + for (param, (type_, space)) in self + .return_arguments + .iter_mut() + .zip(details.return_arguments.iter()) + { + visitor.visit_ident(param, Some((type_, *space)), true, false)?; + } + visitor.visit_ident(&mut self.func, None, false, false)?; + for (param, (type_, space)) in self + .input_arguments + .iter_mut() + .zip(details.input_arguments.iter()) + { + visitor.visit(param, Some((type_, *space)), false, false)?; + } + Ok(()) + } + + #[allow(dead_code)] // Used by generated code + fn map( + self, + details: &CallDetails, + visitor: &mut impl VisitorMap, + ) -> Result, Err> { + let return_arguments = self + .return_arguments + .into_iter() + .zip(details.return_arguments.iter()) + .map(|(param, (type_, space))| { + visitor.visit_ident(param, Some((type_, *space)), true, false) + }) + .collect::, _>>()?; + let func = visitor.visit_ident(self.func, None, false, false)?; + let input_arguments = self + .input_arguments + .into_iter() + .zip(details.input_arguments.iter()) + .map(|(param, (type_, space))| { + visitor.visit(param, Some((type_, *space)), false, false) + }) + .collect::, _>>()?; + Ok(CallArgs { + return_arguments, + func, + input_arguments, + }) + } +} + +pub struct CvtDetails { + pub from: ScalarType, + pub to: ScalarType, + pub mode: CvtMode, +} + +pub enum CvtMode { + // int from int + ZeroExtend, + SignExtend, + Truncate, + Bitcast, + SaturateUnsignedToSigned, + SaturateSignedToUnsigned, + // float from float + FPExtend { + flush_to_zero: Option, + }, + FPTruncate { + // float rounding + rounding: RoundingMode, + flush_to_zero: Option, + }, + FPRound { + integer_rounding: Option, + flush_to_zero: Option, + }, + // int from float + SignedFromFP { + rounding: RoundingMode, + flush_to_zero: Option, + }, // integer rounding + UnsignedFromFP { + rounding: RoundingMode, + flush_to_zero: Option, + }, // integer rounding + // float from int, ftz is allowed in the grammar, but clearly nonsensical + FPFromSigned(RoundingMode), // float rounding + FPFromUnsigned(RoundingMode), // float rounding +} + +impl CvtDetails { + pub(crate) fn new( + errors: &mut Vec, + rnd: Option, + ftz: bool, + saturate: bool, + dst: ScalarType, + src: ScalarType, + ) -> Self { + if saturate && dst.kind() == ScalarKind::Float { + errors.push(PtxError::SyntaxError); + } + // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results. + let flush_to_zero = match (dst, src) { + (ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz), + _ => { + if ftz { + errors.push(PtxError::NonF32Ftz); + } + None + } + }; + let rounding = rnd.map(Into::into); + let mut unwrap_rounding = || match rounding { + Some(rnd) => rnd, + None => { + errors.push(PtxError::SyntaxError); + RoundingMode::NearestEven + } + }; + let mode = match (dst.kind(), src.kind()) { + (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) { + Ordering::Less => CvtMode::FPTruncate { + rounding: unwrap_rounding(), + flush_to_zero, + }, + Ordering::Equal => CvtMode::FPRound { + integer_rounding: rounding, + flush_to_zero, + }, + Ordering::Greater => { + if rounding.is_some() { + errors.push(PtxError::SyntaxError); + } + CvtMode::FPExtend { flush_to_zero } + } + }, + (ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP { + rounding: unwrap_rounding(), + flush_to_zero, + }, + (ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP { + rounding: unwrap_rounding(), + flush_to_zero, + }, + (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()), + (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()), + (ScalarKind::Signed, ScalarKind::Unsigned) if saturate => { + CvtMode::SaturateUnsignedToSigned + } + (ScalarKind::Unsigned, ScalarKind::Signed) if saturate => { + CvtMode::SaturateSignedToUnsigned + } + (ScalarKind::Unsigned, ScalarKind::Signed) + | (ScalarKind::Signed, ScalarKind::Unsigned) + if dst.size_of() == src.size_of() => + { + CvtMode::Bitcast + } + (ScalarKind::Unsigned, ScalarKind::Unsigned) + | (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) { + Ordering::Less => CvtMode::Truncate, + Ordering::Equal => CvtMode::Bitcast, + Ordering::Greater => { + if src.kind() == ScalarKind::Signed { + CvtMode::SignExtend + } else { + CvtMode::ZeroExtend + } + } + }, + (ScalarKind::Unsigned, ScalarKind::Signed) => CvtMode::SaturateSignedToUnsigned, + (_, _) => { + errors.push(PtxError::SyntaxError); + CvtMode::Bitcast + } + }; + CvtDetails { + mode, + to: dst, + from: src, + } + } +} + +pub struct CvtIntToIntDesc { + pub dst: ScalarType, + pub src: ScalarType, + pub saturate: bool, +} + +pub struct CvtDesc { + pub rounding: Option, + pub flush_to_zero: Option, + pub saturate: bool, + pub dst: ScalarType, + pub src: ScalarType, +} + +pub struct ShrData { + pub type_: ScalarType, + pub kind: RightShiftKind, +} + +pub enum RightShiftKind { + Arithmetic, + Logical, +} + +pub struct CvtaDetails { + pub state_space: StateSpace, + pub direction: CvtaDirection, +} + +pub enum CvtaDirection { + GenericToExplicit, + ExplicitToGeneric, +} + +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct TypeFtz { + pub flush_to_zero: Option, + pub type_: ScalarType, +} + +#[derive(Copy, Clone)] +pub enum MadDetails { + Integer { + control: MulIntControl, + saturate: bool, + type_: ScalarType, + }, + Float(ArithFloat), +} + +impl MadDetails { + pub fn dst_type(&self) -> ScalarType { + match self { + MadDetails::Integer { + type_, + control: MulIntControl::Wide, + .. + } => match type_ { + ScalarType::U16 => ScalarType::U32, + ScalarType::S16 => ScalarType::S32, + ScalarType::U32 => ScalarType::U64, + ScalarType::S32 => ScalarType::S64, + _ => unreachable!(), + }, + _ => self.type_(), + } + } + + fn type_(&self) -> ScalarType { + match self { + MadDetails::Integer { type_, .. } => *type_, + MadDetails::Float(arith) => arith.type_, + } + } +} + +#[derive(Copy, Clone)] +pub enum MinMaxDetails { + Signed(ScalarType), + Unsigned(ScalarType), + Float(MinMaxFloat), +} + +impl MinMaxDetails { + pub fn type_(&self) -> ScalarType { + match self { + MinMaxDetails::Signed(t) => *t, + MinMaxDetails::Unsigned(t) => *t, + MinMaxDetails::Float(float) => float.type_, + } + } +} + +#[derive(Copy, Clone)] +pub struct MinMaxFloat { + pub flush_to_zero: Option, + pub nan: bool, + pub type_: ScalarType, +} + +#[derive(Copy, Clone)] +pub struct RcpData { + pub kind: RcpKind, + pub flush_to_zero: Option, + pub type_: ScalarType, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum RcpKind { + Approx, + Compliant(RoundingMode), +} + +pub struct BarData { + pub aligned: bool, +} + +pub struct AtomDetails { + pub type_: Type, + pub semantics: AtomSemantics, + pub scope: MemScope, + pub space: StateSpace, + pub op: AtomicOp, +} + +#[derive(Copy, Clone)] +pub enum AtomicOp { + And, + Or, + Xor, + Exchange, + Add, + IncrementWrap, + DecrementWrap, + SignedMin, + UnsignedMin, + SignedMax, + UnsignedMax, + FloatAdd, + FloatMin, + FloatMax, +} + +impl AtomicOp { + pub(crate) fn new(op: super::RawAtomicOp, kind: ScalarKind) -> Self { + use super::RawAtomicOp; + match (op, kind) { + (RawAtomicOp::And, _) => Self::And, + (RawAtomicOp::Or, _) => Self::Or, + (RawAtomicOp::Xor, _) => Self::Xor, + (RawAtomicOp::Exch, _) => Self::Exchange, + (RawAtomicOp::Add, ScalarKind::Float) => Self::FloatAdd, + (RawAtomicOp::Add, _) => Self::Add, + (RawAtomicOp::Inc, _) => Self::IncrementWrap, + (RawAtomicOp::Dec, _) => Self::DecrementWrap, + (RawAtomicOp::Min, ScalarKind::Signed) => Self::SignedMin, + (RawAtomicOp::Min, ScalarKind::Float) => Self::FloatMin, + (RawAtomicOp::Min, _) => Self::UnsignedMin, + (RawAtomicOp::Max, ScalarKind::Signed) => Self::SignedMax, + (RawAtomicOp::Max, ScalarKind::Float) => Self::FloatMax, + (RawAtomicOp::Max, _) => Self::UnsignedMax, + } + } +} + +pub struct AtomCasDetails { + pub type_: ScalarType, + pub semantics: AtomSemantics, + pub scope: MemScope, + pub space: StateSpace, +} + +#[derive(Copy, Clone)] +pub enum DivDetails { + Unsigned(ScalarType), + Signed(ScalarType), + Float(DivFloatDetails), +} + +impl DivDetails { + pub fn type_(&self) -> ScalarType { + match self { + DivDetails::Unsigned(t) => *t, + DivDetails::Signed(t) => *t, + DivDetails::Float(float) => float.type_, + } + } +} + +#[derive(Copy, Clone)] +pub struct DivFloatDetails { + pub type_: ScalarType, + pub flush_to_zero: Option, + pub kind: DivFloatKind, +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub enum DivFloatKind { + Approx, + ApproxFull, + Rounding(RoundingMode), +} + +#[derive(Copy, Clone, Eq, PartialEq)] +pub struct FlushToZero { + pub flush_to_zero: bool, +} diff --git a/ptx_parser/src/check_args.py b/ptx_parser/src/check_args.py new file mode 100644 index 00000000..04ffdb91 --- /dev/null +++ b/ptx_parser/src/check_args.py @@ -0,0 +1,69 @@ +import os, sys, subprocess + + +SPACE = [".reg", ".sreg", ".param", ".param::entry", ".param::func", ".local", ".global", ".const", ".shared", ".shared::cta", ".shared::cluster"] +TYPE_AND_INIT = ["", " = 1", "[1]", "[1] = {1}"] +MULTIVAR = ["", "<1>" ] +VECTOR = ["", ".v2" ] + +HEADER = """ + .version 8.5 + .target sm_90 + .address_size 64 +""" + + +def directive(space, variable, multivar, vector): + return """{3} + {0} {4} .b32 variable{2} {1}; + """.format(space, variable, multivar, HEADER, vector) + +def entry_arg(space, variable, multivar, vector): + return """{3} + .entry foobar ( {0} {4} .b32 variable{2} {1}) + {{ + ret; + }} + """.format(space, variable, multivar, HEADER, vector) + + +def fn_arg(space, variable, multivar, vector): + return """{3} + .func foobar ( {0} {4} .b32 variable{2} {1}) + {{ + ret; + }} + """.format(space, variable, multivar, HEADER, vector) + + +def fn_body(space, variable, multivar, vector): + return """{3} + .func foobar () + {{ + {0} {4} .b32 variable{2} {1}; + ret; + }} + """.format(space, variable, multivar, HEADER, vector) + + +def generate(generator): + legal = [] + for space in SPACE: + for init in TYPE_AND_INIT: + for multi in MULTIVAR: + for vector in VECTOR: + ptx = generator(space, init, multi, vector) + if 0 == subprocess.call(["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin\\ptxas.exe", "-arch", "sm_90", "-ias", ptx], stdout = subprocess.DEVNULL): # + legal.append((space, vector, init, multi)) + print(generator.__name__) + print(legal) + + +def main(): + generate(directive) + generate(entry_arg) + generate(fn_arg) + generate(fn_body) + +if __name__ == "__main__": + main() diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs new file mode 100644 index 00000000..f842ace6 --- /dev/null +++ b/ptx_parser/src/lib.rs @@ -0,0 +1,3269 @@ +use derive_more::Display; +use logos::Logos; +use ptx_parser_macros::derive_parser; +use rustc_hash::FxHashMap; +use std::fmt::Debug; +use std::iter; +use std::num::{NonZeroU8, ParseFloatError, ParseIntError}; +use winnow::ascii::dec_uint; +use winnow::combinator::*; +use winnow::error::{ErrMode, ErrorKind}; +use winnow::stream::Accumulate; +use winnow::token::any; +use winnow::{ + error::{ContextError, ParserError}, + stream::{Offset, Stream, StreamIsPartial}, + PResult, +}; +use winnow::{prelude::*, Stateful}; + +mod ast; +pub use ast::*; + +impl From for ast::MulIntControl { + fn from(value: RawMulIntControl) -> Self { + match value { + RawMulIntControl::Lo => ast::MulIntControl::Low, + RawMulIntControl::Hi => ast::MulIntControl::High, + RawMulIntControl::Wide => ast::MulIntControl::Wide, + } + } +} + +impl From for ast::StCacheOperator { + fn from(value: RawStCacheOperator) -> Self { + match value { + RawStCacheOperator::Wb => ast::StCacheOperator::Writeback, + RawStCacheOperator::Cg => ast::StCacheOperator::L2Only, + RawStCacheOperator::Cs => ast::StCacheOperator::Streaming, + RawStCacheOperator::Wt => ast::StCacheOperator::Writethrough, + } + } +} + +impl From for ast::LdCacheOperator { + fn from(value: RawLdCacheOperator) -> Self { + match value { + RawLdCacheOperator::Ca => ast::LdCacheOperator::Cached, + RawLdCacheOperator::Cg => ast::LdCacheOperator::L2Only, + RawLdCacheOperator::Cs => ast::LdCacheOperator::Streaming, + RawLdCacheOperator::Lu => ast::LdCacheOperator::LastUse, + RawLdCacheOperator::Cv => ast::LdCacheOperator::Uncached, + } + } +} + +impl From for ast::LdStQualifier { + fn from(value: RawLdStQualifier) -> Self { + match value { + RawLdStQualifier::Weak => ast::LdStQualifier::Weak, + RawLdStQualifier::Volatile => ast::LdStQualifier::Volatile, + } + } +} + +impl From for ast::RoundingMode { + fn from(value: RawRoundingMode) -> Self { + match value { + RawRoundingMode::Rn | RawRoundingMode::Rni => ast::RoundingMode::NearestEven, + RawRoundingMode::Rz | RawRoundingMode::Rzi => ast::RoundingMode::Zero, + RawRoundingMode::Rm | RawRoundingMode::Rmi => ast::RoundingMode::NegativeInf, + RawRoundingMode::Rp | RawRoundingMode::Rpi => ast::RoundingMode::PositiveInf, + } + } +} + +impl VectorPrefix { + pub(crate) fn len(self) -> NonZeroU8 { + unsafe { + match self { + VectorPrefix::V2 => NonZeroU8::new_unchecked(2), + VectorPrefix::V4 => NonZeroU8::new_unchecked(4), + VectorPrefix::V8 => NonZeroU8::new_unchecked(8), + } + } + } +} + +struct PtxParserState<'a, 'input> { + errors: &'a mut Vec, + function_declarations: + FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>, +} + +impl<'a, 'input> PtxParserState<'a, 'input> { + fn new(errors: &'a mut Vec) -> Self { + Self { + errors, + function_declarations: FxHashMap::default(), + } + } + + fn record_function(&mut self, function_decl: &MethodDeclaration<'input, &'input str>) { + let name = match function_decl.name { + MethodName::Kernel(name) => name, + MethodName::Func(name) => name, + }; + let return_arguments = Self::get_type_space(&*function_decl.return_arguments); + let input_arguments = Self::get_type_space(&*function_decl.input_arguments); + // TODO: check if declarations match + self.function_declarations + .insert(name, (return_arguments, input_arguments)); + } + + fn get_type_space(input_arguments: &[Variable<&str>]) -> Vec<(Type, StateSpace)> { + input_arguments + .iter() + .map(|var| (var.v_type.clone(), var.state_space)) + .collect::>() + } +} + +impl<'a, 'input> Debug for PtxParserState<'a, 'input> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PtxParserState") + .field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */ + .finish() + } +} + +type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>; + +fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { + any.verify_map(|t| { + if let Token::Ident(text) = t { + Some(text) + } else if let Some(text) = t.opcode_text() { + Some(text) + } else { + None + } + }) + .parse_next(stream) +} + +fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { + any.verify_map(|t| { + if let Token::DotIdent(text) = t { + Some(text) + } else { + None + } + }) + .parse_next(stream) +} + +fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> { + any.verify_map(|t| { + Some(match t { + Token::Hex(s) => { + if s.ends_with('U') { + (&s[2..s.len() - 1], 16, true) + } else { + (&s[2..], 16, false) + } + } + Token::Decimal(s) => { + let radix = if s.starts_with('0') { 8 } else { 10 }; + if s.ends_with('U') { + (&s[..s.len() - 1], radix, true) + } else { + (s, radix, false) + } + } + _ => return None, + }) + }) + .parse_next(stream) +} + +fn take_error<'a, 'input: 'a, O, E>( + mut parser: impl Parser, Result, E>, +) -> impl Parser, O, E> { + move |input: &mut PtxParser<'a, 'input>| { + Ok(match parser.parse_next(input)? { + Ok(x) => x, + Err((x, err)) => { + input.state.errors.push(err); + x + } + }) + } +} + +fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult { + take_error((opt(Token::Minus), num).map(|(neg, x)| { + let (num, radix, is_unsigned) = x; + if neg.is_some() { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(-x)), + Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))), + } + } else if is_unsigned { + match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))), + } + } else { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(x)), + Err(_) => match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))), + }, + } + } + })) + .parse_next(input) +} + +fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + take_error(any.verify_map(|t| match t { + Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) { + Ok(x) => Ok(f32::from_bits(x)), + Err(err) => Err((0.0, PtxError::from(err))), + }), + _ => None, + })) + .parse_next(stream) +} + +fn f64<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + take_error(any.verify_map(|t| match t { + Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) { + Ok(x) => Ok(f64::from_bits(x)), + Err(err) => Err((0.0, PtxError::from(err))), + }), + _ => None, + })) + .parse_next(stream) +} + +fn s32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + take_error((opt(Token::Minus), num).map(|(sign, x)| { + let (text, radix, _) = x; + match i32::from_str_radix(text, radix) { + Ok(x) => Ok(if sign.is_some() { -x } else { x }), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn u8<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + take_error(num.map(|x| { + let (text, radix, _) = x; + match u8::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + take_error(num.map(|x| { + let (text, radix, _) = x; + match u32::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + alt(( + int_immediate, + f32.map(ast::ImmediateValue::F32), + f64.map(ast::ImmediateValue::F64), + )) + .parse_next(stream) +} + +pub fn parse_module_unchecked<'input>(text: &'input str) -> Option> { + let lexer = Token::lexer(text); + let input = lexer.collect::, _>>().ok()?; + let mut errors = Vec::new(); + let state = PtxParserState::new(&mut errors); + let parser = PtxParser { + state, + input: &input[..], + }; + let parsing_result = module.parse(parser).ok(); + if !errors.is_empty() { + None + } else { + parsing_result + } +} + +pub fn parse_module_checked<'input>( + text: &'input str, +) -> Result, Vec> { + let mut lexer = Token::lexer(text); + let mut errors = Vec::new(); + let mut tokens = Vec::new(); + loop { + let maybe_token = match lexer.next() { + Some(maybe_token) => maybe_token, + None => break, + }; + match maybe_token { + Ok(token) => tokens.push(token), + Err(mut err) => { + err.0 = lexer.span(); + errors.push(PtxError::from(err)) + } + } + } + if !errors.is_empty() { + return Err(errors); + } + let parse_result = { + let state = PtxParserState::new(&mut errors); + let parser = PtxParser { + state, + input: &tokens[..], + }; + module + .parse(parser) + .map_err(|err| PtxError::Parser(err.into_inner())) + }; + match parse_result { + Ok(result) if errors.is_empty() => Ok(result), + Ok(_) => Err(errors), + Err(err) => { + errors.push(err); + Err(errors) + } + } +} + +fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + ( + version, + target, + opt(address_size), + repeat_without_none(directive), + eof, + ) + .map(|(version, _, _, directives, _)| ast::Module { + version, + directives, + }) + .parse_next(stream) +} + +fn address_size<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + (Token::DotAddressSize, u8_literal(64)) + .void() + .parse_next(stream) +} + +fn version<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u8, u8)> { + (Token::DotVersion, u8, Token::Dot, u8) + .map(|(_, major, _, minor)| (major, minor)) + .parse_next(stream) +} + +fn target<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, Option)> { + preceded(Token::DotTarget, ident.and_then(shader_model)).parse_next(stream) +} + +fn shader_model<'a>(stream: &mut &str) -> PResult<(u32, Option)> { + ( + "sm_", + dec_uint, + opt(any.verify(|c: &char| c.is_ascii_lowercase())), + eof, + ) + .map(|(_, digits, arch_variant, _)| (digits, arch_variant)) + .parse_next(stream) +} + +fn directive<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>>> { + alt(( + function.map(|(linking, func)| Some(ast::Directive::Method(linking, func))), + file.map(|_| None), + section.map(|_| None), + (module_variable, Token::Semicolon) + .map(|((linking, var), _)| Some(ast::Directive::Variable(linking, var))), + )) + .parse_next(stream) +} + +fn module_variable<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<(ast::LinkingDirective, ast::Variable<&'input str>)> { + let linking = linking_directives.parse_next(stream)?; + let var = global_space + .flat_map(|space| multi_variable(linking.contains(LinkingDirective::EXTERN), space)) + // TODO: support multi var in globals + .map(|multi_var| multi_var.var) + .parse_next(stream)?; + Ok((linking, var)) +} + +fn file<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotFile, + u32, + Token::String, + opt((Token::Comma, u32, Token::Comma, u32)), + ) + .void() + .parse_next(stream) +} + +fn section<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotSection.void(), + dot_ident.void(), + Token::LBrace.void(), + repeat::<_, _, (), _, _>(0.., section_dwarf_line), + Token::RBrace.void(), + ) + .void() + .parse_next(stream) +} + +fn section_dwarf_line<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt(( + (section_label, Token::Colon).void(), + (Token::DotB32, section_label, opt((Token::Add, u32))).void(), + (Token::DotB64, section_label, opt((Token::Add, u32))).void(), + ( + any_bit_type, + separated::<_, _, (), _, _, _, _>(1.., u32, Token::Comma), + ) + .void(), + )) + .parse_next(stream) +} + +fn any_bit_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt((Token::DotB8, Token::DotB16, Token::DotB32, Token::DotB64)) + .void() + .parse_next(stream) +} + +fn section_label<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt((ident, dot_ident)).void().parse_next(stream) +} + +fn function<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<( + ast::LinkingDirective, + ast::Function<'input, &'input str, ast::Statement>>, +)> { + let (linking, function) = ( + linking_directives, + method_declaration, + repeat(0.., tuning_directive), + function_body, + ) + .map(|(linking, func_directive, tuning, body)| { + ( + linking, + ast::Function { + func_directive, + tuning, + body, + }, + ) + }) + .parse_next(stream)?; + stream.state.record_function(&function.func_directive); + Ok((linking, function)) +} + +fn linking_directives<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult { + repeat( + 0.., + dispatch! { any; + Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN), + Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE), + Token::DotWeak => empty.value(ast::LinkingDirective::WEAK), + _ => fail + }, + ) + .fold(|| ast::LinkingDirective::NONE, |x, y| x | y) + .parse_next(stream) +} + +fn tuning_directive<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult { + dispatch! {any; + Token::DotMaxnreg => u32.map(ast::TuningDirective::MaxNReg), + Token::DotMaxntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)), + Token::DotReqntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)), + Token::DotMinnctapersm => u32.map(ast::TuningDirective::MinNCtaPerSm), + _ => fail + } + .parse_next(stream) +} + +fn method_declaration<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult> { + dispatch! {any; + Token::DotEntry => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{ + return_arguments: Vec::new(), name: ast::MethodName::Kernel(name), input_arguments, shared_mem: None + }), + Token::DotFunc => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| { + let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); + let name = ast::MethodName::Func(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } + }), + _ => fail + } + .parse_next(stream) +} + +fn fn_arguments<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + delimited( + Token::LParen, + separated(0.., fn_input, Token::Comma), + Token::RParen, + ) + .parse_next(stream) +} + +fn kernel_arguments<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + delimited( + Token::LParen, + separated(0.., kernel_input, Token::Comma), + Token::RParen, + ) + .parse_next(stream) +} + +fn kernel_input<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult> { + preceded(Token::DotParam, method_parameter(StateSpace::Param)).parse_next(stream) +} + +fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + dispatch! { any; + Token::DotParam => method_parameter(StateSpace::Param), + Token::DotReg => method_parameter(StateSpace::Reg), + _ => fail + } + .parse_next(stream) +} + +fn tuple1to3_u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, u32, u32)> { + struct Tuple3AccumulateU32 { + index: usize, + value: (u32, u32, u32), + } + + impl Accumulate for Tuple3AccumulateU32 { + fn initial(_: Option) -> Self { + Self { + index: 0, + value: (1, 1, 1), + } + } + + fn accumulate(&mut self, value: u32) { + match self.index { + 0 => { + self.value = (value, self.value.1, self.value.2); + self.index = 1; + } + 1 => { + self.value = (self.value.0, value, self.value.2); + self.index = 2; + } + 2 => { + self.value = (self.value.0, self.value.1, value); + self.index = 3; + } + _ => unreachable!(), + } + } + } + + separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..=3, u32, Token::Comma) + .map(|acc| acc.value) + .parse_next(stream) +} + +fn function_body<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>>>> { + dispatch! {any; + Token::LBrace => terminated(repeat_without_none(statement), Token::RBrace).map(Some), + Token::Semicolon => empty.map(|_| None), + _ => fail + } + .parse_next(stream) +} + +fn statement<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>>> { + alt(( + label.map(Some), + debug_directive.map(|_| None), + terminated( + method_space + .flat_map(|space| multi_variable(false, space)) + .map(|var| Some(Statement::Variable(var))), + Token::Semicolon, + ), + predicated_instruction.map(Some), + pragma.map(|_| None), + block_statement.map(Some), + )) + .parse_next(stream) +} + +fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + (Token::DotPragma, Token::String, Token::Semicolon) + .void() + .parse_next(stream) +} + +fn method_parameter<'a, 'input: 'a>( + state_space: StateSpace, +) -> impl Parser, Variable<&'input str>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + let (align, vector, type_, name) = variable_declaration.parse_next(stream)?; + let array_dimensions = if state_space != StateSpace::Reg { + opt(array_dimensions).parse_next(stream)? + } else { + None + }; + // TODO: push this check into array_dimensions(...) + if let Some(ref dims) = array_dimensions { + if dims[0] == 0 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + } + Ok(Variable { + align, + v_type: Type::maybe_array(vector, type_, array_dimensions), + state_space, + name, + array_init: Vec::new(), + }) + } +} + +// TODO: split to a separate type +fn variable_declaration<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<(Option, Option, ScalarType, &'input str)> { + ( + opt(align.verify(|x| x.count_ones() == 1)), + vector_prefix, + scalar_type, + ident, + ) + .parse_next(stream) +} + +fn multi_variable<'a, 'input: 'a>( + extern_: bool, + state_space: StateSpace, +) -> impl Parser, MultiVariable<&'input str>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + let ((align, vector, type_, name), count) = ( + variable_declaration, + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names + opt(delimited(Token::Lt, u32.verify(|x| *x != 0), Token::Gt)), + ) + .parse_next(stream)?; + if count.is_some() { + return Ok(MultiVariable { + var: Variable { + align, + v_type: Type::maybe_vector_parsed(vector, type_), + state_space, + name, + array_init: Vec::new(), + }, + count, + }); + } + let mut array_dimensions = if state_space != StateSpace::Reg { + opt(array_dimensions).parse_next(stream)? + } else { + None + }; + let initializer = match state_space { + StateSpace::Global | StateSpace::Const => match array_dimensions { + Some(ref mut dimensions) => { + opt(array_initializer(vector, type_, dimensions)).parse_next(stream)? + } + None => opt(value_initializer(vector, type_)).parse_next(stream)?, + }, + _ => None, + }; + if let Some(ref dims) = array_dimensions { + if !extern_ && dims[0] == 0 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + } + Ok(MultiVariable { + var: Variable { + align, + v_type: Type::maybe_array(vector, type_, array_dimensions), + state_space, + name, + array_init: initializer.unwrap_or(Vec::new()), + }, + count, + }) + } +} + +fn array_initializer<'a, 'input: 'a>( + vector: Option, + type_: ScalarType, + array_dimensions: &mut Vec, +) -> impl Parser, Vec, ContextError> + '_ { + move |stream: &mut PtxParser<'a, 'input>| { + Token::Eq.parse_next(stream)?; + let mut result = Vec::new(); + // TODO: vector constants and multi dim arrays + if vector.is_some() || array_dimensions[0] == 0 || array_dimensions.len() > 1 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + delimited( + Token::LBrace, + separated( + 0..=array_dimensions[0] as usize, + single_value_append(&mut result, type_), + Token::Comma, + ), + Token::RBrace, + ) + .parse_next(stream)?; + // pad with zeros + let result_size = type_.size_of() as usize * array_dimensions[0] as usize; + result.extend(iter::repeat(0u8).take(result_size - result.len())); + Ok(result) + } +} + +fn value_initializer<'a, 'input: 'a>( + vector: Option, + type_: ScalarType, +) -> impl Parser, Vec, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + Token::Eq.parse_next(stream)?; + let mut result = Vec::new(); + // TODO: vector constants + if vector.is_some() { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + single_value_append(&mut result, type_).parse_next(stream)?; + Ok(result) + } +} + +fn single_value_append<'a, 'input: 'a>( + accumulator: &mut Vec, + type_: ScalarType, +) -> impl Parser, (), ContextError> + '_ { + move |stream: &mut PtxParser<'a, 'input>| { + let value = immediate_value.parse_next(stream)?; + match (type_, value) { + (ScalarType::U8 | ScalarType::B8, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u8).to_le_bytes()) + } + (ScalarType::U8 | ScalarType::B8, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u8).to_le_bytes()) + } + (ScalarType::U16 | ScalarType::B16, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u16).to_le_bytes()) + } + (ScalarType::U16 | ScalarType::B16, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u16).to_le_bytes()) + } + (ScalarType::U32 | ScalarType::B32, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u32).to_le_bytes()) + } + (ScalarType::U32 | ScalarType::B32, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u32).to_le_bytes()) + } + (ScalarType::U64 | ScalarType::B64, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u64).to_le_bytes()) + } + (ScalarType::U64 | ScalarType::B64, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u64).to_le_bytes()) + } + (ScalarType::S8, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i8).to_le_bytes()) + } + (ScalarType::S8, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i8).to_le_bytes()) + } + (ScalarType::S16, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i16).to_le_bytes()) + } + (ScalarType::S16, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i16).to_le_bytes()) + } + (ScalarType::S32, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i32).to_le_bytes()) + } + (ScalarType::S32, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i32).to_le_bytes()) + } + (ScalarType::S64, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i64).to_le_bytes()) + } + (ScalarType::S64, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i64).to_le_bytes()) + } + (ScalarType::F32, ImmediateValue::F32(x)) => { + accumulator.extend_from_slice(&x.to_le_bytes()) + } + (ScalarType::F64, ImmediateValue::F64(x)) => { + accumulator.extend_from_slice(&x.to_le_bytes()) + } + _ => return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)), + } + Ok(()) + } +} + +fn array_dimensions<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + let dimension = delimited( + Token::LBracket, + opt(u32).verify(|dim| *dim != Some(0)), + Token::RBracket, + ) + .parse_next(stream)?; + let result = vec![dimension.unwrap_or(0)]; + repeat_fold_0_or_more( + delimited( + Token::LBracket, + u32.verify(|dim| *dim != 0), + Token::RBracket, + ), + move || result, + |mut result: Vec, x| { + result.push(x); + result + }, + stream, + ) +} + +// Copied and fixed from Winnow sources (fold_repeat0_) +// Winnow Repeat::fold takes FnMut() -> Result to initalize accumulator, +// this really should be FnOnce() -> Result +fn repeat_fold_0_or_more( + mut f: F, + init: H, + mut g: G, + input: &mut I, +) -> PResult +where + I: Stream, + F: Parser, + G: FnMut(R, O) -> R, + H: FnOnce() -> R, + E: ParserError, +{ + use winnow::error::ErrMode; + let mut res = init(); + loop { + let start = input.checkpoint(); + match f.parse_next(input) { + Ok(o) => { + res = g(res, o); + } + Err(ErrMode::Backtrack(_)) => { + input.reset(&start); + return Ok(res); + } + Err(e) => { + return Err(e); + } + } + } +} + +fn global_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + alt(( + Token::DotGlobal.value(StateSpace::Global), + Token::DotConst.value(StateSpace::Const), + Token::DotShared.value(StateSpace::Shared), + )) + .parse_next(stream) +} + +fn method_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + alt(( + Token::DotReg.value(StateSpace::Reg), + Token::DotLocal.value(StateSpace::Local), + Token::DotParam.value(StateSpace::Param), + global_space, + )) + .parse_next(stream) +} + +fn align<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + preceded(Token::DotAlign, u32).parse_next(stream) +} + +fn vector_prefix<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + opt(alt(( + Token::DotV2.value(unsafe { NonZeroU8::new_unchecked(2) }), + Token::DotV4.value(unsafe { NonZeroU8::new_unchecked(4) }), + Token::DotV8.value(unsafe { NonZeroU8::new_unchecked(8) }), + ))) + .parse_next(stream) +} + +fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult { + any.verify_map(|t| { + Some(match t { + Token::DotS8 => ScalarType::S8, + Token::DotS16 => ScalarType::S16, + Token::DotS16x2 => ScalarType::S16x2, + Token::DotS32 => ScalarType::S32, + Token::DotS64 => ScalarType::S64, + Token::DotU8 => ScalarType::U8, + Token::DotU16 => ScalarType::U16, + Token::DotU16x2 => ScalarType::U16x2, + Token::DotU32 => ScalarType::U32, + Token::DotU64 => ScalarType::U64, + Token::DotB8 => ScalarType::B8, + Token::DotB16 => ScalarType::B16, + Token::DotB32 => ScalarType::B32, + Token::DotB64 => ScalarType::B64, + Token::DotB128 => ScalarType::B128, + Token::DotPred => ScalarType::Pred, + Token::DotF16 => ScalarType::F16, + Token::DotF16x2 => ScalarType::F16x2, + Token::DotF32 => ScalarType::F32, + Token::DotF64 => ScalarType::F64, + Token::DotBF16 => ScalarType::BF16, + Token::DotBF16x2 => ScalarType::BF16x2, + _ => return None, + }) + }) + .parse_next(stream) +} + +fn predicated_instruction<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + (opt(pred_at), parse_instruction, Token::Semicolon) + .map(|(p, i, _)| ast::Statement::Instruction(p, i)) + .parse_next(stream) +} + +fn pred_at<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult> { + (Token::At, opt(Token::Exclamation), ident) + .map(|(_, not, label)| ast::PredAt { + not: not.is_some(), + label, + }) + .parse_next(stream) +} + +fn label<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + terminated(ident, Token::Colon) + .map(|l| ast::Statement::Label(l)) + .parse_next(stream) +} + +fn debug_directive<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotLoc, + u32, + u32, + u32, + opt(( + Token::Comma, + ident_literal("function_name"), + ident, + dispatch! { any; + Token::Comma => (ident_literal("inlined_at"), u32, u32, u32).void(), + Token::Plus => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(), + _ => fail + }, + )), + ) + .void() + .parse_next(stream) +} + +fn block_statement<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + delimited(Token::LBrace, repeat_without_none(statement), Token::RBrace) + .map(|s| ast::Statement::Block(s)) + .parse_next(stream) +} + +fn repeat_without_none>( + parser: impl Parser, Error>, +) -> impl Parser, Error> { + repeat(0.., parser).fold(Vec::new, |mut acc: Vec<_>, item| { + if let Some(item) = item { + acc.push(item); + } + acc + }) +} + +fn ident_literal< + 'a, + 'input, + I: Stream> + StreamIsPartial, + E: ParserError, +>( + s: &'input str, +) -> impl Parser + 'input { + move |stream: &mut I| { + any.verify(|t| matches!(t, Token::Ident(text) if *text == s)) + .void() + .parse_next(stream) + } +} + +fn u8_literal<'a, 'input>(x: u8) -> impl Parser, (), ContextError> { + move |stream: &mut PtxParser| u8.verify(|t| *t == x).void().parse_next(stream) +} + +impl ast::ParsedOperand { + fn parse<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, + ) -> PResult> { + use winnow::combinator::*; + use winnow::token::any; + fn vector_index<'input>(inp: &'input str) -> Result { + match inp { + ".x" | ".r" => Ok(0), + ".y" | ".g" => Ok(1), + ".z" | ".b" => Ok(2), + ".w" | ".a" => Ok(3), + _ => Err(PtxError::WrongVectorElement), + } + } + fn ident_operands<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, + ) -> PResult> { + let main_ident = ident.parse_next(stream)?; + alt(( + preceded(Token::Plus, s32) + .map(move |offset| ast::ParsedOperand::RegOffset(main_ident, offset)), + take_error(dot_ident.map(move |suffix| { + let vector_index = vector_index(suffix) + .map_err(move |e| (ast::ParsedOperand::VecMember(main_ident, 0), e))?; + Ok(ast::ParsedOperand::VecMember(main_ident, vector_index)) + })), + empty.value(ast::ParsedOperand::Reg(main_ident)), + )) + .parse_next(stream) + } + fn vector_operand<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, + ) -> PResult> { + let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?; + // TODO: parse .v8 literals + dispatch! {any; + Token::RBrace => empty.map(|_| vec![r1, r2]), + Token::Comma => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), + _ => fail + } + .parse_next(stream) + } + alt(( + ident_operands, + immediate_value.map(ast::ParsedOperand::Imm), + vector_operand.map(ast::ParsedOperand::VecPack), + )) + .parse_next(stream) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum PtxError { + #[error("{source}")] + ParseInt { + #[from] + source: ParseIntError, + }, + #[error("{source}")] + ParseFloat { + #[from] + source: ParseFloatError, + }, + #[error("{source}")] + Lexer { + #[from] + source: TokenError, + }, + #[error("")] + Parser(ContextError), + #[error("")] + Todo, + #[error("")] + SyntaxError, + #[error("")] + NonF32Ftz, + #[error("")] + Unsupported32Bit, + #[error("")] + WrongType, + #[error("")] + UnknownFunction, + #[error("")] + MalformedCall, + #[error("")] + WrongArrayType, + #[error("")] + WrongVectorElement, + #[error("")] + MultiArrayVariable, + #[error("")] + ZeroDimensionArray, + #[error("")] + ArrayInitalizer, + #[error("")] + NonExternPointer, + #[error("{start}:{end}")] + UnrecognizedStatement { start: usize, end: usize }, + #[error("{start}:{end}")] + UnrecognizedDirective { start: usize, end: usize }, +} + +#[derive(Debug)] +struct ReverseStream<'a, T>(pub &'a [T]); + +impl<'i, T> Stream for ReverseStream<'i, T> +where + T: Clone + ::std::fmt::Debug, +{ + type Token = T; + type Slice = &'i [T]; + + type IterOffsets = + std::iter::Enumerate>>>; + + type Checkpoint = &'i [T]; + + fn iter_offsets(&self) -> Self::IterOffsets { + self.0.iter().rev().cloned().enumerate() + } + + fn eof_offset(&self) -> usize { + self.0.len() + } + + fn next_token(&mut self) -> Option { + let (token, next) = self.0.split_last()?; + self.0 = next; + Some(token.clone()) + } + + fn offset_for

(&self, predicate: P) -> Option + where + P: Fn(Self::Token) -> bool, + { + self.0.iter().rev().position(|b| predicate(b.clone())) + } + + fn offset_at(&self, tokens: usize) -> Result { + if let Some(needed) = tokens + .checked_sub(self.0.len()) + .and_then(std::num::NonZeroUsize::new) + { + Err(winnow::error::Needed::Size(needed)) + } else { + Ok(tokens) + } + } + + fn next_slice(&mut self, offset: usize) -> Self::Slice { + let offset = self.0.len() - offset; + let (next, slice) = self.0.split_at(offset); + self.0 = next; + slice + } + + fn checkpoint(&self) -> Self::Checkpoint { + self.0 + } + + fn reset(&mut self, checkpoint: &Self::Checkpoint) { + self.0 = checkpoint; + } + + fn raw(&self) -> &dyn std::fmt::Debug { + self + } +} + +impl<'a, T> Offset<&'a [T]> for ReverseStream<'a, T> { + fn offset_from(&self, start: &&'a [T]) -> usize { + let fst = start.as_ptr(); + let snd = self.0.as_ptr(); + + debug_assert!( + snd <= fst, + "`Offset::offset_from({snd:?}, {fst:?})` only accepts slices of `self`" + ); + (fst as usize - snd as usize) / std::mem::size_of::() + } +} + +impl<'a, T> StreamIsPartial for ReverseStream<'a, T> { + type PartialState = (); + + fn complete(&mut self) -> Self::PartialState {} + + fn restore_partial(&mut self, _state: Self::PartialState) {} + + fn is_partial_supported() -> bool { + false + } +} + +impl<'input, I: Stream + StreamIsPartial, E: ParserError> Parser + for Token<'input> +{ + fn parse_next(&mut self, input: &mut I) -> PResult { + any.verify(|t| t == self).parse_next(input) + } +} + +fn bra<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + preceded( + opt(Token::DotUni), + any.verify_map(|t| match t { + Token::Ident(ident) => Some(ast::Instruction::Bra { + arguments: BraArgs { src: ident }, + }), + _ => None, + }), + ) + .parse_next(stream) +} + +fn call<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult>> { + let (uni, return_arguments, name, input_arguments) = ( + opt(Token::DotUni), + opt(( + Token::LParen, + separated(1.., ident, Token::Comma).map(|x: Vec<_>| x), + Token::RParen, + Token::Comma, + ) + .map(|(_, arguments, _, _)| arguments)), + ident, + opt(( + Token::Comma.void(), + Token::LParen.void(), + separated(1.., ParsedOperand::<&'input str>::parse, Token::Comma).map(|x: Vec<_>| x), + Token::RParen.void(), + ) + .map(|(_, _, arguments, _)| arguments)), + ) + .parse_next(stream)?; + let uniform = uni.is_some(); + let recorded_fn = match stream.state.function_declarations.get(name) { + Some(decl) => decl, + None => { + stream.state.errors.push(PtxError::UnknownFunction); + return Ok(empty_call(uniform, name)); + } + }; + let return_arguments = return_arguments.unwrap_or(Vec::new()); + let input_arguments = input_arguments.unwrap_or(Vec::new()); + if recorded_fn.0.len() != return_arguments.len() || recorded_fn.1.len() != input_arguments.len() + { + stream.state.errors.push(PtxError::MalformedCall); + return Ok(empty_call(uniform, name)); + } + let data = CallDetails { + uniform, + return_arguments: recorded_fn.0.clone(), + input_arguments: recorded_fn.1.clone(), + }; + let arguments = CallArgs { + return_arguments, + func: name, + input_arguments, + }; + Ok(ast::Instruction::Call { data, arguments }) +} + +fn empty_call<'input>( + uniform: bool, + name: &'input str, +) -> ast::Instruction> { + ast::Instruction::Call { + data: CallDetails { + uniform, + return_arguments: Vec::new(), + input_arguments: Vec::new(), + }, + arguments: CallArgs { + return_arguments: Vec::new(), + func: name, + input_arguments: Vec::new(), + }, + } +} + +type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>; + +#[derive(Clone, PartialEq, Default, Debug, Display)] +#[display("({}:{})", _0.start, _0.end)] +pub struct TokenError(std::ops::Range); + +impl std::error::Error for TokenError {} + +// This macro is responsible for generating parser code for instruction parser. +// Instruction parsing is by far the most complex part of parsing PTX code: +// * There are tens of instruction kinds, each with slightly different parsing rules +// * After parsing, each instruction needs to do some early validation and generate a specific, +// strongly-typed object. We want strong-typing because we have a single PTX parser frontend, but +// there can be multiple different code emitter backends +// * Most importantly, instruction modifiers can come in aby order, so e.g. both +// `ld.relaxed.global.u32 a, b` and `ld.global.relaxed.u32 a, b` are equally valid. This makes +// classic parsing generators fail: if we tried to generate parsing rules that cover every possible +// ordering we'd need thousands of rules. This is not a purely theoretical problem. NVCC and Clang +// will always emit modifiers in the correct order, but people who write inline assembly usually +// get it wrong (even first party developers) +// +// This macro exists purely to generate repetitive code for parsing each instruction. It is +// _not_ self-contained and is _not_ general-purpose: it relies on certain types and functions from +// the enclosing module +// +// derive_parser!(...) input is split into three parts: +// * Token type definition +// * Partial enums +// * Parsing definitions +// +// Token type definition: +// This is the enum type that will be usesby the instruction parser. For every instruction and +// modifier, derive_parser!(...) will add appropriate variant into this type. So e.g. if there is a +// rule for for `bar.sync` then those two variants wil be appended to the Token enum: +// #[token("bar")] Bar, +// #[token(".sync")] DotSync, +// +// Partial enums: +// With proper annotations, derive_parser!(...) parsing definitions are able to interpret +// instruction modifiers as variants of a single enum type. So e.g. for definitions `ld.u32` and +// `ld.u64` the macro can generate `enum ScalarType { U32, U64 }`. The problem is that for some +// (but not all) of those generated enum types we want to add some attributes and additional +// variants. In order to do so, you need to define this enum and derive_parser!(...) will append to +// the type instead of creating a new type. This is sort of replacement for partial classes known +// from C# +// +// Parsing definitions: +// Parsing definitions consist of a list of patterns and rules: +// * Pattern consists of: +// * Opcode: `ld` +// * Modifiers, always start with a dot: `.global`, `.relaxed`. Optionals are enclosed in braces +// * Arguments: `a`, `b`. Optionals are enclosed in braces +// * Code block: => { }. Code blocks implictly take all modifiers ansd arguments +// as parameters. All modifiers and arguments are passed to the code block: +// * If it is an alternative (as defined in rules list later): +// * If it is mandatory then its type is Foo (as defined by the relevant rule) +// * If it is optional then its type is Option +// * Otherwise: +// * If it is mandatory then it is skipped +// * If it is optional then its type is `bool` +// * List of rules. They are associated with the preceding patterns (until different opcode or +// different rules). Rules are used to resolve modifiers. There are two types of rules: +// * Normal rule: `.foobar: FoobarEnum => { .a, .b, .c }`. This means that instead of `.foobar` we +// expecte one of `.a`, `.b`, `.c` and will emit value FoobarEnum::DotA, FoobarEnum::DotB, +// FoobarEnum::DotC appropriately +// * Type-only rule: `FoobarEnum => { .a, .b, .c }` this means that all the occurences of `.a` will +// emit FoobarEnum::DotA to the code block. This helps to avoid copy-paste errors +// Additionally, you can opt out from the usual parsing rule generation with a special `<=` pattern. +// See `call` instruction to see it in action +derive_parser!( + #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] + #[logos(skip r"(?:\s+)|(?://[^\n\r]*[\n\r]*)|(?:/\*[^*]*\*+(?:[^/*][^*]*\*+)*/)")] + #[logos(error = TokenError)] + enum Token<'input> { + #[token(",")] + Comma, + #[token(".")] + Dot, + #[token(":")] + Colon, + #[token(";")] + Semicolon, + #[token("@")] + At, + #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] + Ident(&'input str), + #[regex(r"\.[a-zA-Z][a-zA-Z0-9_$]*|\.[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] + DotIdent(&'input str), + #[regex(r#""[^"]*""#)] + String, + #[token("|")] + Pipe, + #[token("!")] + Exclamation, + #[token("(")] + LParen, + #[token(")")] + RParen, + #[token("[")] + LBracket, + #[token("]")] + RBracket, + #[token("{")] + LBrace, + #[token("}")] + RBrace, + #[token("<")] + Lt, + #[token(">")] + Gt, + #[regex(r"0[fF][0-9a-zA-Z]{8}", |lex| lex.slice())] + F32(&'input str), + #[regex(r"0[dD][0-9a-zA-Z]{16}", |lex| lex.slice())] + F64(&'input str), + #[regex(r"0[xX][0-9a-zA-Z]+U?", |lex| lex.slice())] + Hex(&'input str), + #[regex(r"[0-9]+U?", |lex| lex.slice())] + Decimal(&'input str), + #[token("-")] + Minus, + #[token("+")] + Plus, + #[token("=")] + Eq, + #[token(".version")] + DotVersion, + #[token(".loc")] + DotLoc, + #[token(".reg")] + DotReg, + #[token(".align")] + DotAlign, + #[token(".pragma")] + DotPragma, + #[token(".maxnreg")] + DotMaxnreg, + #[token(".maxntid")] + DotMaxntid, + #[token(".reqntid")] + DotReqntid, + #[token(".minnctapersm")] + DotMinnctapersm, + #[token(".entry")] + DotEntry, + #[token(".func")] + DotFunc, + #[token(".extern")] + DotExtern, + #[token(".visible")] + DotVisible, + #[token(".target")] + DotTarget, + #[token(".address_size")] + DotAddressSize, + #[token(".action")] + DotSection, + #[token(".file")] + DotFile + } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum StateSpace { + Reg, + Generic, + Sreg, + } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum MemScope { } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum ScalarType { } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum SetpBoolPostOp { } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum AtomSemantics { } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov + mov{.vec}.type d, a => { + Instruction::Mov { + data: ast::MovDetails::new(vec, type_), + arguments: MovArgs { dst: d, src: a }, + } + } + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .pred, + .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st + st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::St { + data: StData { + qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: cop.unwrap_or(RawStCacheOperator::Wb).into(), + typ: ast::Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } + } + st.volatile{.ss}{.vec}.type [a], b => { + Instruction::St { + data: StData { + qualifier: volatile.into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: ast::Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } + } + st.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::St { + data: StData { + qualifier: ast::LdStQualifier::Relaxed(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: ast::Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } + } + st.release.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::St { + data: StData { + qualifier: ast::LdStQualifier::Release(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: ast::Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } + } + st.mmio.relaxed.sys{.global}.type [a], b => { + state.errors.push(PtxError::Todo); + Instruction::St { + data: ast::StData { + qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), + state_space: global.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: type_.into() + }, + arguments: ast::StArgs { src1:a, src2:b } + } + } + .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; + .level::cache_hint = { .L2::cache_hint }; + .cop: RawStCacheOperator = { .wb, .cg, .cs, .wt }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + RawLdStQualifier = { .weak, .volatile }; + StateSpace = { .global }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld + ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => { + let (a, unified) = a; + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || unified || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(), + typ: ast::Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => { + if level_prefetch_size.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: volatile.into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: ast::Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Relaxed(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: ast::Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Acquire(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: ast::Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + ld.mmio.relaxed.sys{.global}.type d, [a] => { + state.errors.push(PtxError::Todo); + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), + state_space: global.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: type_.into(), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + .ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} }; + .cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; + .level::cache_hint = { .L2::cache_hint }; + .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + RawLdStQualifier = { .weak, .volatile }; + StateSpace = { .global }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld-global-nc + ld.global{.cop}.nc{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + if cop.is_some() && level_eviction_priority.is_some() { + state.errors.push(PtxError::SyntaxError); + } + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: global, + caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(), + typ: Type::maybe_vector(vec, type_), + non_coherent: true + }, + arguments: LdArgs { dst:d, src:a } + } + } + .cop: RawLdCacheOperator = { .ca, .cg, .cs }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, + .L1::evict_first, .L1::evict_last, .L1::no_allocate}; + .level::cache_hint = { .L2::cache_hint }; + .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + StateSpace = { .global }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add + add.type d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Integer( + ast::ArithInteger { + type_, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.sat}.s32 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Integer( + ast::ArithInteger { + type_: s32, + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s64, + .u16x2, .s16x2 }; + ScalarType = { .s32 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add + add{.rnd}{.ftz}{.sat}.f32 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.rnd}.f64 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add + add{.rnd}{.ftz}{.sat}.f16 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.rnd}.bf16 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.rnd}.bf16x2 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul + mul.mode.type d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Integer { + type_, + control: mode.into() + }, + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .mode: RawMulIntControl = { .hi, .lo }; + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + // "The .wide suffix is supported only for 16- and 32-bit integer types" + mul.wide.type d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Integer { + type_, + control: wide.into() + }, + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .u16, .u32, + .s16, .s32 }; + RawMulIntControl = { .wide }; + + mul{.rnd}{.ftz}{.sat}.f32 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.f64 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + mul{.rnd}{.ftz}{.sat}.f16 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.bf16 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.bf16x2 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp + setp.CmpOp{.ftz}.type p[|q], a, b => { + let data = ast::SetpData::try_parse(state, cmpop, ftz, type_); + ast::Instruction::Setp { + data, + arguments: SetpArgs { dst1: p, dst2: q, src1: a, src2: b } + } + } + setp.CmpOp.BoolOp{.ftz}.type p[|q], a, b, {!}c => { + let (negate_src3, c) = c; + let base = ast::SetpData::try_parse(state, cmpop, ftz, type_); + let data = ast::SetpBoolData { + base, + bool_op: boolop, + negate_src3 + }; + ast::Instruction::SetpBool { + data, + arguments: SetpBoolArgs { dst1: p, dst2: q, src1: a, src2: b, src3: c } + } + } + .CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge, + .lo, .ls, .hi, .hs, // signed + .equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only + .BoolOp: SetpBoolPostOp = { .and, .or, .xor }; + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64, + .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not + not.type d, a => { + ast::Instruction::Not { + data: type_, + arguments: NotArgs { dst: d, src: a } + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-or + or.type d, a, b => { + ast::Instruction::Or { + data: type_, + arguments: OrArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-and + and.type d, a, b => { + ast::Instruction::And { + data: type_, + arguments: AndArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra + bra <= { bra(stream) } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call + call <= { call(stream) } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt + cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => { + let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype); + let arguments = ast::CvtArgs { dst: d, src: a }; + ast::Instruction::Cvt { + data, arguments + } + } + // cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a; + // cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b; + // cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a; + // cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b; + // cvt.rna{.satfinite}.tf32.f32 d, a; + // cvt.frnd2{.relu}.tf32.f32 d, a; + // cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b; + // cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a; + // cvt.rn.{.relu}.f16x2.f8x2type d, a; + + .ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi }; + .frnd2: RawRoundingMode = { .rn, .rz }; + .dtype: ScalarType = { .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .bf16, .f16, .f32, .f64 }; + .atype: ScalarType = { .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .bf16, .f16, .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl + shl.type d, a, b => { + ast::Instruction::Shl { data: type_, arguments: ShlArgs { dst: d, src1: a, src2: b } } + } + .type: ScalarType = { .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr + shr.type d, a, b => { + let kind = if type_.kind() == ast::ScalarKind::Signed { RightShiftKind::Arithmetic} else { RightShiftKind::Logical }; + ast::Instruction::Shr { + data: ast::ShrData { type_, kind }, + arguments: ShrArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvta + cvta.space.size p, a => { + if size != ScalarType::U64 { + state.errors.push(PtxError::Unsupported32Bit); + } + let data = ast::CvtaDetails { + state_space: space, + direction: ast::CvtaDirection::ExplicitToGeneric + }; + let arguments = ast::CvtaArgs { + dst: p, src: a + }; + ast::Instruction::Cvta { + data, arguments + } + } + cvta.to.space.size p, a => { + if size != ScalarType::U64 { + state.errors.push(PtxError::Unsupported32Bit); + } + let data = ast::CvtaDetails { + state_space: space, + direction: ast::CvtaDirection::GenericToExplicit + }; + let arguments = ast::CvtaArgs { + dst: p, src: a + }; + ast::Instruction::Cvta { + data, arguments + } + } + .space: StateSpace = { .const, .global, .local, .shared{::cta, ::cluster}, .param{::entry} }; + .size: ScalarType = { .u32, .u64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-abs + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-abs + abs.type d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: None, + type_ + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f32 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.f64 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: None, + type_: f64 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f16 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: Some(ftz), + type_: f16 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f16x2 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: Some(ftz), + type_: f16x2 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.bf16 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: None, + type_: bf16 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.bf16x2 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: None, + type_: bf16x2 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + .type: ScalarType = { .s16, .s32, .s64 }; + ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad + mad.mode.type d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_, + control: mode.into(), + saturate: false + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + .mode: RawMulIntControl = { .hi, .lo }; + + // The .wide suffix is supported only for 16-bit and 32-bit integer types. + mad.wide.type d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_, + control: wide.into(), + saturate: false + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .u16, .u32, + .s16, .s32 }; + RawMulIntControl = { .wide }; + + mad.hi.sat.s32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_: s32, + control: hi.into(), + saturate: true + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + RawMulIntControl = { .hi }; + ScalarType = { .s32 }; + + mad{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: None, + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + mad.rnd{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + mad.rnd.f64 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: Some(rnd.into()), + flush_to_zero: None, + saturate: false + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-fma + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-fma + fma.rnd{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f32, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + fma.rnd.f64 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f64, + rounding: Some(rnd.into()), + flush_to_zero: None, + saturate: false + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + fma.rnd{.ftz}{.sat}.f16 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f16, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + //fma.rnd{.ftz}{.sat}.f16x2 d, a, b, c; + //fma.rnd{.ftz}.relu.f16 d, a, b, c; + //fma.rnd{.ftz}.relu.f16x2 d, a, b, c; + //fma.rnd{.relu}.bf16 d, a, b, c; + //fma.rnd{.relu}.bf16x2 d, a, b, c; + //fma.rnd.oob.{relu}.type d, a, b, c; + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub + sub.type d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Integer( + ArithInteger { + type_, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub.sat.s32 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Integer( + ArithInteger { + type_: s32, + saturate: true + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + ScalarType = { .s32 }; + + sub{.rnd}{.ftz}{.sat}.f32 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.f64 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + sub{.rnd}{.ftz}{.sat}.f16 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.bf16 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.bf16x2 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min + min.atype d, a, b => { + ast::Instruction::Min { + data: if atype.kind() == ast::ScalarKind::Signed { + ast::MinMaxDetails::Signed(atype) + } else { + ast::MinMaxDetails::Unsigned(atype) + }, + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + //min{.relu}.btype d, a, b => { todo!() } + min.btype d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Signed(btype), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + .atype: ScalarType = { .u16, .u32, .u64, + .u16x2, .s16, .s64 }; + .btype: ScalarType = { .s16x2, .s32 }; + + //min{.ftz}{.NaN}{.xorsign.abs}.f32 d, a, b; + min{.ftz}{.NaN}.f32 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f32 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min.f64 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan: false, + type_: f64 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f32, .f64 }; + + //min{.ftz}{.NaN}{.xorsign.abs}.f16 d, a, b; + //min{.ftz}{.NaN}{.xorsign.abs}.f16x2 d, a, b; + //min{.NaN}{.xorsign.abs}.bf16 d, a, b; + //min{.NaN}{.xorsign.abs}.bf16x2 d, a, b; + min{.ftz}{.NaN}.f16 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min{.ftz}{.NaN}.f16x2 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16x2 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min{.NaN}.bf16 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min{.NaN}.bf16x2 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16x2 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max + max.atype d, a, b => { + ast::Instruction::Max { + data: if atype.kind() == ast::ScalarKind::Signed { + ast::MinMaxDetails::Signed(atype) + } else { + ast::MinMaxDetails::Unsigned(atype) + }, + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + //max{.relu}.btype d, a, b => { todo!() } + max.btype d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Signed(btype), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + .atype: ScalarType = { .u16, .u32, .u64, + .u16x2, .s16, .s64 }; + .btype: ScalarType = { .s16x2, .s32 }; + + //max{.ftz}{.NaN}{.xorsign.abs}.f32 d, a, b; + max{.ftz}{.NaN}.f32 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f32 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max.f64 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan: false, + type_: f64 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f32, .f64 }; + + //max{.ftz}{.NaN}{.xorsign.abs}.f16 d, a, b; + //max{.ftz}{.NaN}{.xorsign.abs}.f16x2 d, a, b; + //max{.NaN}{.xorsign.abs}.bf16 d, a, b; + //max{.NaN}{.xorsign.abs}.bf16x2 d, a, b; + max{.ftz}{.NaN}.f16 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max{.ftz}{.NaN}.f16x2 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16x2 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max{.NaN}.bf16 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max{.NaN}.bf16x2 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16x2 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp-approx-ftz-f64 + rcp.approx{.ftz}.type d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Approx, + flush_to_zero: Some(ftz), + type_ + }, + arguments: RcpArgs { dst: d, src: a } + } + } + rcp.rnd{.ftz}.f32 d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Compliant(rnd.into()), + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: RcpArgs { dst: d, src: a } + } + } + rcp.rnd.f64 d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Compliant(rnd.into()), + flush_to_zero: None, + type_: f64 + }, + arguments: RcpArgs { dst: d, src: a } + } + } + .type: ScalarType = { .f32, .f64 }; + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sqrt + sqrt.approx{.ftz}.f32 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Approx, + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + sqrt.rnd{.ftz}.f32 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Compliant(rnd.into()), + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + sqrt.rnd.f64 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Compliant(rnd.into()), + flush_to_zero: None, + type_: f64 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64 + rsqrt.approx{.ftz}.f32 d, a => { + ast::Instruction::Rsqrt { + data: ast::TypeFtz { + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + rsqrt.approx.f64 d, a => { + ast::Instruction::Rsqrt { + data: ast::TypeFtz { + flush_to_zero: None, + type_: f64 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + rsqrt.approx.ftz.f64 d, a => { + ast::Instruction::Rsqrt { + data: ast::TypeFtz { + flush_to_zero: None, + type_: f64 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp + selp.type d, a, b, c => { + ast::Instruction::Selp { + data: type_, + arguments: SelpArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar + barrier{.cta}.sync{.aligned} a{, b} => { + let _ = cta; + ast::Instruction::Bar { + data: ast::BarData { aligned }, + arguments: BarArgs { src1: a, src2: b } + } + } + //barrier{.cta}.arrive{.aligned} a, b; + //barrier{.cta}.red.popc{.aligned}.u32 d, a{, b}, {!}c; + //barrier{.cta}.red.op{.aligned}.pred p, a{, b}, {!}c; + bar{.cta}.sync a{, b} => { + let _ = cta; + ast::Instruction::Bar { + data: ast::BarData { aligned: true }, + arguments: BarArgs { src1: a, src2: b } + } + } + //bar{.cta}.arrive a, b; + //bar{.cta}.red.popc.u32 d, a{, b}, {!}c; + //bar{.cta}.red.op.pred p, a{, b}, {!}c; + //.op = { .and, .or }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom + atom{.sem}{.scope}{.space}.op{.level::cache_hint}.type d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: space.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(op, type_.kind()), + type_: type_.into() + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.space}.cas.cas_type d, [a], b, c => { + ast::Instruction::AtomCas { + data: AtomCasDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: space.unwrap_or(StateSpace::Generic), + type_: cas_type + }, + arguments: AtomCasArgs { dst: d, src1: a, src2: b, src3: c } + } + } + atom{.sem}{.scope}{.space}.exch{.level::cache_hint}.b128 d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: space.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(exch, b128.kind()), + type_: b128.into() + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.global}.float_op{.level::cache_hint}.vec_32_bit.f32 d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: global.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(float_op, f32.kind()), + type_: ast::Type::Vector(vec_32_bit.len().get(), f32) + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.global}.float_op.noftz{.level::cache_hint}{.vec_16_bit}.half_word_type d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: global.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(float_op, half_word_type.kind()), + type_: ast::Type::maybe_vector(vec_16_bit, half_word_type) + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.global}.float_op.noftz{.level::cache_hint}{.vec_32_bit}.packed_type d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: global.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(float_op, packed_type.kind()), + type_: ast::Type::maybe_vector(vec_32_bit, packed_type) + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + .space: StateSpace = { .global, .shared{::cta, ::cluster} }; + .sem: AtomSemantics = { .relaxed, .acquire, .release, .acq_rel }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .op: RawAtomicOp = { .and, .or, .xor, + .exch, + .add, .inc, .dec, + .min, .max }; + .level::cache_hint = { .L2::cache_hint }; + .type: ScalarType = { .b32, .b64, .u32, .u64, .s32, .s64, .f32, .f64 }; + .cas_type: ScalarType = { .b32, .b64, .u32, .u64, .s32, .s64, .f32, .f64, .b16, .b128 }; + .half_word_type: ScalarType = { .f16, .bf16 }; + .packed_type: ScalarType = { .f16x2, .bf16x2 }; + .vec_16_bit: VectorPrefix = { .v2, .v4, .v8 }; + .vec_32_bit: VectorPrefix = { .v2, .v4 }; + .float_op: RawAtomicOp = { .add, .min, .max }; + ScalarType = { .b16, .b128, .f32 }; + StateSpace = { .global }; + RawAtomicOp = { .exch }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-div + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div + div.type d, a, b => { + ast::Instruction::Div { + data: if type_.kind() == ast::ScalarKind::Signed { + ast::DivDetails::Signed(type_) + } else { + ast::DivDetails::Unsigned(type_) + }, + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + + div.approx{.ftz}.f32 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f32, + flush_to_zero: Some(ftz), + kind: ast::DivFloatKind::Approx + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + div.full{.ftz}.f32 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f32, + flush_to_zero: Some(ftz), + kind: ast::DivFloatKind::ApproxFull + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + div.rnd{.ftz}.f32 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f32, + flush_to_zero: Some(ftz), + kind: ast::DivFloatKind::Rounding(rnd.into()) + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + div.rnd.f64 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f64, + flush_to_zero: None, + kind: ast::DivFloatKind::Rounding(rnd.into()) + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-neg + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-neg + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-neg + neg.type d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .s16, .s32, .s64 }; + + neg{.ftz}.f32 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f32, + flush_to_zero: Some(ftz) + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg.f64 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f64, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg{.ftz}.f16 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f16, + flush_to_zero: Some(ftz) + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg{.ftz}.f16x2 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f16x2, + flush_to_zero: Some(ftz) + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg.bf16 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: bf16, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg.bf16x2 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: bf16x2, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sin + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-cos + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-lg2 + sin.approx{.ftz}.f32 d, a => { + ast::Instruction::Sin { + data: ast::FlushToZero { + flush_to_zero: ftz + }, + arguments: SinArgs { dst: d, src: a, }, + } + } + cos.approx{.ftz}.f32 d, a => { + ast::Instruction::Cos { + data: ast::FlushToZero { + flush_to_zero: ftz + }, + arguments: CosArgs { dst: d, src: a, }, + } + } + lg2.approx{.ftz}.f32 d, a => { + ast::Instruction::Lg2 { + data: ast::FlushToZero { + flush_to_zero: ftz + }, + arguments: Lg2Args { dst: d, src: a, }, + } + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-ex2 + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-ex2 + ex2.approx{.ftz}.f32 d, a => { + ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_: f32, + flush_to_zero: Some(ftz) + }, + arguments: Ex2Args { dst: d, src: a, }, + } + } + ex2.approx.atype d, a => { + ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_: atype, + flush_to_zero: None + }, + arguments: Ex2Args { dst: d, src: a, }, + } + } + ex2.approx.ftz.btype d, a => { + ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_: btype, + flush_to_zero: Some(true) + }, + arguments: Ex2Args { dst: d, src: a, }, + } + } + .atype: ScalarType = { .f16, .f16x2 }; + .btype: ScalarType = { .bf16, .bf16x2 }; + ScalarType = { .f32 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-clz + clz.type d, a => { + ast::Instruction::Clz { + data: type_, + arguments: ClzArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-brev + brev.type d, a => { + ast::Instruction::Brev { + data: type_, + arguments: BrevArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-popc + popc.type d, a => { + ast::Instruction::Popc { + data: type_, + arguments: PopcArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-xor + xor.type d, a, b => { + ast::Instruction::Xor { + data: type_, + arguments: XorArgs { dst: d, src1: a, src2: b, }, + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-rem + rem.type d, a, b => { + ast::Instruction::Rem { + data: type_, + arguments: RemArgs { dst: d, src1: a, src2: b, }, + } + } + .type: ScalarType = { .u16, .u32, .u64, .s16, .s32, .s64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfe + bfe.type d, a, b, c => { + ast::Instruction::Bfe { + data: type_, + arguments: BfeArgs { dst: d, src1: a, src2: b, src3: c }, + } + } + .type: ScalarType = { .u32, .u64, .s32, .s64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfi + bfi.type f, a, b, c, d => { + ast::Instruction::Bfi { + data: type_, + arguments: BfiArgs { dst: f, src1: a, src2: b, src3: c, src4: d }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt + // prmt.b32{.mode} d, a, b, c; + // .mode = { .f4e, .b4e, .rc8, .ecl, .ecr, .rc16 }; + prmt.b32 d, a, b, c => { + match c { + ast::ParsedOperand::Imm(ImmediateValue::S64(control)) => ast::Instruction::Prmt { + data: control as u16, + arguments: PrmtArgs { + dst: d, src1: a, src2: b + } + }, + _ => ast::Instruction::PrmtSlow { + arguments: PrmtSlowArgs { + dst: d, src1: a, src2: b, src3: c + } + } + } + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-activemask + activemask.b32 d => { + ast::Instruction::Activemask { + arguments: ActivemaskArgs { dst: d } + } + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar + // fence{.sem}.scope; + // fence.op_restrict.release.cluster; + // fence.proxy.proxykind; + // fence.proxy.to_proxykind::from_proxykind.release.scope; + // fence.proxy.to_proxykind::from_proxykind.acquire.scope [addr], size; + //membar.proxy.proxykind; + //.sem = { .sc, .acq_rel }; + //.scope = { .cta, .cluster, .gpu, .sys }; + //.proxykind = { .alias, .async, async.global, .async.shared::{cta, cluster} }; + //.op_restrict = { .mbarrier_init }; + //.to_proxykind::from_proxykind = {.tensormap::generic}; + + membar.level => { + ast::Instruction::Membar { data: level } + } + membar.gl => { + ast::Instruction::Membar { data: MemScope::Gpu } + } + .level: MemScope = { .cta, .sys }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret + ret{.uni} => { + Instruction::Ret { data: RetData { uniform: uni } } + } + +); + +#[cfg(test)] +mod tests { + use super::target; + use super::PtxParserState; + use super::Token; + use logos::Logos; + use winnow::prelude::*; + + #[test] + fn sm_11() { + let tokens = Token::lexer(".target sm_11") + .collect::, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: PtxParserState::new(), + }; + assert_eq!(target.parse(stream).unwrap(), (11, None)); + } + + #[test] + fn sm_90a() { + let tokens = Token::lexer(".target sm_90a") + .collect::, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: PtxParserState::new(), + }; + assert_eq!(target.parse(stream).unwrap(), (90, Some('a'))); + } + + #[test] + fn sm_90ab() { + let tokens = Token::lexer(".target sm_90ab") + .collect::, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: PtxParserState::new(), + }; + assert!(target.parse(stream).is_err()); + } +} diff --git a/ptx_parser_macros/Cargo.toml b/ptx_parser_macros/Cargo.toml new file mode 100644 index 00000000..62a5081b --- /dev/null +++ b/ptx_parser_macros/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "ptx_parser_macros" +version = "0.0.0" +authors = ["Andrzej Janik "] +edition = "2021" + +[lib] +proc-macro = true + +[dependencies] +ptx_parser_macros_impl = { path = "../ptx_parser_macros_impl" } +convert_case = "0.6.0" +rustc-hash = "2.0.0" +syn = "2.0.67" +quote = "1.0" +proc-macro2 = "1.0.86" +either = "1.13.0" diff --git a/ptx_parser_macros/src/lib.rs b/ptx_parser_macros/src/lib.rs new file mode 100644 index 00000000..5f47fac7 --- /dev/null +++ b/ptx_parser_macros/src/lib.rs @@ -0,0 +1,1023 @@ +use either::Either; +use ptx_parser_macros_impl::parser; +use proc_macro2::{Span, TokenStream}; +use quote::{format_ident, quote, ToTokens}; +use rustc_hash::{FxHashMap, FxHashSet}; +use std::{collections::hash_map, hash::Hash, iter, rc::Rc}; +use syn::{ + parse_macro_input, parse_quote, punctuated::Punctuated, Ident, ItemEnum, Token, Type, TypePath, + Variant, +}; + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#vectors +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#fundamental-types +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#alternate-floating-point-data-formats +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-floating-point-data-types +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#packed-integer-data-types +#[rustfmt::skip] +static POSTFIX_MODIFIERS: &[&str] = &[ + ".v2", ".v4", ".v8", + ".s8", ".s16", ".s16x2", ".s32", ".s64", + ".u8", ".u16", ".u16x2", ".u32", ".u64", + ".f16", ".f16x2", ".f32", ".f64", + ".b8", ".b16", ".b32", ".b64", ".b128", + ".pred", + ".bf16", ".bf16x2", ".e4m3", ".e5m2", ".tf32", +]; + +static POSTFIX_TYPES: &[&str] = &["ScalarType", "VectorPrefix"]; + +struct OpcodeDefinitions { + definitions: Vec, + block_selection: Vec>, usize)>>, +} + +impl OpcodeDefinitions { + fn new(opcode: &Ident, definitions: Vec) -> Self { + let mut selections = vec![None; definitions.len()]; + let mut generation = 0usize; + loop { + let mut selected_something = false; + let unselected = selections + .iter() + .enumerate() + .filter_map(|(idx, s)| if s.is_none() { Some(idx) } else { None }) + .collect::>(); + match &*unselected { + [] => break, + [remaining] => { + selections[*remaining] = Some((None, generation)); + break; + } + _ => {} + } + 'check_definitions: for i in unselected.iter().copied() { + let mut candidates = definitions[i] + .unordered_modifiers + .iter() + .chain(definitions[i].ordered_modifiers.iter()) + .filter(|modifier| match modifier { + DotModifierRef::Direct { + optional: false, .. + } + | DotModifierRef::Indirect { + optional: false, .. + } => true, + _ => false, + }) + .collect::>(); + candidates.sort_by_key(|modifier| match modifier { + DotModifierRef::Direct { .. } => 1, + DotModifierRef::Indirect { value, .. } => value.alternatives.len(), + }); + // Attempt every modifier + 'check_candidates: for candidate_modifier in candidates { + // check all other unselected patterns + for j in unselected.iter().copied() { + if i == j { + continue; + } + let candidate_set = match candidate_modifier { + DotModifierRef::Direct { value, .. } => Either::Left(iter::once(value)), + DotModifierRef::Indirect { value, .. } => { + Either::Right(value.alternatives.iter()) + } + }; + for candidate_value in candidate_set { + if definitions[j].possible_modifiers.contains(candidate_value) { + continue 'check_candidates; + } + } + } + // it's unique + let candidate_vec = match candidate_modifier { + DotModifierRef::Direct { value, .. } => vec![value.clone()], + DotModifierRef::Indirect { value, .. } => { + value.alternatives.iter().cloned().collect::>() + } + }; + selections[i] = Some((Some(candidate_vec), generation)); + selected_something = true; + continue 'check_definitions; + } + } + if !selected_something { + panic!( + "Failed to generate pattern selection for `{}`. State: {:?}", + opcode, + selections.into_iter().rev().collect::>() + ); + } + generation += 1; + } + let mut block_selection = Vec::new(); + for current_generation in 0usize.. { + let mut current_generation_definitions = Vec::new(); + for (idx, selection) in selections.iter_mut().enumerate() { + match selection { + Some((modifier_set, generation)) => { + if *generation == current_generation { + current_generation_definitions.push((modifier_set.clone(), idx)); + *selection = None; + } + } + None => {} + } + } + if current_generation_definitions.is_empty() { + break; + } + block_selection.push(current_generation_definitions); + } + #[cfg(debug_assertions)] + { + let selected = block_selection + .iter() + .map(|x| x.len()) + .reduce(|x, y| x + y) + .unwrap(); + if selected != definitions.len() { + panic!( + "Internal error when generating pattern selection for `{}`: {:?}", + opcode, &block_selection + ); + } + } + Self { + definitions, + block_selection, + } + } + + fn get_enum_types( + parse_definitions: &[parser::OpcodeDefinition], + ) -> FxHashMap> { + let mut result = FxHashMap::default(); + for parser::OpcodeDefinition(_, rules) in parse_definitions.iter() { + for rule in rules { + let type_ = match rule.type_ { + Some(ref type_) => type_.clone(), + None => continue, + }; + let insert_values = |set: &mut FxHashSet<_>| { + for value in rule.alternatives.iter().cloned() { + set.insert(value); + } + }; + match result.entry(type_) { + hash_map::Entry::Occupied(mut entry) => insert_values(entry.get_mut()), + hash_map::Entry::Vacant(entry) => { + insert_values(entry.insert(FxHashSet::default())) + } + }; + } + } + result + } +} + +struct SingleOpcodeDefinition { + possible_modifiers: FxHashSet, + unordered_modifiers: Vec, + ordered_modifiers: Vec, + arguments: parser::Arguments, + code_block: parser::CodeBlock, +} + +impl SingleOpcodeDefinition { + fn function_arguments_declarations(&self) -> impl Iterator + '_ { + self.unordered_modifiers + .iter() + .chain(self.ordered_modifiers.iter()) + .filter_map(|modf| { + let type_ = modf.type_of(); + type_.map(|t| { + let name = modf.ident(); + quote! { #name : #t } + }) + }) + .chain(self.arguments.0.iter().map(|arg| { + let name = &arg.ident; + let arg_type = if arg.unified { + quote! { (ParsedOperandStr<'input>, bool) } + } else if arg.can_be_negated { + quote! { (bool, ParsedOperandStr<'input>) } + } else { + quote! { ParsedOperandStr<'input> } + }; + if arg.optional { + quote! { #name : Option<#arg_type> } + } else { + quote! { #name : #arg_type } + } + })) + } + + fn function_arguments(&self) -> impl Iterator + '_ { + self.unordered_modifiers + .iter() + .chain(self.ordered_modifiers.iter()) + .filter_map(|modf| { + let type_ = modf.type_of(); + type_.map(|_| { + let name = modf.ident(); + quote! { #name } + }) + }) + .chain(self.arguments.0.iter().map(|arg| { + let name = &arg.ident; + quote! { #name } + })) + } + + fn extract_and_insert( + definitions: &mut FxHashMap>, + special_definitions: &mut FxHashMap, + parser::OpcodeDefinition(pattern_seq, rules): parser::OpcodeDefinition, + ) { + let (mut named_rules, mut unnamed_rules) = gather_rules(rules); + let mut last_opcode = pattern_seq.0.last().unwrap().0 .0.name.clone(); + for (opcode_decl, code_block) in pattern_seq.0.into_iter().rev() { + let current_opcode = opcode_decl.0.name.clone(); + if last_opcode != current_opcode { + named_rules = FxHashMap::default(); + unnamed_rules = FxHashMap::default(); + } + let parser::OpcodeDecl(instruction, arguments) = opcode_decl; + if code_block.special { + if !instruction.modifiers.is_empty() || !arguments.0.is_empty() { + panic!( + "`{}`: no modifiers or arguments are allowed in parser definition.", + instruction.name + ); + } + special_definitions.insert(instruction.name, code_block.code); + continue; + } + let mut possible_modifiers = FxHashSet::default(); + let mut unordered_modifiers = instruction + .modifiers + .into_iter() + .map(|parser::MaybeDotModifier { optional, modifier }| { + match named_rules.get(&modifier) { + Some(alts) => { + possible_modifiers.extend(alts.alternatives.iter().cloned()); + if alts.alternatives.len() == 1 && alts.type_.is_none() { + DotModifierRef::Direct { + optional, + value: alts.alternatives[0].clone(), + name: modifier, + type_: alts.type_.clone(), + } + } else { + DotModifierRef::Indirect { + optional, + value: alts.clone(), + name: modifier, + } + } + } + None => { + let type_ = unnamed_rules.get(&modifier).cloned(); + possible_modifiers.insert(modifier.clone()); + DotModifierRef::Direct { + optional, + value: modifier.clone(), + name: modifier, + type_, + } + } + } + }) + .collect::>(); + let ordered_modifiers = Self::extract_ordered_modifiers(&mut unordered_modifiers); + let entry = Self { + possible_modifiers, + unordered_modifiers, + ordered_modifiers, + arguments, + code_block, + }; + multihash_extend(definitions, current_opcode.clone(), entry); + last_opcode = current_opcode; + } + } + + fn extract_ordered_modifiers( + unordered_modifiers: &mut Vec, + ) -> Vec { + let mut result = Vec::new(); + loop { + let is_ordered = match unordered_modifiers.last() { + Some(DotModifierRef::Direct { value, .. }) => { + let name = value.to_string(); + POSTFIX_MODIFIERS.contains(&&*name) + } + Some(DotModifierRef::Indirect { value, .. }) => { + let type_ = value.type_.to_token_stream().to_string(); + //panic!("{} {}", type_, POSTFIX_TYPES.contains(&&*type_)); + POSTFIX_TYPES.contains(&&*type_) + } + None => break, + }; + if is_ordered { + result.push(unordered_modifiers.pop().unwrap()); + } else { + break; + } + } + if unordered_modifiers.len() == 1 { + result.push(unordered_modifiers.pop().unwrap()); + } + result.reverse(); + result + } +} + +fn gather_rules( + rules: Vec, +) -> ( + FxHashMap>, + FxHashMap, +) { + let mut named = FxHashMap::default(); + let mut unnamed = FxHashMap::default(); + for rule in rules { + match rule.modifier { + Some(ref modifier) => { + named.insert(modifier.clone(), Rc::new(rule)); + } + None => unnamed.extend( + rule.alternatives + .into_iter() + .map(|alt| (alt, rule.type_.as_ref().unwrap().clone())), + ), + } + } + (named, unnamed) +} + +#[proc_macro] +pub fn derive_parser(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { + let parse_definitions = parse_macro_input!(tokens as ptx_parser_macros_impl::parser::ParseDefinitions); + let mut definitions = FxHashMap::default(); + let mut special_definitions = FxHashMap::default(); + let types = OpcodeDefinitions::get_enum_types(&parse_definitions.definitions); + let enum_types_tokens = emit_enum_types(types, parse_definitions.additional_enums); + for definition in parse_definitions.definitions.into_iter() { + SingleOpcodeDefinition::extract_and_insert( + &mut definitions, + &mut special_definitions, + definition, + ); + } + let definitions = definitions + .into_iter() + .map(|(k, v)| { + let v = OpcodeDefinitions::new(&k, v); + (k, v) + }) + .collect::>(); + let mut token_enum = parse_definitions.token_type; + let (all_opcode, all_modifier) = write_definitions_into_tokens( + &definitions, + special_definitions.keys(), + &mut token_enum.variants, + ); + let token_impl = emit_parse_function(&token_enum.ident, &definitions, &special_definitions, all_opcode, all_modifier); + let tokens = quote! { + #enum_types_tokens + + #token_enum + + #token_impl + }; + tokens.into() +} + +fn emit_enum_types( + types: FxHashMap>, + mut existing_enums: FxHashMap, +) -> TokenStream { + let token_types = types.into_iter().filter_map(|(type_, variants)| { + match type_ { + syn::Type::Path(TypePath { + qself: None, + ref path, + }) => { + if let Some(ident) = path.get_ident() { + if let Some(enum_) = existing_enums.get_mut(ident) { + enum_.variants.extend(variants.into_iter().map(|modifier| { + let ident = modifier.variant_capitalized(); + let variant: syn::Variant = syn::parse_quote! { + #ident + }; + variant + })); + return None; + } + } + } + _ => {} + } + let variants = variants.iter().map(|v| v.variant_capitalized()); + Some(quote! { + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + enum #type_ { + #(#variants),* + } + }) + }); + let mut result = TokenStream::new(); + for tokens in token_types { + tokens.to_tokens(&mut result); + } + for (_, enum_) in existing_enums { + quote! { #enum_ }.to_tokens(&mut result); + } + result +} + +fn emit_parse_function( + type_name: &Ident, + defs: &FxHashMap, + special_defs: &FxHashMap, + all_opcode: Vec<&Ident>, + all_modifier: FxHashSet<&parser::DotModifier>, +) -> TokenStream { + use std::fmt::Write; + let fns_ = defs + .iter() + .map(|(opcode, defs)| { + defs.definitions.iter().enumerate().map(|(idx, def)| { + let mut fn_name = opcode.to_string(); + write!(&mut fn_name, "_{}", idx).ok(); + let fn_name = Ident::new(&fn_name, Span::call_site()); + let code_block = &def.code_block.code; + let args = def.function_arguments_declarations(); + quote! { + fn #fn_name<'input>(state: &mut PtxParserState, #(#args),* ) -> Instruction> #code_block + } + }) + }) + .flatten(); + let selectors = defs.iter().map(|(opcode, def)| { + let opcode_variant = Ident::new(&capitalize(&opcode.to_string()), opcode.span()); + let mut result = TokenStream::new(); + let mut selectors = TokenStream::new(); + quote! { + if false { + unsafe { std::hint::unreachable_unchecked() } + } + } + .to_tokens(&mut selectors); + let mut has_default_selector = false; + for selection_layer in def.block_selection.iter() { + for (selection_key, selected_definition) in selection_layer { + let def_parser = emit_definition_parser(type_name, (opcode,*selected_definition), &def.definitions[*selected_definition]); + match selection_key { + Some(selection_keys) => { + let selection_keys = selection_keys.iter().map(|k| k.dot_capitalized()); + quote! { + else if false #(|| modifiers.contains(& #type_name :: #selection_keys))* { + #def_parser + } + } + .to_tokens(&mut selectors); + } + None => { + has_default_selector = true; + quote! { + else { + #def_parser + } + } + .to_tokens(&mut selectors); + } + } + } + } + if !has_default_selector { + quote! { + else { + return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) + } + } + .to_tokens(&mut selectors); + } + quote! { + #opcode_variant => { + let modifers_start = stream.checkpoint(); + let modifiers = take_while(0.., Token::modifier).parse_next(stream)?; + #selectors + } + } + .to_tokens(&mut result); + result + }).chain(special_defs.iter().map(|(opcode, code)| { + let opcode_variant = Ident::new(&capitalize(&opcode.to_string()), opcode.span()); + quote! { + #opcode_variant => { #code? } + } + })); + let opcodes = all_opcode.into_iter().map(|op_ident| { + let op = op_ident.to_string(); + let variant = Ident::new(&capitalize(&op), op_ident.span()); + let value = op; + quote! { + #type_name :: #variant => Some(#value), + } + }); + let modifier_names = iter::once(Ident::new("DotUnified", Span::call_site())) + .chain(all_modifier.iter().map(|m| m.dot_capitalized())); + quote! { + impl<'input> #type_name<'input> { + fn opcode_text(self) -> Option<&'static str> { + match self { + #(#opcodes)* + _ => None + } + } + + fn modifier(self) -> bool { + match self { + #( + #type_name :: #modifier_names => true, + )* + _ => false + } + } + } + + #(#fns_)* + + fn parse_instruction<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> winnow::error::PResult>> + { + use winnow::Parser; + use winnow::token::*; + use winnow::combinator::*; + let opcode = any.parse_next(stream)?; + let modifiers_start = stream.checkpoint(); + Ok(match opcode { + #( + #type_name :: #selectors + )* + _ => return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) + }) + } + } +} + +fn emit_definition_parser( + token_type: &Ident, + (opcode, fn_idx): (&Ident, usize), + definition: &SingleOpcodeDefinition, +) -> TokenStream { + let return_error_ref = quote! { + return Err(winnow::error::ErrMode::from_error_kind(&stream, winnow::error::ErrorKind::Token)) + }; + let return_error = quote! { + return Err(winnow::error::ErrMode::from_error_kind(stream, winnow::error::ErrorKind::Token)) + }; + let ordered_parse_declarations = definition.ordered_modifiers.iter().map(|modifier| { + modifier.type_of().map(|type_| { + let name = modifier.ident(); + quote! { + let #name : #type_; + } + }) + }); + let ordered_parse = definition.ordered_modifiers.iter().rev().map(|modifier| { + let arg_name = modifier.ident(); + match modifier { + DotModifierRef::Direct { optional, value, type_: None, .. } => { + let variant = value.dot_capitalized(); + if *optional { + quote! { + #arg_name = opt(any.verify(|t| *t == #token_type :: #variant)).parse_next(&mut stream)?.is_some(); + } + } else { + quote! { + any.verify(|t| *t == #token_type :: #variant).parse_next(&mut stream)?; + } + } + } + DotModifierRef::Direct { optional: false, type_: Some(type_), name, value } => { + let variable = name.ident(); + let variant = value.dot_capitalized(); + let parsed_variant = value.variant_capitalized(); + quote! { + any.verify(|t| *t == #token_type :: #variant).parse_next(&mut stream)?; + #variable = #type_ :: #parsed_variant; + } + } + DotModifierRef::Direct { optional: true, type_: Some(_), .. } => { todo!() } + DotModifierRef::Indirect { optional, value, .. } => { + let variants = value.alternatives.iter().map(|alt| { + let type_ = value.type_.as_ref().unwrap(); + let token_variant = alt.dot_capitalized(); + let parsed_variant = alt.variant_capitalized(); + quote! { + #token_type :: #token_variant => #type_ :: #parsed_variant, + } + }); + if *optional { + quote! { + #arg_name = opt(any.verify_map(|tok| { + Some(match tok { + #(#variants)* + _ => return None + }) + })).parse_next(&mut stream)?; + } + } else { + quote! { + #arg_name = any.verify_map(|tok| { + Some(match tok { + #(#variants)* + _ => return None + }) + }).parse_next(&mut stream)?; + } + } + } + } + }); + let unordered_parse_declarations = definition.unordered_modifiers.iter().map(|modifier| { + let name = modifier.ident(); + let type_ = modifier.type_of_check(); + quote! { + let mut #name : #type_ = std::default::Default::default(); + } + }); + let unordered_parse = definition + .unordered_modifiers + .iter() + .map(|modifier| match modifier { + DotModifierRef::Direct { + name, + value, + type_: None, + .. + } => { + let name = name.ident(); + let token_variant = value.dot_capitalized(); + quote! { + #token_type :: #token_variant => { + if #name { + #return_error_ref; + } + #name = true; + } + } + } + DotModifierRef::Direct { + name, + value, + type_: Some(type_), + .. + } => { + let variable = name.ident(); + let token_variant = value.dot_capitalized(); + let enum_variant = value.variant_capitalized(); + quote! { + #token_type :: #token_variant => { + if #variable.is_some() { + #return_error_ref; + } + #variable = Some(#type_ :: #enum_variant); + } + } + } + DotModifierRef::Indirect { value, name, .. } => { + let variable = name.ident(); + let type_ = value.type_.as_ref().unwrap(); + let alternatives = value.alternatives.iter().map(|alt| { + let token_variant = alt.dot_capitalized(); + let enum_variant = alt.variant_capitalized(); + quote! { + #token_type :: #token_variant => { + if #variable.is_some() { + #return_error_ref; + } + #variable = Some(#type_ :: #enum_variant); + } + } + }); + quote! { + #(#alternatives)* + } + } + }); + let unordered_parse_validations = + definition + .unordered_modifiers + .iter() + .map(|modifier| match modifier { + DotModifierRef::Direct { + optional: false, + name, + type_: None, + .. + } => { + let variable = name.ident(); + quote! { + if !#variable { + #return_error; + } + } + } + DotModifierRef::Direct { + optional: false, + name, + type_: Some(_), + .. + } => { + let variable = name.ident(); + quote! { + let #variable = match #variable { + Some(x) => x, + None => #return_error + }; + } + } + DotModifierRef::Indirect { + optional: false, + name, + .. + } => { + let variable = name.ident(); + quote! { + let #variable = match #variable { + Some(x) => x, + None => #return_error + }; + } + } + DotModifierRef::Direct { optional: true, .. } + | DotModifierRef::Indirect { optional: true, .. } => TokenStream::new(), + }); + let arguments_parse = definition.arguments.0.iter().enumerate().map(|(idx, arg)| { + let comma = if idx == 0 || arg.pre_pipe { + quote! { empty } + } else { + quote! { any.verify(|t| *t == #token_type::Comma).void() } + }; + let pre_bracket = if arg.pre_bracket { + quote! { + any.verify(|t| *t == #token_type::LBracket).void() + } + } else { + quote! { + empty + } + }; + let pre_pipe = if arg.pre_pipe { + quote! { + any.verify(|t| *t == #token_type::Pipe).void() + } + } else { + quote! { + empty + } + }; + let can_be_negated = if arg.can_be_negated { + quote! { + opt(any.verify(|t| *t == #token_type::Not)).map(|o| o.is_some()) + } + } else { + quote! { + empty + } + }; + let operand = { + quote! { + ParsedOperandStr::parse + } + }; + let post_bracket = if arg.post_bracket { + quote! { + any.verify(|t| *t == #token_type::RBracket).void() + } + } else { + quote! { + empty + } + }; + let unified = if arg.unified { + quote! { + opt(any.verify(|t| *t == #token_type::DotUnified).void()).map(|u| u.is_some()) + } + } else { + quote! { + empty + } + }; + let pattern = quote! { + (#comma, #pre_bracket, #pre_pipe, #can_be_negated, #operand, #post_bracket, #unified) + }; + let arg_name = &arg.ident; + if arg.unified && arg.can_be_negated { + panic!("TODO: argument can't be both prefixed by `!` and suffixed by `.unified`") + } + let inner_parser = if arg.unified { + quote! { + #pattern.map(|(_, _, _, _, name, _, unified)| (name, unified)) + } + } else if arg.can_be_negated { + quote! { + #pattern.map(|(_, _, _, negated, name, _, _)| (negated, name)) + } + } else { + quote! { + #pattern.map(|(_, _, _, _, name, _, _)| name) + } + }; + if arg.optional { + quote! { + let #arg_name = opt(#inner_parser).parse_next(stream)?; + } + } else { + quote! { + let #arg_name = #inner_parser.parse_next(stream)?; + } + } + }); + let fn_args = definition.function_arguments(); + let fn_name = format_ident!("{}_{}", opcode, fn_idx); + let fn_call = quote! { + #fn_name(&mut stream.state, #(#fn_args),* ) + }; + quote! { + #(#unordered_parse_declarations)* + #(#ordered_parse_declarations)* + { + let mut stream = ReverseStream(modifiers); + #(#ordered_parse)* + let mut stream: &[#token_type] = stream.0; + for token in stream.iter().copied() { + match token { + #(#unordered_parse)* + _ => #return_error_ref + } + } + } + #(#unordered_parse_validations)* + #(#arguments_parse)* + #fn_call + } +} + +fn write_definitions_into_tokens<'a>( + defs: &'a FxHashMap, + special_definitions: impl Iterator, + variants: &mut Punctuated, +) -> (Vec<&'a Ident>, FxHashSet<&'a parser::DotModifier>) { + let mut all_opcodes = Vec::new(); + let mut all_modifiers = FxHashSet::default(); + for (opcode, definitions) in defs.iter() { + all_opcodes.push(opcode); + let opcode_as_string = opcode.to_string(); + let variant_name = Ident::new(&capitalize(&opcode_as_string), opcode.span()); + let arg: Variant = syn::parse_quote! { + #[token(#opcode_as_string)] + #variant_name + }; + variants.push(arg); + for definition in definitions.definitions.iter() { + for modifier in definition.possible_modifiers.iter() { + all_modifiers.insert(modifier); + } + } + } + for opcode in special_definitions { + all_opcodes.push(opcode); + let opcode_as_string = opcode.to_string(); + let variant_name = Ident::new(&capitalize(&opcode_as_string), opcode.span()); + let arg: Variant = syn::parse_quote! { + #[token(#opcode_as_string)] + #variant_name + }; + variants.push(arg); + } + for modifier in all_modifiers.iter() { + let modifier_as_string = modifier.to_string(); + let variant_name = modifier.dot_capitalized(); + let arg: Variant = syn::parse_quote! { + #[token(#modifier_as_string)] + #variant_name + }; + variants.push(arg); + } + variants.push(parse_quote! { + #[token(".unified")] + DotUnified + }); + (all_opcodes, all_modifiers) +} + +fn capitalize(s: &str) -> String { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } +} + +fn multihash_extend(multimap: &mut FxHashMap>, k: K, v: V) { + match multimap.entry(k) { + hash_map::Entry::Occupied(mut entry) => entry.get_mut().push(v), + hash_map::Entry::Vacant(entry) => { + entry.insert(vec![v]); + } + } +} + +enum DotModifierRef { + Direct { + optional: bool, + value: parser::DotModifier, + name: parser::DotModifier, + type_: Option, + }, + Indirect { + optional: bool, + name: parser::DotModifier, + value: Rc, + }, +} + +impl DotModifierRef { + fn ident(&self) -> Ident { + match self { + DotModifierRef::Direct { name, .. } => name.ident(), + DotModifierRef::Indirect { name, .. } => name.ident(), + } + } + + fn type_of(&self) -> Option { + Some(match self { + DotModifierRef::Direct { + optional: true, + type_: None, + .. + } => syn::parse_quote! { bool }, + DotModifierRef::Direct { + optional: false, + type_: None, + .. + } => return None, + DotModifierRef::Direct { + optional: true, + type_: Some(type_), + .. + } => syn::parse_quote! { Option<#type_> }, + DotModifierRef::Direct { + optional: false, + type_: Some(type_), + .. + } => type_.clone(), + DotModifierRef::Indirect { + optional, value, .. + } => { + let type_ = value + .type_ + .as_ref() + .expect("Indirect modifer must have a type"); + if *optional { + syn::parse_quote! { Option<#type_> } + } else { + type_.clone() + } + } + }) + } + + fn type_of_check(&self) -> syn::Type { + match self { + DotModifierRef::Direct { type_: None, .. } => syn::parse_quote! { bool }, + DotModifierRef::Direct { + type_: Some(type_), .. + } => syn::parse_quote! { Option<#type_> }, + DotModifierRef::Indirect { value, .. } => { + let type_ = value + .type_ + .as_ref() + .expect("Indirect modifer must have a type"); + syn::parse_quote! { Option<#type_> } + } + } + } +} + +#[proc_macro] +pub fn generate_instruction_type(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(tokens as ptx_parser_macros_impl::GenerateInstructionType); + let mut result = proc_macro2::TokenStream::new(); + input.emit_arg_types(&mut result); + input.emit_instruction_type(&mut result); + input.emit_visit(&mut result); + input.emit_visit_mut(&mut result); + input.emit_visit_map(&mut result); + result.into() +} diff --git a/ptx_parser_macros_impl/Cargo.toml b/ptx_parser_macros_impl/Cargo.toml new file mode 100644 index 00000000..96f3b749 --- /dev/null +++ b/ptx_parser_macros_impl/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "ptx_parser_macros_impl" +version = "0.0.0" +authors = ["Andrzej Janik "] +edition = "2021" + +[lib] + +[dependencies] +syn = { version = "2.0.67", features = ["extra-traits", "full"] } +quote = "1.0" +proc-macro2 = "1.0.86" +rustc-hash = "2.0.0" diff --git a/ptx_parser_macros_impl/src/lib.rs b/ptx_parser_macros_impl/src/lib.rs new file mode 100644 index 00000000..2f2c87a0 --- /dev/null +++ b/ptx_parser_macros_impl/src/lib.rs @@ -0,0 +1,881 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote, ToTokens}; +use syn::{ + braced, parse::Parse, punctuated::Punctuated, token, Expr, Ident, LitBool, PathSegment, Token, + Type, TypeParam, Visibility, +}; + +pub mod parser; + +pub struct GenerateInstructionType { + pub visibility: Option, + pub name: Ident, + pub type_parameters: Punctuated, + pub short_parameters: Punctuated, + pub variants: Punctuated, +} + +impl GenerateInstructionType { + pub fn emit_arg_types(&self, tokens: &mut TokenStream) { + for v in self.variants.iter() { + v.emit_type(&self.visibility, tokens); + } + } + + pub fn emit_instruction_type(&self, tokens: &mut TokenStream) { + let vis = &self.visibility; + let type_name = &self.name; + let type_parameters = &self.type_parameters; + let variants = self.variants.iter().map(|v| v.emit_variant()); + quote! { + #vis enum #type_name<#type_parameters> { + #(#variants),* + } + } + .to_tokens(tokens); + } + + pub fn emit_visit(&self, tokens: &mut TokenStream) { + self.emit_visit_impl(VisitKind::Ref, tokens, InstructionVariant::emit_visit) + } + + pub fn emit_visit_mut(&self, tokens: &mut TokenStream) { + self.emit_visit_impl( + VisitKind::RefMut, + tokens, + InstructionVariant::emit_visit_mut, + ) + } + + pub fn emit_visit_map(&self, tokens: &mut TokenStream) { + self.emit_visit_impl(VisitKind::Map, tokens, InstructionVariant::emit_visit_map) + } + + fn emit_visit_impl( + &self, + kind: VisitKind, + tokens: &mut TokenStream, + mut fn_: impl FnMut(&InstructionVariant, &Ident, &mut TokenStream), + ) { + let type_name = &self.name; + let type_parameters = &self.type_parameters; + let short_parameters = &self.short_parameters; + let mut inner_tokens = TokenStream::new(); + for v in self.variants.iter() { + fn_(v, type_name, &mut inner_tokens); + } + let visit_ref = kind.reference(); + let visitor_type = format_ident!("Visitor{}", kind.type_suffix()); + let visit_fn = format_ident!("visit{}", kind.fn_suffix()); + let (type_parameters, visitor_parameters, return_type) = if kind == VisitKind::Map { + ( + quote! { <#type_parameters, To: Operand, Err> }, + quote! { <#short_parameters, To, Err> }, + quote! { std::result::Result<#type_name, Err> }, + ) + } else { + ( + quote! { <#type_parameters, Err> }, + quote! { <#short_parameters, Err> }, + quote! { std::result::Result<(), Err> }, + ) + }; + quote! { + pub fn #visit_fn #type_parameters (i: #visit_ref #type_name<#short_parameters>, visitor: &mut impl #visitor_type #visitor_parameters ) -> #return_type { + Ok(match i { + #inner_tokens + }) + } + }.to_tokens(tokens); + if kind == VisitKind::Map { + return; + } + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum VisitKind { + Ref, + RefMut, + Map, +} + +impl VisitKind { + fn fn_suffix(self) -> &'static str { + match self { + VisitKind::Ref => "", + VisitKind::RefMut => "_mut", + VisitKind::Map => "_map", + } + } + + fn type_suffix(self) -> &'static str { + match self { + VisitKind::Ref => "", + VisitKind::RefMut => "Mut", + VisitKind::Map => "Map", + } + } + + fn reference(self) -> Option { + match self { + VisitKind::Ref => Some(quote! { & }), + VisitKind::RefMut => Some(quote! { &mut }), + VisitKind::Map => None, + } + } +} + +impl Parse for GenerateInstructionType { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let visibility = if !input.peek(Token![enum]) { + Some(input.parse::()?) + } else { + None + }; + input.parse::()?; + let name = input.parse::()?; + input.parse::()?; + let type_parameters = Punctuated::parse_separated_nonempty(input)?; + let short_parameters = type_parameters + .iter() + .map(|p: &TypeParam| p.ident.clone()) + .collect(); + input.parse::]>()?; + let variants_buffer; + braced!(variants_buffer in input); + let variants = variants_buffer.parse_terminated(InstructionVariant::parse, Token![,])?; + Ok(Self { + visibility, + name, + type_parameters, + short_parameters, + variants, + }) + } +} + +pub struct InstructionVariant { + pub name: Ident, + pub type_: Option>, + pub space: Option, + pub data: Option, + pub arguments: Option, + pub visit: Option, + pub visit_mut: Option, + pub map: Option, +} + +impl InstructionVariant { + fn args_name(&self) -> Ident { + format_ident!("{}Args", self.name) + } + + fn emit_variant(&self) -> TokenStream { + let name = &self.name; + let data = match &self.data { + None => { + quote! {} + } + Some(data_type) => { + quote! { + data: #data_type, + } + } + }; + let arguments = match &self.arguments { + None => { + quote! {} + } + Some(args) => { + let args_name = self.args_name(); + match &args { + Arguments::Def(InstructionArguments { generic: None, .. }) => { + quote! { + arguments: #args_name, + } + } + Arguments::Def(InstructionArguments { + generic: Some(generics), + .. + }) => { + quote! { + arguments: #args_name <#generics>, + } + } + Arguments::Decl(type_) => quote! { + arguments: #type_, + }, + } + } + }; + quote! { + #name { #data #arguments } + } + } + + fn emit_visit(&self, enum_: &Ident, tokens: &mut TokenStream) { + self.emit_visit_impl(&self.visit, enum_, tokens, InstructionArguments::emit_visit) + } + + fn emit_visit_mut(&self, enum_: &Ident, tokens: &mut TokenStream) { + self.emit_visit_impl( + &self.visit_mut, + enum_, + tokens, + InstructionArguments::emit_visit_mut, + ) + } + + fn emit_visit_impl( + &self, + visit_fn: &Option, + enum_: &Ident, + tokens: &mut TokenStream, + mut fn_: impl FnMut(&InstructionArguments, &Option>, &Option) -> TokenStream, + ) { + let name = &self.name; + let arguments = match &self.arguments { + None => { + quote! { + #enum_ :: #name { .. } => { } + } + .to_tokens(tokens); + return; + } + Some(Arguments::Decl(_)) => { + quote! { + #enum_ :: #name { data, arguments } => { #visit_fn } + } + .to_tokens(tokens); + return; + } + Some(Arguments::Def(args)) => args, + }; + let data = &self.data.as_ref().map(|_| quote! { data,}); + let arg_calls = fn_(arguments, &self.type_, &self.space); + quote! { + #enum_ :: #name { #data arguments } => { + #arg_calls + } + } + .to_tokens(tokens); + } + + fn emit_visit_map(&self, enum_: &Ident, tokens: &mut TokenStream) { + let name = &self.name; + let data = &self.data.as_ref().map(|_| quote! { data,}); + let arguments = match self.arguments { + None => None, + Some(Arguments::Decl(_)) => { + let map = self.map.as_ref().unwrap(); + quote! { + #enum_ :: #name { #data arguments } => { + #map + } + } + .to_tokens(tokens); + return; + } + Some(Arguments::Def(ref def)) => Some(def), + }; + let arguments_ident = &self.arguments.as_ref().map(|_| quote! { arguments,}); + let mut arg_calls = None; + let arguments_init = arguments.as_ref().map(|arguments| { + let arg_type = self.args_name(); + arg_calls = Some(arguments.emit_visit_map(&self.type_, &self.space)); + let arg_names = arguments.fields.iter().map(|arg| &arg.name); + quote! { + arguments: #arg_type { #(#arg_names),* } + } + }); + quote! { + #enum_ :: #name { #data #arguments_ident } => { + #arg_calls + #enum_ :: #name { #data #arguments_init } + } + } + .to_tokens(tokens); + } + + fn emit_type(&self, vis: &Option, tokens: &mut TokenStream) { + let arguments = match self.arguments { + Some(Arguments::Def(ref a)) => a, + Some(Arguments::Decl(_)) => return, + None => return, + }; + let name = self.args_name(); + let type_parameters = if arguments.generic.is_some() { + Some(quote! { }) + } else { + None + }; + let fields = arguments.fields.iter().map(|f| f.emit_field(vis)); + quote! { + #vis struct #name #type_parameters { + #(#fields),* + } + } + .to_tokens(tokens); + } +} + +impl Parse for InstructionVariant { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let name = input.parse::()?; + let properties_buffer; + braced!(properties_buffer in input); + let properties = properties_buffer.parse_terminated(VariantProperty::parse, Token![,])?; + let mut type_ = None; + let mut space = None; + let mut data = None; + let mut arguments = None; + let mut visit = None; + let mut visit_mut = None; + let mut map = None; + for property in properties { + match property { + VariantProperty::Type(t) => type_ = Some(t), + VariantProperty::Space(s) => space = Some(s), + VariantProperty::Data(d) => data = Some(d), + VariantProperty::Arguments(a) => arguments = Some(a), + VariantProperty::Visit(e) => visit = Some(e), + VariantProperty::VisitMut(e) => visit_mut = Some(e), + VariantProperty::Map(e) => map = Some(e), + } + } + Ok(Self { + name, + type_, + space, + data, + arguments, + visit, + visit_mut, + map, + }) + } +} + +enum VariantProperty { + Type(Option), + Space(Expr), + Data(Type), + Arguments(Arguments), + Visit(Expr), + VisitMut(Expr), + Map(Expr), +} + +impl VariantProperty { + pub fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + Ok(if lookahead.peek(Token![type]) { + input.parse::()?; + input.parse::()?; + VariantProperty::Type(if input.peek(Token![!]) { + input.parse::()?; + None + } else { + Some(input.parse::()?) + }) + } else if lookahead.peek(Ident) { + let key = input.parse::()?; + match &*key.to_string() { + "data" => { + input.parse::()?; + VariantProperty::Data(input.parse::()?) + } + "space" => { + input.parse::()?; + VariantProperty::Space(input.parse::()?) + } + "arguments" => { + let generics = if input.peek(Token![<]) { + input.parse::()?; + let gen_params = + Punctuated::::parse_separated_nonempty(input)?; + input.parse::]>()?; + Some(gen_params) + } else { + None + }; + input.parse::()?; + if input.peek(token::Brace) { + let fields; + braced!(fields in input); + VariantProperty::Arguments(Arguments::Def(InstructionArguments::parse( + generics, &fields, + )?)) + } else { + VariantProperty::Arguments(Arguments::Decl(input.parse::()?)) + } + } + "visit" => { + input.parse::()?; + VariantProperty::Visit(input.parse::()?) + } + "visit_mut" => { + input.parse::()?; + VariantProperty::VisitMut(input.parse::()?) + } + "map" => { + input.parse::()?; + VariantProperty::Map(input.parse::()?) + } + x => { + return Err(syn::Error::new( + key.span(), + format!( + "Unexpected key `{}`. Expected `type`, `data`, `arguments`, `visit, `visit_mut` or `map`.", + x + ), + )) + } + } + } else { + return Err(lookahead.error()); + }) + } +} + +pub enum Arguments { + Decl(Type), + Def(InstructionArguments), +} + +pub struct InstructionArguments { + pub generic: Option>, + pub fields: Punctuated, +} + +impl InstructionArguments { + pub fn parse( + generic: Option>, + input: syn::parse::ParseStream, + ) -> syn::Result { + let fields = Punctuated::::parse_terminated_with( + input, + ArgumentField::parse, + )?; + Ok(Self { generic, fields }) + } + + fn emit_visit( + &self, + parent_type: &Option>, + parent_space: &Option, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit) + } + + fn emit_visit_mut( + &self, + parent_type: &Option>, + parent_space: &Option, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_mut) + } + + fn emit_visit_map( + &self, + parent_type: &Option>, + parent_space: &Option, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, ArgumentField::emit_visit_map) + } + + fn emit_visit_impl( + &self, + parent_type: &Option>, + parent_space: &Option, + mut fn_: impl FnMut(&ArgumentField, &Option>, &Option, bool) -> TokenStream, + ) -> TokenStream { + let is_ident = if let Some(ref generic) = self.generic { + generic.len() > 1 + } else { + false + }; + let field_calls = self + .fields + .iter() + .map(|f| fn_(f, parent_type, parent_space, is_ident)); + quote! { + #(#field_calls)* + } + } +} + +pub struct ArgumentField { + pub name: Ident, + pub is_dst: bool, + pub repr: Type, + pub space: Option, + pub type_: Option, + pub relaxed_type_check: bool, +} + +impl ArgumentField { + fn parse_block( + input: syn::parse::ParseStream, + ) -> syn::Result<(Type, Option, Option, Option, bool)> { + let content; + braced!(content in input); + let all_fields = + Punctuated::::parse_terminated_with(&content, |content| { + let lookahead = content.lookahead1(); + Ok(if lookahead.peek(Token![type]) { + content.parse::()?; + content.parse::()?; + ExprOrPath::Type(content.parse::()?) + } else if lookahead.peek(Ident) { + let name_ident = content.parse::()?; + content.parse::()?; + match &*name_ident.to_string() { + "relaxed_type_check" => { + ExprOrPath::RelaxedTypeCheck(content.parse::()?.value) + } + "repr" => ExprOrPath::Repr(content.parse::()?), + "space" => ExprOrPath::Space(content.parse::()?), + "dst" => { + let ident = content.parse::()?; + ExprOrPath::Dst(ident.value) + } + name => { + return Err(syn::Error::new( + name_ident.span(), + format!("Unexpected key `{}`, expected `repr` or `space", name), + )) + } + } + } else { + return Err(lookahead.error()); + }) + })?; + let mut repr = None; + let mut type_ = None; + let mut space = None; + let mut is_dst = None; + let mut relaxed_type_check = false; + for exp_or_path in all_fields { + match exp_or_path { + ExprOrPath::Repr(r) => repr = Some(r), + ExprOrPath::Type(t) => type_ = Some(t), + ExprOrPath::Space(s) => space = Some(s), + ExprOrPath::Dst(x) => is_dst = Some(x), + ExprOrPath::RelaxedTypeCheck(relaxed) => relaxed_type_check = relaxed, + } + } + Ok((repr.unwrap(), type_, space, is_dst, relaxed_type_check)) + } + + fn parse_basic(input: &syn::parse::ParseBuffer) -> syn::Result { + input.parse::() + } + + fn emit_visit( + &self, + parent_type: &Option>, + parent_space: &Option, + is_ident: bool, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, is_ident, false) + } + + fn emit_visit_mut( + &self, + parent_type: &Option>, + parent_space: &Option, + is_ident: bool, + ) -> TokenStream { + self.emit_visit_impl(parent_type, parent_space, is_ident, true) + } + + fn emit_visit_impl( + &self, + parent_type: &Option>, + parent_space: &Option, + is_ident: bool, + is_mut: bool, + ) -> TokenStream { + let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) { + (Some(type_), _) => (false, Some(type_)), + (None, None) => panic!("No type set"), + (None, Some(None)) => (true, None), + (None, Some(Some(type_))) => (false, Some(type_)), + }; + let space = self + .space + .as_ref() + .or(parent_space.as_ref()) + .map(|space| quote! { #space }) + .unwrap_or_else(|| quote! { StateSpace::Reg }); + let is_dst = self.is_dst; + let relaxed_type_check = self.relaxed_type_check; + let name = &self.name; + let type_space = if is_typeless { + quote! { + let type_space = None; + } + } else { + quote! { + let type_ = #type_; + let space = #space; + let type_space = Some((std::borrow::Borrow::::borrow(&type_), space)); + } + }; + if is_ident { + if is_mut { + quote! { + { + #type_space + visitor.visit_ident(&mut arguments.#name, type_space, #is_dst, #relaxed_type_check)?; + } + } + } else { + quote! { + { + #type_space + visitor.visit_ident(& arguments.#name, type_space, #is_dst, #relaxed_type_check)?; + } + } + } + } else { + let (operand_fn, arguments_name) = if is_mut { + ( + quote! { + VisitOperand::visit_mut + }, + quote! { + &mut arguments.#name + }, + ) + } else { + ( + quote! { + VisitOperand::visit + }, + quote! { + & arguments.#name + }, + ) + }; + quote! {{ + #type_space + #operand_fn(#arguments_name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))?; + }} + } + } + + fn emit_visit_map( + &self, + parent_type: &Option>, + parent_space: &Option, + is_ident: bool, + ) -> TokenStream { + let (is_typeless, type_) = match (self.type_.as_ref(), parent_type) { + (Some(type_), _) => (false, Some(type_)), + (None, None) => panic!("No type set"), + (None, Some(None)) => (true, None), + (None, Some(Some(type_))) => (false, Some(type_)), + }; + let space = self + .space + .as_ref() + .or(parent_space.as_ref()) + .map(|space| quote! { #space }) + .unwrap_or_else(|| quote! { StateSpace::Reg }); + let is_dst = self.is_dst; + let relaxed_type_check = self.relaxed_type_check; + let name = &self.name; + let type_space = if is_typeless { + quote! { + let type_space = None; + } + } else { + quote! { + let type_ = #type_; + let space = #space; + let type_space = Some((std::borrow::Borrow::::borrow(&type_), space)); + } + }; + let map_call = if is_ident { + quote! { + visitor.visit_ident(arguments.#name, type_space, #is_dst, #relaxed_type_check)? + } + } else { + quote! { + MapOperand::map(arguments.#name, |x| visitor.visit(x, type_space, #is_dst, #relaxed_type_check))? + } + }; + quote! { + let #name = { + #type_space + #map_call + }; + } + } + + fn is_dst(name: &Ident) -> syn::Result { + if name.to_string().starts_with("dst") { + Ok(true) + } else if name.to_string().starts_with("src") { + Ok(false) + } else { + return Err(syn::Error::new( + name.span(), + format!( + "Could not guess if `{}` is a read or write argument. Name should start with `dst` or `src`", + name + ), + )); + } + } + + fn emit_field(&self, vis: &Option) -> TokenStream { + let name = &self.name; + let type_ = &self.repr; + quote! { + #vis #name: #type_ + } + } +} + +impl Parse for ArgumentField { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let name = input.parse::()?; + + input.parse::()?; + let lookahead = input.lookahead1(); + let (repr, type_, space, is_dst, relaxed_type_check) = if lookahead.peek(token::Brace) { + Self::parse_block(input)? + } else if lookahead.peek(syn::Ident) { + (Self::parse_basic(input)?, None, None, None, false) + } else { + return Err(lookahead.error()); + }; + let is_dst = match is_dst { + Some(x) => x, + None => Self::is_dst(&name)?, + }; + Ok(Self { + name, + is_dst, + repr, + type_, + space, + relaxed_type_check + }) + } +} + +enum ExprOrPath { + Repr(Type), + Type(Expr), + Space(Expr), + Dst(bool), + RelaxedTypeCheck(bool), +} + +#[cfg(test)] +mod tests { + use super::*; + use proc_macro2::Span; + use quote::{quote, ToTokens}; + + fn to_string(x: impl ToTokens) -> String { + quote! { #x }.to_string() + } + + #[test] + fn parse_argument_field_basic() { + let input = quote! { + dst: P::Operand + }; + let arg = syn::parse2::(input).unwrap(); + assert_eq!("dst", arg.name.to_string()); + assert_eq!("P :: Operand", to_string(arg.repr)); + assert!(matches!(arg.type_, None)); + } + + #[test] + fn parse_argument_field_block() { + let input = quote! { + dst: { + type: ScalarType::U32, + space: StateSpace::Global, + repr: P::Operand, + } + }; + let arg = syn::parse2::(input).unwrap(); + assert_eq!("dst", arg.name.to_string()); + assert_eq!("ScalarType :: U32", to_string(arg.type_.unwrap())); + assert_eq!("StateSpace :: Global", to_string(arg.space.unwrap())); + assert_eq!("P :: Operand", to_string(arg.repr)); + } + + #[test] + fn parse_argument_field_block_untyped() { + let input = quote! { + dst: { + repr: P::Operand, + } + }; + let arg = syn::parse2::(input).unwrap(); + assert_eq!("dst", arg.name.to_string()); + assert_eq!("P :: Operand", to_string(arg.repr)); + assert!(matches!(arg.type_, None)); + } + + #[test] + fn parse_variant_complex() { + let input = quote! { + Ld { + type: ScalarType::U32, + space: StateSpace::Global, + data: LdDetails, + arguments

: { + dst: { + repr: P::Operand, + type: ScalarType::U32, + space: StateSpace::Shared, + }, + src: P::Operand, + }, + } + }; + let variant = syn::parse2::(input).unwrap(); + assert_eq!("Ld", variant.name.to_string()); + assert_eq!("ScalarType :: U32", to_string(variant.type_.unwrap())); + assert_eq!("StateSpace :: Global", to_string(variant.space.unwrap())); + assert_eq!("LdDetails", to_string(variant.data.unwrap())); + let arguments = if let Some(Arguments::Def(a)) = variant.arguments { + a + } else { + panic!() + }; + assert_eq!("P", to_string(arguments.generic)); + let mut fields = arguments.fields.into_iter(); + let dst = fields.next().unwrap(); + assert_eq!("P :: Operand", to_string(dst.repr)); + assert_eq!("ScalarType :: U32", to_string(dst.type_)); + assert_eq!("StateSpace :: Shared", to_string(dst.space)); + let src = fields.next().unwrap(); + assert_eq!("P :: Operand", to_string(src.repr)); + assert!(matches!(src.type_, None)); + assert!(matches!(src.space, None)); + } + + #[test] + fn visit_variant_empty() { + let input = quote! { + Ret { + data: RetData + } + }; + let variant = syn::parse2::(input).unwrap(); + let mut output = TokenStream::new(); + variant.emit_visit(&Ident::new("Instruction", Span::call_site()), &mut output); + assert_eq!(output.to_string(), "Instruction :: Ret { .. } => { }"); + } +} diff --git a/ptx_parser_macros_impl/src/parser.rs b/ptx_parser_macros_impl/src/parser.rs new file mode 100644 index 00000000..f1cd7383 --- /dev/null +++ b/ptx_parser_macros_impl/src/parser.rs @@ -0,0 +1,844 @@ +use proc_macro2::Span; +use proc_macro2::TokenStream; +use quote::quote; +use quote::ToTokens; +use rustc_hash::FxHashMap; +use std::fmt::Write; +use syn::bracketed; +use syn::parse::Peek; +use syn::punctuated::Punctuated; +use syn::spanned::Spanned; +use syn::LitInt; +use syn::Type; +use syn::{braced, parse::Parse, token, Ident, ItemEnum, Token}; + +pub struct ParseDefinitions { + pub token_type: ItemEnum, + pub additional_enums: FxHashMap, + pub definitions: Vec, +} + +impl Parse for ParseDefinitions { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let token_type = input.parse::()?; + let mut additional_enums = FxHashMap::default(); + while input.peek(Token![#]) { + let enum_ = input.parse::()?; + additional_enums.insert(enum_.ident.clone(), enum_); + } + let mut definitions = Vec::new(); + while !input.is_empty() { + definitions.push(input.parse::()?); + } + Ok(Self { + token_type, + additional_enums, + definitions, + }) + } +} + +pub struct OpcodeDefinition(pub Patterns, pub Vec); + +impl Parse for OpcodeDefinition { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let patterns = input.parse::()?; + let mut rules = Vec::new(); + while Rule::peek(input) { + rules.push(input.parse::()?); + input.parse::()?; + } + Ok(Self(patterns, rules)) + } +} + +pub struct Patterns(pub Vec<(OpcodeDecl, CodeBlock)>); + +impl Parse for Patterns { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut result = Vec::new(); + loop { + if !OpcodeDecl::peek(input) { + break; + } + let decl = input.parse::()?; + let code_block = input.parse::()?; + result.push((decl, code_block)) + } + Ok(Self(result)) + } +} + +pub struct OpcodeDecl(pub Instruction, pub Arguments); + +impl OpcodeDecl { + fn peek(input: syn::parse::ParseStream) -> bool { + Instruction::peek(input) && !input.peek2(Token![=]) + } +} + +impl Parse for OpcodeDecl { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(Self( + input.parse::()?, + input.parse::()?, + )) + } +} + +pub struct CodeBlock { + pub special: bool, + pub code: proc_macro2::Group, +} + +impl Parse for CodeBlock { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + let (special, code) = if lookahead.peek(Token![<]) { + input.parse::()?; + input.parse::()?; + //input.parse::]>()?; + (true, input.parse::()?) + } else if lookahead.peek(Token![=]) { + input.parse::()?; + input.parse::]>()?; + (false, input.parse::()?) + } else { + return Err(lookahead.error()); + }; + Ok(Self { special, code }) + } +} + +pub struct Rule { + pub modifier: Option, + pub type_: Option, + pub alternatives: Vec, +} + +impl Rule { + fn peek(input: syn::parse::ParseStream) -> bool { + DotModifier::peek(input) + || (input.peek(Ident) && input.peek2(Token![=]) && !input.peek3(Token![>])) + } + + fn parse_alternatives(input: syn::parse::ParseStream) -> syn::Result> { + let mut result = Vec::new(); + Self::parse_with_alternative(input, &mut result)?; + loop { + if !input.peek(Token![,]) { + break; + } + input.parse::()?; + Self::parse_with_alternative(input, &mut result)?; + } + Ok(result) + } + + fn parse_with_alternative( + input: &syn::parse::ParseBuffer, + result: &mut Vec, + ) -> Result<(), syn::Error> { + input.parse::()?; + let part1 = input.parse::()?; + if input.peek(token::Brace) { + result.push(DotModifier { + part1: part1.clone(), + part2: None, + }); + let suffix_content; + braced!(suffix_content in input); + let suffixes = Punctuated::::parse_separated_nonempty( + &suffix_content, + )?; + for part2 in suffixes { + result.push(DotModifier { + part1: part1.clone(), + part2: Some(part2), + }); + } + } else if IdentOrTypeSuffix::peek(input) { + let part2 = Some(IdentOrTypeSuffix::parse(input)?); + result.push(DotModifier { part1, part2 }); + } else { + result.push(DotModifier { part1, part2: None }); + } + Ok(()) + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +struct IdentOrTypeSuffix(IdentLike); + +impl IdentOrTypeSuffix { + fn span(&self) -> Span { + self.0.span() + } + + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(Token![::]) + } +} + +impl ToTokens for IdentOrTypeSuffix { + fn to_tokens(&self, tokens: &mut TokenStream) { + let ident = &self.0; + quote! { :: #ident }.to_tokens(tokens) + } +} + +impl Parse for IdentOrTypeSuffix { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + input.parse::()?; + Ok(Self(input.parse::()?)) + } +} + +impl Parse for Rule { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let (modifier, type_) = if DotModifier::peek(input) { + let modifier = Some(input.parse::()?); + if input.peek(Token![:]) { + input.parse::()?; + (modifier, Some(input.parse::()?)) + } else { + (modifier, None) + } + } else { + (None, Some(input.parse::()?)) + }; + input.parse::()?; + let content; + braced!(content in input); + let alternatives = Self::parse_alternatives(&content)?; + Ok(Self { + modifier, + type_, + alternatives, + }) + } +} + +pub struct Instruction { + pub name: Ident, + pub modifiers: Vec, +} +impl Instruction { + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(Ident) + } +} + +impl Parse for Instruction { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let instruction = input.parse::()?; + let mut modifiers = Vec::new(); + loop { + if !MaybeDotModifier::peek(input) { + break; + } + modifiers.push(MaybeDotModifier::parse(input)?); + } + Ok(Self { + name: instruction, + modifiers, + }) + } +} + +pub struct MaybeDotModifier { + pub optional: bool, + pub modifier: DotModifier, +} + +impl MaybeDotModifier { + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(token::Brace) || DotModifier::peek(input) + } +} + +impl Parse for MaybeDotModifier { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + Ok(if input.peek(token::Brace) { + let content; + braced!(content in input); + let modifier = DotModifier::parse(&content)?; + Self { + modifier, + optional: true, + } + } else { + let modifier = DotModifier::parse(input)?; + Self { + modifier, + optional: false, + } + }) + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +pub struct DotModifier { + part1: IdentLike, + part2: Option, +} + +impl std::fmt::Display for DotModifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, ".")?; + self.part1.fmt(f)?; + if let Some(ref part2) = self.part2 { + write!(f, "::")?; + part2.0.fmt(f)?; + } + Ok(()) + } +} + +impl std::fmt::Debug for DotModifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self, f) + } +} + +impl DotModifier { + pub fn span(&self) -> Span { + let part1 = self.part1.span(); + if let Some(ref part2) = self.part2 { + part1.join(part2.span()).unwrap_or(part1) + } else { + part1 + } + } + + pub fn ident(&self) -> Ident { + let mut result = String::new(); + write!(&mut result, "{}", self.part1).unwrap(); + if let Some(ref part2) = self.part2 { + write!(&mut result, "_{}", part2.0).unwrap(); + } else { + match self.part1 { + IdentLike::Type(_) | IdentLike::Const(_) => result.push('_'), + IdentLike::Ident(_) | IdentLike::Integer(_) => {} + } + } + Ident::new(&result.to_ascii_lowercase(), self.span()) + } + + pub fn variant_capitalized(&self) -> Ident { + self.capitalized_impl(String::new()) + } + + pub fn dot_capitalized(&self) -> Ident { + self.capitalized_impl("Dot".to_string()) + } + + fn capitalized_impl(&self, prefix: String) -> Ident { + let mut temp = String::new(); + write!(&mut temp, "{}", &self.part1).unwrap(); + if let Some(IdentOrTypeSuffix(ref part2)) = self.part2 { + write!(&mut temp, "_{}", part2).unwrap(); + } + let mut result = prefix; + let mut capitalize = true; + for c in temp.chars() { + if c == '_' { + capitalize = true; + continue; + } + // Special hack to emit `BF16`` instead of `Bf16`` + let c = if capitalize || c == 'f' && result.ends_with('B') { + capitalize = false; + c.to_ascii_uppercase() + } else { + c + }; + result.push(c); + } + Ident::new(&result, self.span()) + } + + pub fn tokens(&self) -> TokenStream { + let part1 = &self.part1; + let part2 = &self.part2; + match self.part2 { + None => quote! { . #part1 }, + Some(_) => quote! { . #part1 #part2 }, + } + } + + fn peek(input: syn::parse::ParseStream) -> bool { + input.peek(Token![.]) + } +} + +impl Parse for DotModifier { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + input.parse::()?; + let part1 = input.parse::()?; + if IdentOrTypeSuffix::peek(input) { + let part2 = Some(IdentOrTypeSuffix::parse(input)?); + Ok(Self { part1, part2 }) + } else { + Ok(Self { part1, part2: None }) + } + } +} + +#[derive(PartialEq, Eq, Hash, Clone)] +enum IdentLike { + Type(Token![type]), + Const(Token![const]), + Ident(Ident), + Integer(LitInt), +} + +impl IdentLike { + fn span(&self) -> Span { + match self { + IdentLike::Type(c) => c.span(), + IdentLike::Const(t) => t.span(), + IdentLike::Ident(i) => i.span(), + IdentLike::Integer(l) => l.span(), + } + } +} + +impl std::fmt::Display for IdentLike { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + IdentLike::Type(_) => f.write_str("type"), + IdentLike::Const(_) => f.write_str("const"), + IdentLike::Ident(ident) => write!(f, "{}", ident), + IdentLike::Integer(integer) => write!(f, "{}", integer), + } + } +} + +impl ToTokens for IdentLike { + fn to_tokens(&self, tokens: &mut TokenStream) { + match self { + IdentLike::Type(_) => quote! { type }.to_tokens(tokens), + IdentLike::Const(_) => quote! { const }.to_tokens(tokens), + IdentLike::Ident(ident) => quote! { #ident }.to_tokens(tokens), + IdentLike::Integer(int) => quote! { #int }.to_tokens(tokens), + } + } +} + +impl Parse for IdentLike { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let lookahead = input.lookahead1(); + Ok(if lookahead.peek(Token![const]) { + IdentLike::Const(input.parse::()?) + } else if lookahead.peek(Token![type]) { + IdentLike::Type(input.parse::()?) + } else if lookahead.peek(Ident) { + IdentLike::Ident(input.parse::()?) + } else if lookahead.peek(LitInt) { + IdentLike::Integer(input.parse::()?) + } else { + return Err(lookahead.error()); + }) + } +} + +// Arguments decalaration can loook like this: +// a{, b} +// That's why we don't parse Arguments as Punctuated +#[derive(PartialEq, Eq)] +pub struct Arguments(pub Vec); + +impl Parse for Arguments { + fn parse(input: syn::parse::ParseStream) -> syn::Result { + let mut result = Vec::new(); + loop { + if input.peek(Token![,]) { + input.parse::()?; + } + let mut optional = false; + let mut can_be_negated = false; + let mut pre_pipe = false; + let ident; + let lookahead = input.lookahead1(); + if lookahead.peek(token::Brace) { + let content; + braced!(content in input); + let lookahead = content.lookahead1(); + if lookahead.peek(Token![!]) { + content.parse::()?; + can_be_negated = true; + ident = input.parse::()?; + } else if lookahead.peek(Token![,]) { + optional = true; + content.parse::()?; + ident = content.parse::()?; + } else { + return Err(lookahead.error()); + } + } else if lookahead.peek(token::Bracket) { + let bracketed; + bracketed!(bracketed in input); + if bracketed.peek(Token![|]) { + optional = true; + bracketed.parse::()?; + pre_pipe = true; + ident = bracketed.parse::()?; + } else { + let mut sub_args = Self::parse(&bracketed)?; + sub_args.0.first_mut().unwrap().pre_bracket = true; + sub_args.0.last_mut().unwrap().post_bracket = true; + if peek_brace_token(input, Token![.]) { + let optional_suffix; + braced!(optional_suffix in input); + optional_suffix.parse::()?; + let unified_ident = optional_suffix.parse::()?; + if unified_ident.to_string() != "unified" { + return Err(syn::Error::new( + unified_ident.span(), + format!("Exptected `unified`, got `{}`", unified_ident), + )); + } + for a in sub_args.0.iter_mut() { + a.unified = true; + } + } + result.extend(sub_args.0); + continue; + } + } else if lookahead.peek(Ident) { + ident = input.parse::()?; + } else if lookahead.peek(Token![|]) { + input.parse::()?; + pre_pipe = true; + ident = input.parse::()?; + } else { + break; + } + result.push(Argument { + optional, + pre_pipe, + can_be_negated, + pre_bracket: false, + ident, + post_bracket: false, + unified: false, + }); + } + Ok(Self(result)) + } +} + +// This is effectively input.peek(token::Brace) && input.peek2(Token![.]) +// input.peek2 is supposed to skip over next token, but it skips over whole +// braced token group. Not sure if it's a bug +fn peek_brace_token(input: syn::parse::ParseStream, _t: T) -> bool { + use syn::token::Token; + let cursor = input.cursor(); + cursor + .group(proc_macro2::Delimiter::Brace) + .map_or(false, |(content, ..)| T::Token::peek(content)) +} + +#[derive(PartialEq, Eq)] +pub struct Argument { + pub optional: bool, + pub pre_bracket: bool, + pub pre_pipe: bool, + pub can_be_negated: bool, + pub ident: Ident, + pub post_bracket: bool, + pub unified: bool, +} + +#[cfg(test)] +mod tests { + use super::{Arguments, DotModifier, MaybeDotModifier}; + use quote::{quote, ToTokens}; + + #[test] + fn parse_modifier_complex() { + let input = quote! { + .level::eviction_priority + }; + let modifier = syn::parse2::(input).unwrap(); + assert_eq!( + ". level :: eviction_priority", + modifier.tokens().to_string() + ); + } + + #[test] + fn parse_modifier_optional() { + let input = quote! { + { .level::eviction_priority } + }; + let maybe_modifider = syn::parse2::(input).unwrap(); + assert_eq!( + ". level :: eviction_priority", + maybe_modifider.modifier.tokens().to_string() + ); + assert!(maybe_modifider.optional); + } + + #[test] + fn parse_type_token() { + let input = quote! { + . type + }; + let maybe_modifier = syn::parse2::(input).unwrap(); + assert_eq!(". type", maybe_modifier.modifier.tokens().to_string()); + assert!(!maybe_modifier.optional); + } + + #[test] + fn arguments_memory() { + let input = quote! { + [a], b + }; + let arguments = syn::parse2::(input).unwrap(); + let a = &arguments.0[0]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(a.pre_bracket); + assert!(!a.pre_pipe); + assert!(a.post_bracket); + assert!(!a.can_be_negated); + let b = &arguments.0[1]; + assert!(!b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + } + + #[test] + fn arguments_optional() { + let input = quote! { + b{, cache_policy} + }; + let arguments = syn::parse2::(input).unwrap(); + let b = &arguments.0[0]; + assert!(!b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + let cache_policy = &arguments.0[1]; + assert!(cache_policy.optional); + assert_eq!("cache_policy", cache_policy.ident.to_string()); + assert!(!cache_policy.pre_bracket); + assert!(!cache_policy.pre_pipe); + assert!(!cache_policy.post_bracket); + assert!(!cache_policy.can_be_negated); + } + + #[test] + fn arguments_optional_pred() { + let input = quote! { + p[|q], a + }; + let arguments = syn::parse2::(input).unwrap(); + assert_eq!(arguments.0.len(), 3); + let p = &arguments.0[0]; + assert!(!p.optional); + assert_eq!("p", p.ident.to_string()); + assert!(!p.pre_bracket); + assert!(!p.pre_pipe); + assert!(!p.post_bracket); + assert!(!p.can_be_negated); + let q = &arguments.0[1]; + assert!(q.optional); + assert_eq!("q", q.ident.to_string()); + assert!(!q.pre_bracket); + assert!(q.pre_pipe); + assert!(!q.post_bracket); + assert!(!q.can_be_negated); + let a = &arguments.0[2]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(!a.pre_bracket); + assert!(!a.pre_pipe); + assert!(!a.post_bracket); + assert!(!a.can_be_negated); + } + + #[test] + fn arguments_optional_with_negate() { + let input = quote! { + b, {!}c + }; + let arguments = syn::parse2::(input).unwrap(); + assert_eq!(arguments.0.len(), 2); + let b = &arguments.0[0]; + assert!(!b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + let c = &arguments.0[1]; + assert!(!c.optional); + assert_eq!("c", c.ident.to_string()); + assert!(!c.pre_bracket); + assert!(!c.pre_pipe); + assert!(!c.post_bracket); + assert!(c.can_be_negated); + } + + #[test] + fn arguments_tex() { + let input = quote! { + d[|p], [a{, b}, c], dpdx, dpdy {, e} + }; + let arguments = syn::parse2::(input).unwrap(); + assert_eq!(arguments.0.len(), 8); + { + let d = &arguments.0[0]; + assert!(!d.optional); + assert_eq!("d", d.ident.to_string()); + assert!(!d.pre_bracket); + assert!(!d.pre_pipe); + assert!(!d.post_bracket); + assert!(!d.can_be_negated); + } + { + let p = &arguments.0[1]; + assert!(p.optional); + assert_eq!("p", p.ident.to_string()); + assert!(!p.pre_bracket); + assert!(p.pre_pipe); + assert!(!p.post_bracket); + assert!(!p.can_be_negated); + } + { + let a = &arguments.0[2]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(a.pre_bracket); + assert!(!a.pre_pipe); + assert!(!a.post_bracket); + assert!(!a.can_be_negated); + } + { + let b = &arguments.0[3]; + assert!(b.optional); + assert_eq!("b", b.ident.to_string()); + assert!(!b.pre_bracket); + assert!(!b.pre_pipe); + assert!(!b.post_bracket); + assert!(!b.can_be_negated); + } + { + let c = &arguments.0[4]; + assert!(!c.optional); + assert_eq!("c", c.ident.to_string()); + assert!(!c.pre_bracket); + assert!(!c.pre_pipe); + assert!(c.post_bracket); + assert!(!c.can_be_negated); + } + { + let dpdx = &arguments.0[5]; + assert!(!dpdx.optional); + assert_eq!("dpdx", dpdx.ident.to_string()); + assert!(!dpdx.pre_bracket); + assert!(!dpdx.pre_pipe); + assert!(!dpdx.post_bracket); + assert!(!dpdx.can_be_negated); + } + { + let dpdy = &arguments.0[6]; + assert!(!dpdy.optional); + assert_eq!("dpdy", dpdy.ident.to_string()); + assert!(!dpdy.pre_bracket); + assert!(!dpdy.pre_pipe); + assert!(!dpdy.post_bracket); + assert!(!dpdy.can_be_negated); + } + { + let e = &arguments.0[7]; + assert!(e.optional); + assert_eq!("e", e.ident.to_string()); + assert!(!e.pre_bracket); + assert!(!e.pre_pipe); + assert!(!e.post_bracket); + assert!(!e.can_be_negated); + } + } + + #[test] + fn rule_multi() { + let input = quote! { + .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} } + }; + let rule = syn::parse2::(input).unwrap(); + assert_eq!(". ss", rule.modifier.unwrap().tokens().to_string()); + assert_eq!( + "StateSpace", + rule.type_.unwrap().to_token_stream().to_string() + ); + let alts = rule + .alternatives + .iter() + .map(|m| m.tokens().to_string()) + .collect::>(); + assert_eq!( + vec![ + ". global", + ". local", + ". param", + ". param :: func", + ". shared", + ". shared :: cta", + ". shared :: cluster" + ], + alts + ); + } + + #[test] + fn rule_multi2() { + let input = quote! { + .cop: StCacheOperator = { .wb, .cg, .cs, .wt } + }; + let rule = syn::parse2::(input).unwrap(); + assert_eq!(". cop", rule.modifier.unwrap().tokens().to_string()); + assert_eq!( + "StCacheOperator", + rule.type_.unwrap().to_token_stream().to_string() + ); + let alts = rule + .alternatives + .iter() + .map(|m| m.tokens().to_string()) + .collect::>(); + assert_eq!(vec![". wb", ". cg", ". cs", ". wt",], alts); + } + + #[test] + fn args_unified() { + let input = quote! { + d, [a]{.unified}{, cache_policy} + }; + let args = syn::parse2::(input).unwrap(); + let a = &args.0[1]; + assert!(!a.optional); + assert_eq!("a", a.ident.to_string()); + assert!(a.pre_bracket); + assert!(!a.pre_pipe); + assert!(a.post_bracket); + assert!(!a.can_be_negated); + assert!(a.unified); + } + + #[test] + fn special_block() { + let input = quote! { + bra <= { bra(stream) } + }; + syn::parse2::(input).unwrap(); + } +}