Skip to content

Commit

Permalink
Port ssa conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
vosen committed Aug 26, 2024
1 parent 3e0a15a commit cccd37f
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 9 deletions.
6 changes: 0 additions & 6 deletions ptx/src/pass/convert_to_stateful_memory_access.rs
Original file line number Diff line number Diff line change
Expand Up @@ -527,9 +527,3 @@ fn convert_to_stateful_memory_access_postprocess(
})
})
}

fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
this == other
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
}
276 changes: 276 additions & 0 deletions ptx/src/pass/insert_mem_ssa_statements.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,276 @@
use super::*;
use ptx_parser as ast;

/*
How do we handle arguments:
- input .params in kernels
.param .b64 in_arg
get turned into this SPIR-V:
%1 = OpFunctionParameter %ulong
%2 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %1
We do this for two reasons. One, common treatment for argument-declared
.param variables and .param variables inside function (we assume that
at SPIR-V level every .param is a pointer in Function storage class)
- input .params in functions
.param .b64 in_arg
get turned into this SPIR-V:
%1 = OpFunctionParameter %_ptr_Function_ulong
- input .regs
.reg .b64 in_arg
get turned into the same SPIR-V as kernel .params:
%1 = OpFunctionParameter %ulong
%2 = OpVariable %_ptr_Function_ulong Function
OpStore %2 %1
- output .regs
.reg .b64 out_arg
get just a variable declaration:
%2 = OpVariable %%_ptr_Function_ulong Function
- output .params don't exist, they have been moved to input positions
by an earlier pass
Distinguishing betweem kernel .params and function .params is not the
cleanest solution. Alternatively, we could "deparamize" all kernel .param
arguments by turning them into .reg arguments like this:
.param .b64 arg -> .reg ptr<.b64,.param> arg
This has the massive downside that this transformation would have to run
very early and would muddy up already difficult code. It's simpler to just
have an if here
*/
pub(super) fn run<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
fn_decl: &'a mut ast::MethodDeclaration<'b, SpirvWord>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.input_arguments.iter_mut() {
insert_mem_ssa_argument(
id_def,
&mut result,
arg,
matches!(fn_decl.name, ast::MethodName::Kernel(_)),
);
}
for arg in fn_decl.return_arguments.iter() {
insert_mem_ssa_argument_reg_return(&mut result, arg);
}
for s in func {
match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret { data } => {
// TODO: handle multiple output args
match &fn_decl.return_arguments[..] {
[return_reg] => {
let new_id = id_def.register_intermediate(Some((
return_reg.v_type.clone(),
ast::StateSpace::Reg,
)));
result.push(Statement::LoadVar(LoadVarDetails {
arg: ast::LdArgs {
dst: new_id,
src: return_reg.name,
},
typ: return_reg.v_type.clone(),
member_index: None,
}));
result.push(Statement::RetValue(data, new_id));
}
[] => result.push(Statement::Instruction(ast::Instruction::Ret { data })),
_ => unimplemented!(),
}
}
inst => insert_mem_ssa_statement_default(
id_def,
&mut result,
Statement::Instruction(inst),
)?,
},
Statement::Conditional(bra) => {
insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conditional(bra))?
}
Statement::Conversion(conv) => {
insert_mem_ssa_statement_default(id_def, &mut result, Statement::Conversion(conv))?
}
Statement::PtrAccess(ptr_access) => insert_mem_ssa_statement_default(
id_def,
&mut result,
Statement::PtrAccess(ptr_access),
)?,
Statement::RepackVector(repack) => insert_mem_ssa_statement_default(
id_def,
&mut result,
Statement::RepackVector(repack),
)?,
Statement::FunctionPointer(func_ptr) => insert_mem_ssa_statement_default(
id_def,
&mut result,
Statement::FunctionPointer(func_ptr),
)?,
s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => {
result.push(s)
}
_ => return Err(error_unreachable()),
}
}
Ok(result)
}

fn insert_mem_ssa_argument(
id_def: &mut NumericIdResolver,
func: &mut Vec<TypedStatement>,
arg: &mut ast::Variable<SpirvWord>,
is_kernel: bool,
) {
if !is_kernel && arg.state_space == ast::StateSpace::Param {
return;
}
let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
func.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: arg.v_type.clone(),
state_space: ast::StateSpace::Reg,
name: arg.name,
array_init: Vec::new(),
}));
func.push(Statement::StoreVar(StoreVarDetails {
arg: ast::StArgs {
src1: arg.name,
src2: new_id,
},
typ: arg.v_type.clone(),
member_index: None,
}));
arg.name = new_id;
}

