diff --git a/ptx/src/pass/convert_to_typed.rs b/ptx/src/pass/convert_to_typed.rs index 3dfef55b..7ff52909 100644 --- a/ptx/src/pass/convert_to_typed.rs +++ b/ptx/src/pass/convert_to_typed.rs @@ -26,17 +26,9 @@ pub(crate) fn run( src: src_reg, })); } - ast::Instruction::Call(call) => { - let resolver = fn_defs.get_fn_sig_resolver(call.func)?; - let resolved_call = resolver.resolve_in_spirv_repr(call)?; - let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); - let reresolved_call = resolved_call.visit(&mut visitor)?; - visitor.func.push(reresolved_call); - visitor.func.extend(visitor.post_stmts); - } inst => { let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); - let instruction = Statement::Instruction(inst.map(&mut visitor)?); + let instruction = Statement::Instruction(ast::visit_map(inst, &mut visitor)?); visitor.func.push(instruction); visitor.func.extend(visitor.post_stmts); } @@ -68,12 +60,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { fn convert_vector( &mut self, is_dst: bool, - non_default_implicit_conversion: Option< - fn( - (ast::StateSpace, &ast::Type), - (ast::StateSpace, &ast::Type), - ) -> Result, TranslateError>, - >, + relaxed_type_check: bool, typ: &ast::Type, state_space: ast::StateSpace, idx: Vec, @@ -91,7 +78,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { typ: scalar_t, packed: temp_vec, unpacked: idx, - non_default_implicit_conversion, + relaxed_type_check, }); if is_dst { self.post_stmts = Some(statement); @@ -110,6 +97,7 @@ impl<'a, 'b> ast::VisitorMap, TypedOperand, Transl ident: SpirvWord, _: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, _: bool, + _: bool, ) -> Result { Ok(ident) } @@ -119,6 +107,7 @@ impl<'a, 'b> ast::VisitorMap, TypedOperand, Transl op: ast::ParsedOperand, type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, is_dst: bool, + relaxed_type_check: bool, ) -> Result { Ok(match op { ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg), @@ -129,7 +118,7 @@ impl<'a, 'b> ast::VisitorMap, TypedOperand, Transl let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?; TypedOperand::Reg(self.convert_vector( is_dst, - desc.non_default_implicit_conversion, + relaxed_type_check, type_, space, vec, diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index bedf46ab..3968d3d5 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -798,12 +798,7 @@ struct RepackVectorDetails { typ: ast::ScalarType, packed: SpirvWord, unpacked: Vec, - non_default_implicit_conversion: Option< - fn( - (ast::StateSpace, &ast::Type), - (ast::StateSpace, &ast::Type), - ) -> Result, TranslateError>, - >, + relaxed_type_check: bool } struct FunctionPointerDetails { diff --git a/ptx/src/pass/normalize_identifiers.rs b/ptx/src/pass/normalize_identifiers.rs index 6588d637..b5983453 100644 --- a/ptx/src/pass/normalize_identifiers.rs +++ b/ptx/src/pass/normalize_identifiers.rs @@ -41,6 +41,7 @@ fn expand_map_variables<'a, 'b>( .transpose()?, ast::visit_map(i, &mut |id, _: Option<(&ast::Type, ast::StateSpace)>, + _: bool, _: bool| { id_defs.get_id(id) })?,