-
Notifications
You must be signed in to change notification settings - Fork 680
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
222 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
use super::*; | ||
use ptx_parser as ast; | ||
|
||
pub(crate) fn run( | ||
func: Vec<UnconditionalStatement>, | ||
fn_defs: &GlobalFnDeclResolver, | ||
id_defs: &mut NumericIdResolver, | ||
) -> Result<Vec<TypedStatement>, TranslateError> { | ||
let mut result = Vec::<TypedStatement>::with_capacity(func.len()); | ||
for s in func { | ||
match s { | ||
Statement::Instruction(inst) => match inst { | ||
ast::Instruction::Mov { | ||
data, | ||
arguments: | ||
ast::MovArgs { | ||
dst: ast::ParsedOperand::Reg(dst_reg), | ||
src: ast::ParsedOperand::Reg(src_reg), | ||
}, | ||
} if fn_defs.fns.contains_key(&src_reg) => { | ||
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { | ||
return Err(TranslateError::MismatchedType); | ||
} | ||
result.push(TypedStatement::FunctionPointer(FunctionPointerDetails { | ||
dst: dst_reg, | ||
src: src_reg, | ||
})); | ||
} | ||
ast::Instruction::Call(call) => { | ||
let resolver = fn_defs.get_fn_sig_resolver(call.func)?; | ||
let resolved_call = resolver.resolve_in_spirv_repr(call)?; | ||
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); | ||
let reresolved_call = resolved_call.visit(&mut visitor)?; | ||
visitor.func.push(reresolved_call); | ||
visitor.func.extend(visitor.post_stmts); | ||
} | ||
inst => { | ||
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); | ||
let instruction = Statement::Instruction(inst.map(&mut visitor)?); | ||
visitor.func.push(instruction); | ||
visitor.func.extend(visitor.post_stmts); | ||
} | ||
}, | ||
Statement::Label(i) => result.push(Statement::Label(i)), | ||
Statement::Variable(v) => result.push(Statement::Variable(v)), | ||
Statement::Conditional(c) => result.push(Statement::Conditional(c)), | ||
_ => return Err(error_unreachable()), | ||
} | ||
} | ||
Ok(result) | ||
} | ||
|
||
struct VectorRepackVisitor<'a, 'b> { | ||
func: &'b mut Vec<TypedStatement>, | ||
id_def: &'b mut NumericIdResolver<'a>, | ||
post_stmts: Option<TypedStatement>, | ||
} | ||
|
||
impl<'a, 'b> VectorRepackVisitor<'a, 'b> { | ||
fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self { | ||
VectorRepackVisitor { | ||
func, | ||
id_def, | ||
post_stmts: None, | ||
} | ||
} | ||
|
||
fn convert_vector( | ||
&mut self, | ||
is_dst: bool, | ||
non_default_implicit_conversion: Option< | ||
fn( | ||
(ast::StateSpace, &ast::Type), | ||
(ast::StateSpace, &ast::Type), | ||
) -> Result<Option<ConversionKind>, TranslateError>, | ||
>, | ||
typ: &ast::Type, | ||
state_space: ast::StateSpace, | ||
idx: Vec<SpirvWord>, | ||
) -> Result<SpirvWord, TranslateError> { | ||
// mov.u32 foobar, {a,b}; | ||
let scalar_t = match typ { | ||
ast::Type::Vector(scalar_t, _) => *scalar_t, | ||
_ => return Err(TranslateError::MismatchedType), | ||
}; | ||
let temp_vec = self | ||
.id_def | ||
.register_intermediate(Some((typ.clone(), state_space))); | ||
let statement = Statement::RepackVector(RepackVectorDetails { | ||
is_extract: is_dst, | ||
typ: scalar_t, | ||
packed: temp_vec, | ||
unpacked: idx, | ||
non_default_implicit_conversion, | ||
}); | ||
if is_dst { | ||
self.post_stmts = Some(statement); | ||
} else { | ||
self.func.push(statement); | ||
} | ||
Ok(temp_vec) | ||
} | ||
} | ||
|
||
impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, TranslateError> | ||
for VectorRepackVisitor<'a, 'b> | ||
{ | ||
fn visit_ident( | ||
&mut self, | ||
ident: SpirvWord, | ||
_: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, | ||
_: bool, | ||
) -> Result<SpirvWord, TranslateError> { | ||
Ok(ident) | ||
} | ||
|
||
fn visit( | ||
&mut self, | ||
op: ast::ParsedOperand<SpirvWord>, | ||
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, | ||
is_dst: bool, | ||
) -> Result<TypedOperand, TranslateError> { | ||
Ok(match op { | ||
ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg), | ||
ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), | ||
ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x), | ||
ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), | ||
ast::ParsedOperand::VecPack(vec) => { | ||
let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?; | ||
TypedOperand::Reg(self.convert_vector( | ||
is_dst, | ||
desc.non_default_implicit_conversion, | ||
type_, | ||
space, | ||
vec, | ||
)?) | ||
} | ||
}) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
8 changes: 0 additions & 8 deletions
8
ptx/src/pass/normalize.rs → ptx/src/pass/normalize_identifiers.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
use super::*; | ||
use ptx_parser as ast; | ||
|
||
pub(crate) fn run( | ||
func: Vec<NormalizedStatement>, | ||
id_def: &mut NumericIdResolver, | ||
) -> Result<Vec<UnconditionalStatement>, TranslateError> { | ||
let mut result = Vec::with_capacity(func.len()); | ||
for s in func { | ||
match s { | ||
Statement::Label(id) => result.push(Statement::Label(id)), | ||
Statement::Instruction((pred, inst)) => { | ||
if let Some(pred) = pred { | ||
let if_true = id_def.register_intermediate(None); | ||
let if_false = id_def.register_intermediate(None); | ||
let folded_bra = match &inst { | ||
ast::Instruction::Bra { arguments, .. } => Some(arguments.src), | ||
_ => None, | ||
}; | ||
let mut branch = BrachCondition { | ||
predicate: pred.label, | ||
if_true: folded_bra.unwrap_or(if_true), | ||
if_false, | ||
}; | ||
if pred.not { | ||
std::mem::swap(&mut branch.if_true, &mut branch.if_false); | ||
} | ||
result.push(Statement::Conditional(branch)); | ||
if folded_bra.is_none() { | ||
result.push(Statement::Label(if_true)); | ||
result.push(Statement::Instruction(inst)); | ||
} | ||
result.push(Statement::Label(if_false)); | ||
} else { | ||
result.push(Statement::Instruction(inst)); | ||
} | ||
} | ||
Statement::Variable(var) => result.push(Statement::Variable(var)), | ||
// Blocks are flattened when resolving ids | ||
_ => return Err(error_unreachable()), | ||
} | ||
} | ||
Ok(result) | ||
} |