Skip to content

Commit

Permalink
Work on more passes
Browse files Browse the repository at this point in the history
  • Loading branch information
vosen committed Aug 23, 2024
1 parent 12ef8db commit 7ea990e
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 15 deletions.
140 changes: 140 additions & 0 deletions ptx/src/pass/convert_to_typed.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
use super::*;
use ptx_parser as ast;

pub(crate) fn run(
func: Vec<UnconditionalStatement>,
fn_defs: &GlobalFnDeclResolver,
id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::<TypedStatement>::with_capacity(func.len());
for s in func {
match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Mov {
data,
arguments:
ast::MovArgs {
dst: ast::ParsedOperand::Reg(dst_reg),
src: ast::ParsedOperand::Reg(src_reg),
},
} if fn_defs.fns.contains_key(&src_reg) => {
if data.typ != ast::Type::Scalar(ast::ScalarType::U64) {
return Err(TranslateError::MismatchedType);
}
result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
dst: dst_reg,
src: src_reg,
}));
}
ast::Instruction::Call(call) => {
let resolver = fn_defs.get_fn_sig_resolver(call.func)?;
let resolved_call = resolver.resolve_in_spirv_repr(call)?;
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let reresolved_call = resolved_call.visit(&mut visitor)?;
visitor.func.push(reresolved_call);
visitor.func.extend(visitor.post_stmts);
}
inst => {
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let instruction = Statement::Instruction(inst.map(&mut visitor)?);
visitor.func.push(instruction);
visitor.func.extend(visitor.post_stmts);
}
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
Statement::Conditional(c) => result.push(Statement::Conditional(c)),
_ => return Err(error_unreachable()),
}
}
Ok(result)
}

struct VectorRepackVisitor<'a, 'b> {
func: &'b mut Vec<TypedStatement>,
id_def: &'b mut NumericIdResolver<'a>,
post_stmts: Option<TypedStatement>,
}

impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
VectorRepackVisitor {
func,
id_def,
post_stmts: None,
}
}

fn convert_vector(
&mut self,
is_dst: bool,
non_default_implicit_conversion: Option<
fn(
(ast::StateSpace, &ast::Type),
(ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError>,
>,
typ: &ast::Type,
state_space: ast::StateSpace,
idx: Vec<SpirvWord>,
) -> Result<SpirvWord, TranslateError> {
// mov.u32 foobar, {a,b};
let scalar_t = match typ {
ast::Type::Vector(scalar_t, _) => *scalar_t,
_ => return Err(TranslateError::MismatchedType),
};
let temp_vec = self
.id_def
.register_intermediate(Some((typ.clone(), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
packed: temp_vec,
unpacked: idx,
non_default_implicit_conversion,
});
if is_dst {
self.post_stmts = Some(statement);
} else {
self.func.push(statement);
}
Ok(temp_vec)
}
}

impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, TypedOperand, TranslateError>
for VectorRepackVisitor<'a, 'b>
{
fn visit_ident(
&mut self,
ident: SpirvWord,
_: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
_: bool,
) -> Result<SpirvWord, TranslateError> {
Ok(ident)
}

fn visit(
&mut self,
op: ast::ParsedOperand<SpirvWord>,
type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
is_dst: bool,
) -> Result<TypedOperand, TranslateError> {
Ok(match op {
ast::ParsedOperand::Reg(reg) => TypedOperand::Reg(reg),
ast::ParsedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
ast::ParsedOperand::Imm(x) => TypedOperand::Imm(x),
ast::ParsedOperand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
ast::ParsedOperand::VecPack(vec) => {
let (type_, space) = type_space.ok_or(TranslateError::MismatchedType)?;
TypedOperand::Reg(self.convert_vector(
is_dst,
desc.non_default_implicit_conversion,
type_,
space,
vec,
)?)
}
})
}
}
45 changes: 38 additions & 7 deletions ptx/src/pass/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ use std::{
rc::Rc,
};

pub(crate) mod normalize;
mod convert_to_typed;
mod normalize_identifiers;
mod normalize_predicates;

static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
Expand Down Expand Up @@ -161,13 +163,13 @@ fn to_ssa<'input, 'b>(
})
}
};
let normalized_ids = normalize::run(&mut id_defs, &fn_defs, f_body)?;
todo!()
/*
let normalized_ids = normalize_identifiers::run(&mut id_defs, &fn_defs, f_body)?;
let mut numeric_id_defs = id_defs.finish();
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let unadorned_statements = normalize_predicates::run(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
convert_to_typed::run(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
todo!()
/*
let typed_statements =
fix_special_registers2(ptx_impl_imports, typed_statements, &mut numeric_id_defs)?;
let (func_decl, typed_statements) =
Expand Down Expand Up @@ -856,4 +858,33 @@ pub(crate) struct Function<'input> {
linkage: ast::LinkingDirective,
}

type ExpandedStatement = Statement<ast::Instruction<SpirvWord>, SpirvWord>;
type ExpandedStatement = Statement<ast::Instruction<SpirvWord>, SpirvWord>;

type NormalizedStatement = Statement<
(
Option<ast::PredAt<SpirvWord>>,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
),
ast::ParsedOperand<SpirvWord>,
>;

type UnconditionalStatement =
Statement<ast::Instruction<ast::ParsedOperand<SpirvWord>>, ast::ParsedOperand<SpirvWord>>;

type TypedStatement = Statement<ast::Instruction<TypedOperand>, TypedOperand>;

#[derive(Copy, Clone)]
enum TypedOperand {
Reg(SpirvWord),
RegOffset(SpirvWord, i32),
Imm(ast::ImmediateValue),
VecMember(SpirvWord, u8),
}

impl ast::Operand for TypedOperand {
type Ident = SpirvWord;

fn from_ident(ident: Self::Ident) -> Self {
TypedOperand::Reg(ident)
}
}
Original file line number Diff line number Diff line change
@@ -1,14 +1,6 @@
use super::*;
use ptx_parser as ast;

type NormalizedStatement = Statement<
(
Option<ast::PredAt<SpirvWord>>,
ast::Instruction<ast::ParsedOperand<SpirvWord>>,
),
ast::ParsedOperand<SpirvWord>,
>;

pub(crate) fn run<'input, 'b>(
id_defs: &mut FnStringIdResolver<'input, 'b>,
fn_defs: &GlobalFnDeclResolver<'input, 'b>,
Expand Down
44 changes: 44 additions & 0 deletions ptx/src/pass/normalize_predicates.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use super::*;
use ptx_parser as ast;

pub(crate) fn run(
func: Vec<NormalizedStatement>,
id_def: &mut NumericIdResolver,
) -> Result<Vec<UnconditionalStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Instruction((pred, inst)) => {
if let Some(pred) = pred {
let if_true = id_def.register_intermediate(None);
let if_false = id_def.register_intermediate(None);
let folded_bra = match &inst {
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
_ => None,
};
let mut branch = BrachCondition {
predicate: pred.label,
if_true: folded_bra.unwrap_or(if_true),
if_false,
};
if pred.not {
std::mem::swap(&mut branch.if_true, &mut branch.if_false);
}
result.push(Statement::Conditional(branch));
if folded_bra.is_none() {
result.push(Statement::Label(if_true));
result.push(Statement::Instruction(inst));
}
result.push(Statement::Label(if_false));
} else {
result.push(Statement::Instruction(inst));
}
}
Statement::Variable(var) => result.push(Statement::Variable(var)),
// Blocks are flattened when resolving ids
_ => return Err(error_unreachable()),
}
}
Ok(result)
}

0 comments on commit 7ea990e

Please sign in to comment.