fn insert_mem_ssa_argument_reg_return(
func: &mut Vec<TypedStatement>,
arg: &ast::Variable<SpirvWord>,
) {
func.push(Statement::Variable(ast::Variable {
align: arg.align,
v_type: arg.v_type.clone(),
state_space: arg.state_space,
name: arg.name,
array_init: arg.array_init.clone(),
}));
}

fn insert_mem_ssa_statement_default<'a, 'input>(
id_def: &'a mut NumericIdResolver<'input>,
func: &'a mut Vec<TypedStatement>,
stmt: TypedStatement,
) -> Result<(), TranslateError> {
let mut visitor = InsertMemSSAVisitor {
id_def,
func,
post_statements: Vec::new(),
};
let new_stmt = stmt.visit_map(&mut visitor)?;
visitor.func.push(new_stmt);
visitor.func.extend(visitor.post_statements);
Ok(())
}

struct InsertMemSSAVisitor<'a, 'input> {
id_def: &'a mut NumericIdResolver<'input>,
func: &'a mut Vec<TypedStatement>,
post_statements: Vec<TypedStatement>,
}

impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn symbol(
&mut self,
symbol: SpirvWord,
member_index: Option<u8>,
expected: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
) -> Result<SpirvWord, TranslateError> {
if expected.is_none() {
return Ok(symbol);
};
let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
if !state_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable {
return Ok(symbol);
};
let member_index = match member_index {
Some(idx) => {
let vector_width = match var_type {
ast::Type::Vector(scalar_t, width) => {
var_type = ast::Type::Scalar(scalar_t);
width
}
_ => return Err(TranslateError::MismatchedType),
};
Some((
idx,
if self.id_def.special_registers.get(symbol).is_some() {
Some(vector_width)
} else {
None
},
))
}
None => None,
};
let generated_id = self
.id_def
.register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
if !is_dst {
self.func.push(Statement::LoadVar(LoadVarDetails {
arg: ast::LdArgs {
dst: generated_id,
src: symbol,
},
typ: var_type,
member_index,
}));
} else {
self.post_statements
.push(Statement::StoreVar(StoreVarDetails {
arg: ast::StArgs {
src1: symbol,
src2: generated_id,
},
typ: var_type,
member_index: member_index.map(|(idx, _)| idx),
}));
}
Ok(generated_id)
}
}

impl<'a, 'input> ast::VisitorMap<TypedOperand, TypedOperand, TranslateError>
for InsertMemSSAVisitor<'a, 'input>
{
fn visit(
&mut self,
operand: TypedOperand,
type_space: Option<(&ast::Type, ast::StateSpace)>,
is_dst: bool,
_relaxed_type_check: bool,
) -> Result<TypedOperand, TranslateError> {
Ok(match operand {
TypedOperand::Reg(reg) => {
TypedOperand::Reg(self.symbol(reg, None, type_space, is_dst)?)
}
TypedOperand::RegOffset(reg, offset) => {
TypedOperand::RegOffset(self.symbol(reg, None, type_space, is_dst)?, offset)
}
op @ TypedOperand::Imm(..) => op,
TypedOperand::VecMember(symbol, index) => TypedOperand::VecMember(
self.symbol(symbol, Some(index), type_space, is_dst)?,
index,
),
})
}

fn visit_ident(
&mut self,
args: SpirvWord,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
relaxed_type_check: bool,
) -> Result<SpirvWord, TranslateError> {
self.symbol(args, None, type_space, is_dst)
}
}
13 changes: 10 additions & 3 deletions ptx/src/pass/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use std::{
mod convert_to_stateful_memory_access;
mod convert_to_typed;
mod fix_special_registers;
mod insert_mem_ssa_statements;
mod normalize_identifiers;
mod normalize_predicates;

Expand Down Expand Up @@ -175,13 +176,13 @@ fn to_ssa<'input, 'b>(
fix_special_registers::run(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
let (func_decl, typed_statements) =
convert_to_stateful_memory_access::run(func_decl, typed_statements, &mut numeric_id_defs)?;
todo!()
/*
let ssa_statements = insert_mem_ssa_statements(
let ssa_statements = insert_mem_ssa_statements::run(
typed_statements,
&mut numeric_id_defs,
&mut (*func_decl).borrow_mut(),
)?;
todo!()
/*
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
Expand Down Expand Up @@ -1206,3 +1207,9 @@ impl<
}
}
}

fn state_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
this == other
|| this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
|| this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
}

0 comments on commit cccd37f

Please sign in to comment.