summaryrefslogtreecommitdiffhomepage
path: root/ptx/src/translate.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r--ptx/src/translate.rs2070
1 files changed, 515 insertions, 1555 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 0d86066..7cce63c 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -151,16 +151,16 @@ fn emit_function<'a>(
if f.kernel {
builder.entry_point(spirv::ExecutionModel::Kernel, func_id, f.name, &[]);
}
- let (mut func_body, bbs, _, unique_ids) = to_ssa(&f.args, f.body);
+ let (mut func_body, unique_ids) = to_ssa(&f.args, f.body);
let id_offset = builder.reserve_ids(unique_ids);
emit_function_args(builder, id_offset, map, &f.args);
apply_id_offset(&mut func_body, id_offset);
- emit_function_body_ops(builder, map, &func_body, &bbs)?;
+ emit_function_body_ops(builder, map, &func_body)?;
builder.end_function()?;
Ok(func_id)
}
-fn apply_id_offset(func_body: &mut Vec<Statement>, id_offset: u32) {
+fn apply_id_offset(func_body: &mut Vec<ExpandedStatement>, id_offset: u32) {
for s in func_body {
s.visit_id_mut(&mut |_, id| *id += id_offset);
}
@@ -169,61 +169,27 @@ fn apply_id_offset(func_body: &mut Vec<Statement>, id_offset: u32) {
fn to_ssa<'a>(
f_args: &[ast::Argument],
f_body: Vec<ast::Statement<&'a str>>,
-) -> (
- Vec<Statement>,
- Vec<BasicBlock>,
- Vec<Vec<PhiDef>>,
- spirv::Word,
-) {
- let mut contant_ids = HashMap::new();
- let mut type_check = HashMap::new();
- collect_arg_ids(&mut contant_ids, &mut type_check, &f_args);
- collect_label_ids(&mut contant_ids, &f_body);
- let registers = collect_var_definitions(&f_args, &f_body);
- let (normalized_ids, mut unique_ids) =
- normalize_identifiers(f_body, &contant_ids, &mut type_check, registers);
- let type_check = RefCell::new(type_check);
- let new_id = &mut |typ: Option<ast::Type>| {
- let to_insert = unique_ids;
- {
- let mut type_check = type_check.borrow_mut();
- typ.map(|t| (*type_check).insert(to_insert, t));
- }
- unique_ids += 1;
- to_insert
- };
- let normalized_stmts = normalize_statements(normalized_ids, new_id);
- let mut func_body = insert_implicit_conversions(normalized_stmts, new_id, &|x| {
- let type_check = type_check.borrow();
- type_check[&x]
- });
- let bbs = get_basic_blocks(&func_body);
- let rpostorder = to_reverse_postorder(&bbs);
- let doms = immediate_dominators(&bbs, &rpostorder);
- let dom_fronts = dominance_frontiers(&bbs, &doms);
- let (phis, unique_ids) = ssa_legalize(
- &mut func_body,
- contant_ids.len() as u32,
- unique_ids,
- &bbs,
- &doms,
- &dom_fronts,
- );
- (func_body, bbs, phis, unique_ids)
+) -> (Vec<ExpandedStatement>, spirv::Word) {
+ let (normalized_ids, mut id_def) = normalize_identifiers(&f_args, f_body);
+ let normalized_statements = normalize_predicates(normalized_ids, &mut id_def);
+ let ssa_statements = insert_mem_ssa_statements(normalized_statements, &mut id_def);
+ let expanded_statements = expand_arguments(ssa_statements, &mut id_def);
+ let expanded_statements = insert_implicit_conversions(expanded_statements, &mut id_def);
+ (expanded_statements, id_def.ids_count())
}
-fn normalize_statements(
+fn normalize_predicates(
func: Vec<ast::Statement<spirv::Word>>,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
-) -> Vec<Statement> {
+ id_def: &mut NumericIdResolver,
+) -> Vec<Statement<NormalizedArgs>> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
ast::Statement::Label(id) => result.push(Statement::Label(id)),
ast::Statement::Instruction(pred, inst) => {
if let Some(pred) = pred {
- let mut if_true = new_id(None);
- let mut if_false = new_id(None);
+ let mut if_true = id_def.new_id(None);
+ let mut if_false = id_def.new_id(None);
if pred.not {
std::mem::swap(&mut if_true, &mut if_false);
}
@@ -239,16 +205,82 @@ fn normalize_statements(
result.push(Statement::Conditional(branch));
if folded_bra.is_none() {
result.push(Statement::Label(if_true));
- let instr = normalize_insert_instruction(&mut result, new_id, inst);
- result.push(Statement::Instruction(instr));
+ result.push(Statement::Instruction(Instruction::from_ast(inst)));
}
result.push(Statement::Label(if_false));
} else {
- let instr = normalize_insert_instruction(&mut result, new_id, inst);
- result.push(Statement::Instruction(instr));
+ result.push(Statement::Instruction(Instruction::from_ast(inst)));
}
}
- ast::Statement::Variable(_) => unreachable!(),
+ ast::Statement::Variable(var) => result.push(Statement::Variable(var.name, var.v_type)),
+ }
+ }
+ result
+}
+
+fn insert_mem_ssa_statements(
+ func: Vec<Statement<NormalizedArgs>>,
+ id_def: &mut NumericIdResolver,
+) -> Vec<Statement<NormalizedArgs>> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func {
+ match s {
+ Statement::Instruction(mut inst) => {
+ let inst_type = inst.get_type();
+ let mut post_statements = Vec::new();
+ inst.visit_id_mut(&mut |is_dst, id| {
+ let inst_type = inst_type.unwrap();
+ let generated_id = id_def.new_id(Some(inst_type));
+ if !is_dst {
+ result.push(Statement::LoadVar(
+ Arg2 {
+ dst: generated_id,
+ src: *id,
+ },
+ inst_type,
+ ));
+ } else {
+ post_statements.push(Statement::StoreVar(
+ Arg2St {
+ src1: *id,
+ src2: generated_id,
+ },
+ inst_type,
+ ));
+ }
+ *id = generated_id;
+ });
+ result.push(Statement::Instruction(inst));
+ result.append(&mut post_statements);
+ }
+ s @ Statement::Variable(_, _)
+ | s @ Statement::Label(_)
+ | s @ Statement::Conditional(_) => result.push(s),
+ Statement::LoadVar(_, _)
+ | Statement::StoreVar(_, _)
+ | Statement::Converison(_)
+ | Statement::Constant(_) => unreachable!(),
+ }
+ }
+ result
+}
+
+fn expand_arguments(
+ func: Vec<Statement<NormalizedArgs>>,
+ id_def: &mut NumericIdResolver,
+) -> Vec<ExpandedStatement> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func {
+ match s {
+ Statement::Instruction(inst) => {
+ normalize_insert_instruction(&mut result, id_def, inst);
+ }
+ Statement::Variable(id, typ) => result.push(Statement::Variable(id, typ)),
+ Statement::Label(id) => result.push(Statement::Label(id)),
+ Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
+ Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
+ Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
+ Statement::Converison(_) | Statement::Constant(_) => unreachable!(),
}
}
result
@@ -256,137 +288,137 @@ fn normalize_statements(
#[must_use]
fn normalize_insert_instruction(
- func: &mut Vec<Statement>,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
- instr: ast::Instruction<spirv::Word>,
-) -> Instruction {
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
+ instr: Instruction<NormalizedArgs>,
+) -> Instruction<ExpandedArgs> {
match instr {
- ast::Instruction::Ld(d, a) => {
- let arg = normalize_expand_arg2(func, new_id, &|| Some(d.typ), a);
+ Instruction::Ld(d, a) => {
+ let arg = normalize_expand_arg2(func, id_def, &|| Some(d.typ), a);
Instruction::Ld(d, arg)
}
- ast::Instruction::Mov(d, a) => {
- let arg = normalize_expand_arg2mov(func, new_id, &|| d.typ.try_as_scalar(), a);
+ Instruction::Mov(d, a) => {
+ let arg = normalize_expand_arg2mov(func, id_def, &|| d.typ.try_as_scalar(), a);
Instruction::Mov(d, arg)
}
- ast::Instruction::Mul(d, a) => {
- let arg = normalize_expand_arg3(func, new_id, &|| d.typ.try_as_scalar(), a);
+ Instruction::Mul(d, a) => {
+ let arg = normalize_expand_arg3(func, id_def, &|| d.typ.try_as_scalar(), a);
Instruction::Mul(d, arg)
}
- ast::Instruction::Add(d, a) => {
- let arg = normalize_expand_arg3(func, new_id, &|| Some(d.typ), a);
+ Instruction::Add(d, a) => {
+ let arg = normalize_expand_arg3(func, id_def, &|| Some(d.typ), a);
Instruction::Add(d, arg)
}
- ast::Instruction::Setp(d, a) => {
- let arg = normalize_expand_arg4(func, new_id, &|| Some(d.typ), a);
+ Instruction::Setp(d, a) => {
+ let arg = normalize_expand_arg4(func, id_def, &|| Some(d.typ), a);
Instruction::Setp(d, arg)
}
- ast::Instruction::SetpBool(d, a) => {
- let arg = normalize_expand_arg5(func, new_id, &|| Some(d.typ), a);
+ Instruction::SetpBool(d, a) => {
+ let arg = normalize_expand_arg5(func, id_def, &|| Some(d.typ), a);
Instruction::SetpBool(d, arg)
}
- ast::Instruction::Not(d, a) => {
- let arg = normalize_expand_arg2(func, new_id, &|| todo!(), a);
+ Instruction::Not(d, a) => {
+ let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a);
Instruction::Not(d, arg)
}
- ast::Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }),
- ast::Instruction::Cvt(d, a) => {
- let arg = normalize_expand_arg2(func, new_id, &|| todo!(), a);
+ Instruction::Bra(d, a) => Instruction::Bra(d, Arg1 { src: a.src }),
+ Instruction::Cvt(d, a) => {
+ let arg = normalize_expand_arg2(func, id_def, &|| todo!(), a);
Instruction::Cvt(d, arg)
}
- ast::Instruction::Shl(d, a) => {
- let arg = normalize_expand_arg3(func, new_id, &|| todo!(), a);
+ Instruction::Shl(d, a) => {
+ let arg = normalize_expand_arg3(func, id_def, &|| todo!(), a);
Instruction::Shl(d, arg)
}
- ast::Instruction::St(d, a) => {
- let arg = normalize_expand_arg2st(func, new_id, &|| todo!(), a);
+ Instruction::St(d, a) => {
+ let arg = normalize_expand_arg2st(func, id_def, &|| todo!(), a);
Instruction::St(d, arg)
}
- ast::Instruction::Ret(d) => Instruction::Ret(d),
+ Instruction::Ret(d) => Instruction::Ret(d),
}
}
fn normalize_expand_arg2(
- func: &mut Vec<Statement>,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2<spirv::Word>,
) -> Arg2 {
Arg2 {
dst: a.dst,
- src: normalize_expand_operand(func, new_id, inst_type, a.src),
+ src: normalize_expand_operand(func, id_def, inst_type, a.src),
}
}
fn normalize_expand_arg2mov(
- func: &mut Vec<Statement>,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2Mov<spirv::Word>,
) -> Arg2 {
Arg2 {
dst: a.dst,
- src: normalize_expand_mov_operand(func, new_id, inst_type, a.src),
+ src: normalize_expand_mov_operand(func, id_def, inst_type, a.src),
}
}
fn normalize_expand_arg2st(
- func: &mut Vec<Statement>,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2St<spirv::Word>,
) -> Arg2St {
Arg2St {
- src1: normalize_expand_operand(func, new_id, inst_type, a.src1),
- src2: normalize_expand_operand(func, new_id, inst_type, a.src2),
+ src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
+ src2: normalize_expand_operand(func, id_def, inst_type, a.src2),
}
}
fn normalize_expand_arg3(
- func: &mut Vec<Statement>,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg3<spirv::Word>,
) -> Arg3 {
Arg3 {
dst: a.dst,
- src1: normalize_expand_operand(func, new_id, inst_type, a.src1),
- src2: normalize_expand_operand(func, new_id, inst_type, a.src2),
+ src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
+ src2: normalize_expand_operand(func, id_def, inst_type, a.src2),
}
}
fn normalize_expand_arg4(
- func: &mut Vec<Statement>,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg4<spirv::Word>,
) -> Arg4 {
Arg4 {
dst1: a.dst1,
dst2: a.dst2,
- src1: normalize_expand_operand(func, new_id, inst_type, a.src1),
- src2: normalize_expand_operand(func, new_id, inst_type, a.src2),
+ src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
+ src2: normalize_expand_operand(func, id_def, inst_type, a.src2),
}
}
fn normalize_expand_arg5(
- func: &mut Vec<Statement>,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg5<spirv::Word>,
) -> Arg5 {
Arg5 {
dst1: a.dst1,
dst2: a.dst2,
- src1: normalize_expand_operand(func, new_id, inst_type, a.src1),
- src2: normalize_expand_operand(func, new_id, inst_type, a.src2),
- src3: normalize_expand_operand(func, new_id, inst_type, a.src3),
+ src1: normalize_expand_operand(func, id_def, inst_type, a.src1),
+ src2: normalize_expand_operand(func, id_def, inst_type, a.src2),
+ src3: normalize_expand_operand(func, id_def, inst_type, a.src3),
}
}
fn normalize_expand_operand(
- func: &mut Vec<Statement>,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
opr: ast::Operand<spirv::Word>,
) -> spirv::Word {
@@ -394,7 +426,7 @@ fn normalize_expand_operand(
ast::Operand::Reg(r) => r,
ast::Operand::Imm(x) => {
if let Some(typ) = inst_type() {
- let id = new_id(Some(ast::Type::Scalar(typ)));
+ let id = id_def.new_id(Some(ast::Type::Scalar(typ)));
func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: typ,
@@ -410,43 +442,17 @@ fn normalize_expand_operand(
}
fn normalize_expand_mov_operand(
- func: &mut Vec<Statement>,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
opr: ast::MovOperand<spirv::Word>,
) -> spirv::Word {
match opr {
- ast::MovOperand::Op(opr) => normalize_expand_operand(func, new_id, inst_type, opr),
+ ast::MovOperand::Op(opr) => normalize_expand_operand(func, id_def, inst_type, opr),
_ => todo!(),
}
}
-fn collect_var_definitions<'a>(
- args: &[ast::Argument<'a>],
- body: &[ast::Statement<&'a str>],
-) -> HashMap<Cow<'a, str>, ast::Type> {
- let mut result = HashMap::new();
- for param in args {
- result.insert(Cow::Borrowed(param.name), ast::Type::Scalar(param.a_type));
- }
- for s in body {
- match s {
- ast::Statement::Variable(var) => match var.count {
- Some(count) => {
- for i in 0..count {
- result.insert(Cow::Owned(format!("{}{}", var.name, i)), var.v_type);
- }
- }
- None => {
- result.insert(Cow::Borrowed(var.name), var.v_type);
- }
- },
- ast::Statement::Label(_) | ast::Statement::Instruction(_, _) => (),
- }
- }
- result
-}
-
/*
There are several kinds of implicit conversions in PTX:
* auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
@@ -460,11 +466,10 @@ fn collect_var_definitions<'a>(
- generic st: for instruction `st [x], y`, x must be of type b64/u64/s64,
which is bitcast to a pointer
*/
-fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
- normalized_ids: Vec<Statement>,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
- type_check: &TypeCheck,
-) -> Vec<Statement> {
+fn insert_implicit_conversions(
+ normalized_ids: Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
+) -> Vec<ExpandedStatement> {
let mut result = Vec::with_capacity(normalized_ids.len());
for s in normalized_ids.into_iter() {
match s {
@@ -473,16 +478,14 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
arg.src = insert_implicit_conversions_ld_src(
&mut result,
ast::Type::Scalar(ld.typ),
- type_check,
- new_id,
+ id_def,
ld.state_space,
arg.src,
);
insert_with_implicit_conversion_dst(
&mut result,
ld.typ,
- type_check,
- new_id,
+ id_def,
should_convert_relaxed_dst,
arg,
|arg| &mut arg.dst,
@@ -490,11 +493,11 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
);
}
Instruction::St(st, mut arg) => {
- let arg_src2_type = type_check(arg.src2);
+ let arg_src2_type = id_def.get_type(arg.src2);
if let Some(conv) = should_convert_relaxed_src(arg_src2_type, st.typ) {
arg.src2 = insert_conversion_src(
&mut result,
- new_id,
+ id_def,
arg.src2,
arg_src2_type,
ast::Type::Scalar(st.typ),
@@ -504,17 +507,19 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
arg.src1 = insert_implicit_conversions_ld_src(
&mut result,
ast::Type::Scalar(st.typ),
- type_check,
- new_id,
+ id_def,
st.state_space.to_ld_ss(),
arg.src1,
);
result.push(Statement::Instruction(Instruction::St(st, arg)));
}
- inst @ _ => insert_implicit_bitcasts(&mut result, type_check, new_id, inst),
+ inst @ _ => insert_implicit_bitcasts(&mut result, id_def, inst),
},
s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s),
- Statement::Constant(_) => (),
+ Statement::Constant(_)
+ | Statement::Variable(_, _)
+ | Statement::LoadVar(_, _)
+ | Statement::StoreVar(_, _) => (),
Statement::Converison(_) => unreachable!(),
}
}
@@ -582,76 +587,62 @@ fn collect_label_ids<'a>(
fn emit_function_body_ops(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- func: &[Statement],
- cfg: &[BasicBlock],
+ func: &[ExpandedStatement],
) -> Result<(), dr::Error> {
- // TODO: entry basic block can't be target of jumps,
- // we need to emit additional BB for this purpose
- for bb_idx in 0..cfg.len() {
- let body = get_bb_body(func, cfg, BBIndex(bb_idx));
- if body.len() == 0 {
- continue;
- }
- let header_id = if let Statement::Label(id) = body[0] {
- Some(id)
- } else {
- None
- };
- builder.begin_block(header_id)?;
- for s in body {
- match s {
- // If block starts with a label it has already been emitted,
- // all other labels in the block are unused
- Statement::Label(_) => (),
- Statement::Constant(_) => todo!(),
- Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?,
- Statement::Conditional(bra) => {
- builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
+ for s in func {
+ match s {
+ // If block starts with a label it has already been emitted,
+ // all other labels in the block are unused
+ Statement::Label(_) => (),
+ Statement::Constant(_) => todo!(),
+ Statement::Converison(cv) => emit_implicit_conversion(builder, map, cv)?,
+ Statement::Conditional(bra) => {
+ builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
+ }
+ Statement::Instruction(inst) => match inst {
+ // SPIR-V does not support marking jumps as guaranteed-converged
+ Instruction::Bra(_, arg) => {
+ builder.branch(arg.src)?;
}
- Statement::Instruction(inst) => match inst {
- // SPIR-V does not support marking jumps as guaranteed-converged
- Instruction::Bra(_, arg) => {
- builder.branch(arg.src)?;
+ Instruction::Ld(data, arg) => {
+ if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
+ todo!()
}
- Instruction::Ld(data, arg) => {
- if data.qualifier != ast::LdStQualifier::Weak || data.vector.is_some() {
- todo!()
+ let result_type = map.get_or_add_scalar(builder, data.typ);
+ match data.state_space {
+ ast::LdStateSpace::Generic => {
+ builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
}
- let result_type = map.get_or_add_scalar(builder, data.typ);
- match data.state_space {
- ast::LdStateSpace::Generic => {
- builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
- }
- ast::LdStateSpace::Param => {
- builder.copy_object(result_type, Some(arg.dst), arg.src)?;
- }
- _ => todo!(),
+ ast::LdStateSpace::Param => {
+ builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
+ _ => todo!(),
}
- Instruction::St(data, arg) => {
- if data.qualifier != ast::LdStQualifier::Weak
- || data.vector.is_some()
- || data.state_space != ast::StStateSpace::Generic
- {
- todo!()
- }
- builder.store(arg.src1, arg.src2, None, &[])?;
+ }
+ Instruction::St(data, arg) => {
+ if data.qualifier != ast::LdStQualifier::Weak
+ || data.vector.is_some()
+ || data.state_space != ast::StStateSpace::Generic
+ {
+ todo!()
}
- // SPIR-V does not support ret as guaranteed-converged
- Instruction::Ret(_) => builder.ret()?,
- Instruction::Mov(mov, arg) => {
- let result_type = map.get_or_add(builder, SpirvType::from(mov.typ));
- builder.copy_object(result_type, Some(arg.dst), arg.src)?;
+ builder.store(arg.src1, arg.src2, None, &[])?;
+ }
+ // SPIR-V does not support ret as guaranteed-converged
+ Instruction::Ret(_) => builder.ret()?,
+ Instruction::Mov(mov, arg) => {
+ let result_type = map.get_or_add(builder, SpirvType::from(mov.typ));
+ builder.copy_object(result_type, Some(arg.dst), arg.src)?;
+ }
+ Instruction::Mul(mul, arg) => match mul.desc {
+ ast::MulDescriptor::Int(ref ctr) => {
+ emit_mul_int(builder, map, mul.typ, ctr, arg)
}
- Instruction::Mul(mul, arg) => match mul.desc {
- ast::MulDescriptor::Int(ref ctr) => {
- emit_mul_int(builder, map, mul.typ, ctr, arg)
- }
- ast::MulDescriptor::Float(_) => todo!(),
- },
- _ => todo!(),
+ ast::MulDescriptor::Float(_) => todo!(),
},
- }
+ _ => todo!(),
+ },
+ _ => todo!(),
}
}
Ok(())
@@ -734,567 +725,225 @@ fn emit_implicit_conversion(
// TODO: support scopes
fn normalize_identifiers<'a>(
+ args: &'a [ast::Argument<'a>],
func: Vec<ast::Statement<&'a str>>,
- constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined
- type_map: &mut HashMap<spirv::Word, ast::Type>,
- types: HashMap<Cow<'a, str>, ast::Type>,
-) -> (Vec<ast::Statement<spirv::Word>>, spirv::Word) {
- let mut id: u32 = constant_identifiers.len() as u32;
- let mut remapped_ids = HashMap::new();
- let mut get_or_add = |key| {
- constant_identifiers.get(key).map_or_else(
- || {
- *remapped_ids.entry(key).or_insert_with(|| {
- let to_insert = id;
- id += 1;
- to_insert
- })
- },
- |id| *id,
- )
- };
- let result = func
- .into_iter()
- .filter_map(|s| Statement::from_ast(s, &mut get_or_add))
- .collect::<Vec<_>>();
- type_map.extend(
- remapped_ids
- .into_iter()
- .map(|(old_id, new_id)| (new_id, types[old_id])),
- );
- (result, id)
-}
-
-fn ssa_legalize(
- func: &mut [Statement],
- constant_ids: spirv::Word,
- unique_ids: spirv::Word,
- bbs: &[BasicBlock],
- doms: &[BBIndex],
- dom_fronts: &[HashSet<BBIndex>],
-) -> (Vec<Vec<PhiDef>>, spirv::Word) {
- let phis = gather_phi_sets(&func, constant_ids, unique_ids, &bbs, dom_fronts);
- apply_ssa_renaming(func, &bbs, doms, constant_ids, unique_ids, &phis)
-}
-
-/* "Modern Compiler Implementation in Java" - Algorithm 19.7
- * This algorithm modifies passed function body in-place by renumbering ids,
- * result ids can be divided into following categories
- * - if id < constant_ids
- * it's a non-redefinable id
- * - if id >= constant_ids && id < all_ids
- * then it's an undefined id (a0, b0, c0)
- * - if id >= all_ids
- * then it's a normally redefined id
- */
-fn apply_ssa_renaming(
- func: &mut [Statement],
- bbs: &[BasicBlock],
- doms: &[BBIndex],
- constant_ids: spirv::Word,
- all_ids: spirv::Word,
- old_phi: &[HashSet<spirv::Word>],
-) -> (Vec<Vec<PhiDef>>, spirv::Word) {
- let mut dom_tree = vec![Vec::new(); bbs.len()];
- for (bb, idom) in doms.iter().enumerate().skip(1) {
- dom_tree[idom.0].push(BBIndex(bb));
- }
- let mut old_dst_id = vec![Vec::new(); bbs.len()];
- for bb in 0..bbs.len() {
- for s in get_bb_body(func, bbs, BBIndex(bb)) {
- s.visit_id(&mut |is_dst, id| {
- if is_dst {
- old_dst_id[bb].push(id)
- }
- });
- }
- }
- let mut new_phi = old_phi
- .iter()
- .map(|ids| {
- ids.iter()
- .map(|id| (*id, (u32::max_value(), HashSet::new())))
- .collect::<HashMap<_, _>>()
- })
- .collect::<Vec<_>>();
- let mut ssa_state = SSARewriteState::new(constant_ids, all_ids);
- // once again, we do explicit stack
- let mut state = Vec::new();
- state.push((BBIndex(0), 0));
- loop {
- if let Some((BBIndex(bb), dom_succ_idx)) = state.last_mut() {
- let bb = *bb;
- if *dom_succ_idx == 0 {
- rename_phi_dst(&mut ssa_state, &mut new_phi[bb]);
- rename_bb_body(&mut ssa_state, func, bbs, BBIndex(bb));
- for BBIndex(succ_idx) in bbs[bb].succ.iter() {
- rename_succesor_phi_src(&ssa_state, &mut new_phi[*succ_idx]);
- }
- }
- if let Some(s) = dom_tree[bb].get(*dom_succ_idx) {
- *dom_succ_idx += 1;
- state.push((*s, 0));
- } else {
- state.pop();
- pop_stacks(&mut ssa_state, &old_phi[bb], &old_dst_id[bb]);
- }
- } else {
- break;
- }
+) -> (Vec<ast::Statement<spirv::Word>>, NumericIdResolver) {
+ let mut id_defs = StringIdResolver::new();
+ for arg in args {
+ id_defs.add_def(arg.name, Some(ast::Type::Scalar(arg.a_type)));
}
- let phi = new_phi
- .into_iter()
- .map(|map| {
- map.into_iter()
- .map(|(_, (new_id, defs))| PhiDef {
- dst: new_id,
- src: defs,
- })
- .collect::<Vec<_>>()
- })
- .collect::<Vec<_>>();
- (phi, ssa_state.next_id())
-}
-
-// before ssa-renaming every phi is x <- phi(x,x,x,x)
-#[derive(Debug, PartialEq)]
-struct PhiDef {
- dst: spirv::Word,
- src: HashSet<spirv::Word>,
-}
-
-fn rename_phi_dst(
- rewriter: &mut SSARewriteState,
- phi: &mut HashMap<spirv::Word, (spirv::Word, HashSet<spirv::Word>)>,
-) {
- for (old_k, (new_k, _)) in phi.iter_mut() {
- *new_k = rewriter.redefine(*old_k);
+ let mut result = Vec::new();
+ for s in func {
+ expand_map_ids(&mut id_defs, &mut result, s);
}
+ (result, id_defs.finish())
}
-fn rename_bb_body(
- ssa_state: &mut SSARewriteState,
- func: &mut [Statement],
- all_bb: &[BasicBlock],
- bb: BBIndex,
+fn expand_map_ids<'a>(
+ id_defs: &mut StringIdResolver<'a>,
+ result: &mut Vec<ast::Statement<spirv::Word>>,
+ s: ast::Statement<&'a str>,
) {
- for s in get_bb_body_mut(func, all_bb, bb) {
- s.visit_id_mut(&mut |is_dst, id| {
- if is_dst {
- *id = ssa_state.redefine(*id);
- } else {
- *id = ssa_state.get(*id);
+ match s {
+ ast::Statement::Label(name) => {
+ result.push(ast::Statement::Label(id_defs.add_def(name, None)))
+ }
+ ast::Statement::Instruction(p, i) => result.push(ast::Statement::Instruction(
+ p.map(|p| p.map_id(&mut |id| id_defs.get_id(id))),
+ i.map_id1(&mut |id| id_defs.get_id(id)),
+ )),
+ ast::Statement::Variable(var) => match var.count {
+ Some(count) => {
+ for new_id in id_defs.add_defs(var.name, count, var.v_type) {
+ result.push(ast::Statement::Variable(ast::Variable {
+ space: var.space,
+ v_type: var.v_type,
+ name: new_id,
+ count: None,
+ }))
+ }
}
- });
+ None => {
+ let new_id = id_defs.add_def(var.name, Some(var.v_type));
+ result.push(ast::Statement::Variable(ast::Variable {
+ space: var.space,
+ v_type: var.v_type,
+ name: new_id,
+ count: None,
+ }));
+ }
+ },
}
}
-fn rename_succesor_phi_src(
- ssa_state: &SSARewriteState,
- phi: &mut HashMap<spirv::Word, (spirv::Word, HashSet<spirv::Word>)>,
-) {
- for (id, (_, v)) in phi.iter_mut() {
- v.insert(ssa_state.get(*id));
- }
+struct StringIdResolver<'a> {
+ current_id: spirv::Word,
+ variables: HashMap<Cow<'a, str>, spirv::Word>,
+ type_check: HashMap<u32, ast::Type>,
}
-fn pop_stacks(
- ssa_state: &mut SSARewriteState,
- old_phi: &HashSet<spirv::Word>,
- old_ids: &[spirv::Word],
-) {
- for id in old_phi.iter().chain(old_ids) {
- ssa_state.pop(*id);
+impl<'a> StringIdResolver<'a> {
+ fn new() -> Self {
+ StringIdResolver {
+ current_id: 0u32,
+ variables: HashMap::new(),
+ type_check: HashMap::new(),
+ }
}
-}
-fn get_bb_body_mut<'a>(
- func: &'a mut [Statement],
- all_bb: &[BasicBlock],
- bb: BBIndex,
-) -> &'a mut [Statement] {
- let (start, end) = get_bb_body_idx(func, all_bb, bb);
- &mut func[start..end]
-}
-
-fn get_bb_body<'a>(func: &'a [Statement], all_bb: &[BasicBlock], bb: BBIndex) -> &'a [Statement] {
- let (start, end) = get_bb_body_idx(func, all_bb, bb);
- &func[start..end]
-}
-
-fn get_bb_body_idx(func: &[Statement], all_bb: &[BasicBlock], bb: BBIndex) -> (usize, usize) {
- let BBIndex(bb_idx) = bb;
- let start = all_bb[bb_idx].start.0;
- let end = if bb_idx == all_bb.len() - 1 {
- func.len()
- } else {
- all_bb[bb_idx + 1].start.0
- };
- (start, end)
-}
-
-// We assume here that the variables are defined in the dense sequence 0..max
-struct SSARewriteState {
- next: spirv::Word,
- constant_ids: spirv::Word,
- stack: Vec<Vec<spirv::Word>>,
-}
-
-impl<'a> SSARewriteState {
- fn new(constant_ids: spirv::Word, all_ids: spirv::Word) -> Self {
- let to_redefine = all_ids - constant_ids;
- let stack = (0..to_redefine)
- .map(|x| vec![x + constant_ids])
- .collect::<Vec<_>>();
- SSARewriteState {
- next: all_ids,
- constant_ids: constant_ids,
- stack,
+ fn finish(self) -> NumericIdResolver {
+ NumericIdResolver {
+ current_id: self.current_id,
+ type_check: self.type_check,
}
}
- fn get(&self, x: spirv::Word) -> spirv::Word {
- if x < self.constant_ids {
- x
- } else {
- *self.stack[(x - self.constant_ids) as usize].last().unwrap()
- }
+ fn get_id(&self, id: &'a str) -> spirv::Word {
+ self.variables[id]
}
- fn redefine(&mut self, x: spirv::Word) -> spirv::Word {
- if x < self.constant_ids {
- x
- } else {
- let result = self.next;
- self.next += 1;
- self.stack[(x - self.constant_ids) as usize].push(result);
- result
+ #[must_use]
+ fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
+ let numeric_id = self.current_id;
+ self.variables.insert(Cow::Borrowed(id), numeric_id);
+ if let Some(typ) = typ {
+ self.type_check.insert(numeric_id, typ);
}
+ self.current_id += 1;
+ numeric_id
}
- fn pop(&mut self, x: spirv::Word) {
- if x >= self.constant_ids {
- self.stack[(x - self.constant_ids) as usize].pop();
+ #[must_use]
+ fn add_defs(
+ &mut self,
+ base_id: &'a str,
+ count: u32,
+ typ: ast::Type,
+ ) -> impl Iterator<Item = spirv::Word> {
+ let numeric_id = self.current_id;
+ for i in 0..count {
+ self.variables
+ .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
+ self.type_check.insert(numeric_id + i, typ);
}
+ self.current_id += count;
+ (0..count).into_iter().map(move |i| i + numeric_id)
}
+}
- fn next_id(&self) -> spirv::Word {
- self.next
- }
+struct NumericIdResolver {
+ current_id: spirv::Word,
+ type_check: HashMap<u32, ast::Type>,
}
-// "Engineering a Compiler" - Figure 9.9
-// Calculates semi-pruned phis
-fn gather_phi_sets(
- func: &[Statement],
- constant_ids: spirv::Word,
- all_ids: spirv::Word,
- cfg: &[BasicBlock],
- dom_fronts: &[HashSet<BBIndex>],
-) -> Vec<HashSet<spirv::Word>> {
- let mut result = vec![HashSet::new(); cfg.len()];
- let mut globals = HashSet::new();
- let mut blocks = vec![(Vec::new(), HashSet::new()); (all_ids - constant_ids) as usize];
- for bb in 0..cfg.len() {
- let mut var_kill = HashSet::new();
- let mut visitor = |is_dst, id: spirv::Word| {
- if id >= constant_ids {
- let id = id - constant_ids;
- if is_dst {
- var_kill.insert(id);
- let (ref mut stack, ref mut set) = blocks[id as usize];
- stack.push(BBIndex(bb));
- set.insert(BBIndex(bb));
- } else {
- if !var_kill.contains(&id) {
- globals.insert(id);
- }
- }
- }
- };
- // We try to avoid adding labels to the global-visbility set.
- // We are not 100% precise (we add jump targets in bra), but it shouldn't be a problem
- for s in get_bb_body(func, cfg, BBIndex(bb)) {
- match s {
- Statement::Instruction(inst) => inst.visit_id(&mut visitor),
- Statement::Conditional(brc) => visitor(false, brc.predicate),
- Statement::Converison(conv) => conv.visit_id(&mut visitor),
- Statement::Constant(cons) => cons.visit_id(&mut visitor),
- // label redefinition is a compile-time error
- Statement::Label(_) => (),
- }
- }
+impl NumericIdResolver {
+ fn get_type(&self, id: spirv::Word) -> ast::Type {
+ self.type_check[&id]
}
- for id in globals {
- let (ref mut work_stack, ref mut work_set) = &mut blocks[id as usize];
- loop {
- if let Some(bb) = work_stack.pop() {
- work_set.remove(&bb);
- for d_bb in &dom_fronts[bb.0] {
- if result[d_bb.0].insert(id + constant_ids) {
- if work_set.insert(*d_bb) {
- work_stack.push(*d_bb);
- }
- }
- }
- } else {
- break;
- }
- }
- }
- result
-}
-fn get_basic_blocks(fun: &[Statement]) -> Vec<BasicBlock> {
- // edge signify pred/succ relationship between bbs
- let mut unresolved_bb_edge = Vec::new();
- // bb start means that a bb is starting at this statement, but there's no predecessor
- let mut bb_start = Vec::new();
- let mut labels = HashMap::new();
- for (idx, s) in fun.iter().enumerate() {
- match s {
- Statement::Instruction(i) => {
- if let Some(id) = i.jump_target() {
- unresolved_bb_edge.push((StmtIndex(idx), id));
- if idx + 1 < fun.len() {
- bb_start.push(StmtIndex(idx + 1));
- }
- } else if i.is_terminal() && idx + 1 < fun.len() {
- bb_start.push(StmtIndex(idx + 1));
- }
- }
- Statement::Label(id) => {
- labels.insert(id, StmtIndex(idx));
- }
- Statement::Conditional(bra) => {
- unresolved_bb_edge.push((StmtIndex(idx), bra.if_false));
- unresolved_bb_edge.push((StmtIndex(idx), bra.if_true));
- }
- Statement::Constant(_) => (),
- Statement::Converison(_) => (),
- };
- }
- let mut bb_edge = HashSet::new();
- // Resolve every <jump into label> into <jump into statement index>
- // TODO: handle jumps into nowhere
- for (idx, id) in unresolved_bb_edge {
- let target = labels[&id];
- bb_edge.insert((idx, target));
- bb_start.push(target);
- // now check if there is an edge target-1 -> target
- if target != StmtIndex(0) {
- match &fun[target.0 - 1] {
- Statement::Instruction(i) => {
- if !(i.jump_target().is_some() || i.is_terminal()) {
- bb_edge.insert((StmtIndex(target.0 - 1), target));
- }
- }
- Statement::Converison(_) | Statement::Constant(_) | Statement::Label(_) => {
- bb_edge.insert((StmtIndex(target.0 - 1), target));
- }
- // This is already in `unresolved_bb_edge`
- Statement::Conditional(_) => (),
- }
+ fn new_id(&mut self, typ: Option<ast::Type>) -> spirv::Word {
+ let new_id = self.current_id;
+ if let Some(typ) = typ {
+ self.type_check.insert(new_id, typ);
}
+ self.current_id += 1;
+ new_id
}
- // Create list of bbs without succ/pred
- let mut bbs_map = BTreeMap::new();
- bbs_map.insert(
- StmtIndex(0),
- BasicBlock {
- start: StmtIndex(0),
- pred: Vec::new(),
- succ: Vec::new(),
- },
- );
- for bb_first_stmt in bb_start {
- bbs_map.entry(bb_first_stmt).or_insert_with(|| BasicBlock {
- start: bb_first_stmt,
- pred: Vec::new(),
- succ: Vec::new(),
- });
- }
- // Populate succ/pred
- let indexed_bbs_map = bbs_map
- .into_iter()
- .enumerate()
- .map(|(idx, (key, val))| (key, (BBIndex(idx), RefCell::new(val))))
- .collect::<BTreeMap<_, _>>();
- for (from, to) in bb_edge {
- let (_, (from_idx, from_ref)) = indexed_bbs_map.range(..=from).next_back().unwrap();
- let (to_idx, to_ref) = indexed_bbs_map.get(&to).unwrap();
- {
- from_ref.borrow_mut().succ.push(*to_idx);
- }
- {
- to_ref.borrow_mut().pred.push(*from_idx);
- }
- }
- indexed_bbs_map
- .into_iter()
- .map(|(_, (_, bb))| bb.into_inner())
- .collect::<Vec<_>>()
-}
-
-// "A Simple, Fast Dominance Algorithm" - Keith D. Cooper, Timothy J. Harvey, and Ken Kennedy
-// https://www.cs.rice.edu/~keith/EMBED/dom.pdf
-fn dominance_frontiers(bbs: &[BasicBlock], doms: &[BBIndex]) -> Vec<HashSet<BBIndex>> {
- let mut result = vec![HashSet::new(); bbs.len()];
- for (bb_idx, b) in bbs.iter().enumerate() {
- if b.pred.len() < 2 {
- continue;
- }
- for p in b.pred.iter() {
- let mut runner = *p;
- while runner != doms[bb_idx] {
- result[runner.0].insert(BBIndex(bb_idx));
- runner = doms[runner.0];
- }
- }
+
+ fn ids_count(&self) -> spirv::Word {
+ self.current_id
}
- result
}
-fn immediate_dominators(bbs: &Vec<BasicBlock>, order: &Vec<BBIndex>) -> Vec<BBIndex> {
- let undefined = BBIndex(usize::max_value());
- let mut doms = vec![undefined; bbs.len()];
- doms[0] = BBIndex(0);
- let mut changed = true;
- while changed {
- changed = false;
- for BBIndex(bb_idx) in order.iter().skip(1) {
- let bb = &bbs[*bb_idx];
- if let Some(first_pred) = bb.pred.iter().find(|bb| doms[bb.0] != undefined) {
- let mut new_idom = *first_pred;
- for BBIndex(p_idx) in bb.pred.iter().copied().filter(|bb| bb != first_pred) {
- if doms[p_idx] != BBIndex(usize::max_value()) {
- new_idom = intersect(&mut doms, BBIndex(p_idx), new_idom);
- }
- }
- if doms[*bb_idx] != new_idom {
- doms[*bb_idx] = new_idom;
- changed = true;
- }
- }
- }
- }
- return doms;
+enum Statement<A: Args> {
+ Variable(spirv::Word, ast::Type),
+ LoadVar(Arg2, ast::Type),
+ StoreVar(Arg2St, ast::Type),
+ Label(u32),
+ Instruction(Instruction<A>),
+ // SPIR-V compatible replacement for PTX predicates
+ Conditional(BrachCondition),
+ Converison(ImplicitConversion),
+ Constant(ConstantDefinition),
}
-// Original paper uses reverse indexing: their entry node has index n,
-// that's why the compares are reversed
-fn intersect(doms: &mut Vec<BBIndex>, b1: BBIndex, b2: BBIndex) -> BBIndex {
- let mut finger1 = b1;
- let mut finger2 = b2;
- while finger1 != finger2 {
- while finger1 > finger2 {
- finger1 = doms[finger1.0];
- }
- while finger2 > finger1 {
- finger2 = doms[finger2.0];
+impl<A: Args> Statement<A> {
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
+ match self {
+ Statement::Variable(id, _) => f(true, id),
+ Statement::LoadVar(a, _) => a.visit_id_mut(f),
+ Statement::StoreVar(a, _) => a.visit_id_mut(f),
+ Statement::Label(id) => f(false, id),
+ Statement::Instruction(inst) => inst.visit_id_mut(f),
+ Statement::Conditional(bra) => bra.visit_id_mut(f),
+ Statement::Converison(conv) => conv.visit_id_mut(f),
+ Statement::Constant(cons) => cons.visit_id_mut(f),
}
}
- finger1
-}
-// "A Simple Algorithm for Global Data Flow Analysis Problems" - Hecht, M. S., & Ullman, J. D. (1975)
-fn to_reverse_postorder(input: &Vec<BasicBlock>) -> Vec<BBIndex> {
- let mut i = input.len();
- let mut old = BitVec::from_elem(input.len(), false);
- let mut result = vec![BBIndex(usize::max_value()); input.len()];
- // original uses recursion and implicit stack, we do it explictly
- let mut state = Vec::new();
- state.push((BBIndex(0), 0usize));
- loop {
- if let Some((BBIndex(bb), succ_iter_idx)) = state.last_mut() {
- let bb = *bb;
- if *succ_iter_idx == 0 {
- old.set(bb, true);
- }
- if let Some(BBIndex(succ)) = &input[bb].succ.get(*succ_iter_idx) {
- *succ_iter_idx += 1;
- if !old.get(*succ).unwrap() {
- state.push((BBIndex(*succ), 0));
- }
- } else {
- state.pop();
- i = i - 1;
- result[i] = BBIndex(bb);
- }
- } else {
- break;
- }
+ fn get_type(&self) -> Option<ast::Type> {
+ todo!()
}
- result
}
-#[derive(Eq, PartialEq, Debug, Clone)]
-struct BasicBlock {
- start: StmtIndex,
- pred: Vec<BBIndex>,
- succ: Vec<BBIndex>,
+trait Args {
+ type Arg1: Arg;
+ type Arg2: Arg;
+ type Arg2St: Arg;
+ type Arg2Mov: Arg;
+ type Arg3: Arg;
+ type Arg4: Arg;
+ type Arg5: Arg;
}
-#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)]
-struct StmtIndex(pub usize);
-
-impl fmt::Display for StmtIndex {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- self.0.fmt(f)
- }
+trait Arg {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F);
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F);
}
-#[derive(Eq, PartialEq, Debug, Copy, Clone, PartialOrd, Ord, Hash)]
-struct BBIndex(pub usize);
+enum NormalizedArgs {}
-impl fmt::Display for BBIndex {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- self.0.fmt(f)
- }
+impl Args for NormalizedArgs {
+ type Arg1 = ast::Arg1<spirv::Word>;
+ type Arg2 = ast::Arg2<spirv::Word>;
+ type Arg2St = ast::Arg2St<spirv::Word>;
+ type Arg2Mov = ast::Arg2Mov<spirv::Word>;
+ type Arg3 = ast::Arg3<spirv::Word>;
+ type Arg4 = ast::Arg4<spirv::Word>;
+ type Arg5 = ast::Arg5<spirv::Word>;
}
-enum Statement {
- Label(u32),
- Instruction(Instruction),
- // SPIR-V compatible replacement for PTX predicates
- Conditional(BrachCondition),
- Converison(ImplicitConversion),
- Constant(ConstantDefinition),
+enum ExpandedArgs {}
+
+impl Args for ExpandedArgs {
+ type Arg1 = Arg1;
+ type Arg2 = Arg2;
+ type Arg2St = Arg2St;
+ type Arg2Mov = Arg2;
+ type Arg3 = Arg3;
+ type Arg4 = Arg4;
+ type Arg5 = Arg5;
}
-enum Instruction {
- Ld(ast::LdData, Arg2),
- Mov(ast::MovData, Arg2),
- Mul(ast::MulData, Arg3),
- Add(ast::AddData, Arg3),
- Setp(ast::SetpData, Arg4),
- SetpBool(ast::SetpBoolData, Arg5),
- Not(ast::NotData, Arg2),
- Bra(ast::BraData, Arg1),
- Cvt(ast::CvtData, Arg2),
- Shl(ast::ShlData, Arg3),
- St(ast::StData, Arg2St),
+type NormalizedStatement = Statement<NormalizedArgs>;
+type ExpandedStatement = Statement<ExpandedArgs>;
+
+enum Instruction<A: Args> {
+ Ld(ast::LdData, A::Arg2),
+ Mov(ast::MovData, A::Arg2Mov),
+ Mul(ast::MulData, A::Arg3),
+ Add(ast::AddData, A::Arg3),
+ Setp(ast::SetpData, A::Arg4),
+ SetpBool(ast::SetpBoolData, A::Arg5),
+ Not(ast::NotData, A::Arg2),
+ Bra(ast::BraData, A::Arg1),
+ Cvt(ast::CvtData, A::Arg2),
+ Shl(ast::ShlData, A::Arg3),
+ St(ast::StData, A::Arg2St),
Ret(ast::RetData),
}
-impl Instruction {
- fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
- match self {
- Instruction::Ld(_, a) => a.visit_id(f),
- Instruction::Mov(_, a) => a.visit_id(f),
- Instruction::Mul(_, a) => a.visit_id(f),
- Instruction::Add(_, a) => a.visit_id(f),
- Instruction::Setp(_, a) => a.visit_id(f),
- Instruction::SetpBool(_, a) => a.visit_id(f),
- Instruction::Not(_, a) => a.visit_id(f),
- Instruction::Cvt(_, a) => a.visit_id(f),
- Instruction::Shl(_, a) => a.visit_id(f),
- Instruction::St(_, a) => a.visit_id(f),
- Instruction::Bra(_, a) => a.visit_id(f),
- Instruction::Ret(_) => (),
- }
- }
-
+impl<A: Args> Instruction<A> {
fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
match self {
Instruction::Ld(_, a) => a.visit_id_mut(f),
@@ -1324,9 +973,9 @@ impl Instruction {
}
}
- fn jump_target(&self) -> Option<spirv::Word> {
+ fn is_terminal(&self) -> bool {
match self {
- Instruction::Bra(_, a) => Some(a.src),
+ Instruction::Ret(_) => true,
Instruction::Ld(_, _)
| Instruction::Mov(_, _)
| Instruction::Mul(_, _)
@@ -1337,13 +986,51 @@ impl Instruction {
| Instruction::Cvt(_, _)
| Instruction::Shl(_, _)
| Instruction::St(_, _)
- | Instruction::Ret(_) => None,
+ | Instruction::Bra(_, _) => false,
}
}
+}
- fn is_terminal(&self) -> bool {
+impl Instruction<NormalizedArgs> {
+ fn from_ast(s: ast::Instruction<spirv::Word>) -> Self {
+ match s {
+ ast::Instruction::Ld(d, a) => Instruction::Ld(d, a),
+ ast::Instruction::Mov(d, a) => Instruction::Mov(d, a),
+ ast::Instruction::Mul(d, a) => Instruction::Mul(d, a),
+ ast::Instruction::Add(d, a) => Instruction::Add(d, a),
+ ast::Instruction::Setp(d, a) => Instruction::Setp(d, a),
+ ast::Instruction::SetpBool(d, a) => Instruction::SetpBool(d, a),
+ ast::Instruction::Not(d, a) => Instruction::Not(d, a),
+ ast::Instruction::Cvt(d, a) => Instruction::Cvt(d, a),
+ ast::Instruction::Shl(d, a) => Instruction::Shl(d, a),
+ ast::Instruction::St(d, a) => Instruction::St(d, a),
+ ast::Instruction::Bra(d, a) => Instruction::Bra(d, a),
+ ast::Instruction::Ret(d) => Instruction::Ret(d),
+ }
+ }
+}
+
+impl Instruction<ExpandedArgs> {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
match self {
- Instruction::Ret(_) => true,
+ Instruction::Ld(_, a) => a.visit_id(f),
+ Instruction::Mov(_, a) => a.visit_id(f),
+ Instruction::Mul(_, a) => a.visit_id(f),
+ Instruction::Add(_, a) => a.visit_id(f),
+ Instruction::Setp(_, a) => a.visit_id(f),
+ Instruction::SetpBool(_, a) => a.visit_id(f),
+ Instruction::Not(_, a) => a.visit_id(f),
+ Instruction::Cvt(_, a) => a.visit_id(f),
+ Instruction::Shl(_, a) => a.visit_id(f),
+ Instruction::St(_, a) => a.visit_id(f),
+ Instruction::Bra(_, a) => a.visit_id(f),
+ Instruction::Ret(_) => (),
+ }
+ }
+
+ fn jump_target(&self) -> Option<spirv::Word> {
+ match self {
+ Instruction::Bra(_, a) => Some(a.src),
Instruction::Ld(_, _)
| Instruction::Mov(_, _)
| Instruction::Mul(_, _)
@@ -1354,7 +1041,7 @@ impl Instruction {
| Instruction::Cvt(_, _)
| Instruction::Shl(_, _)
| Instruction::St(_, _)
- | Instruction::Bra(_, _) => false,
+ | Instruction::Ret(_) => None,
}
}
}
@@ -1363,7 +1050,7 @@ struct Arg1 {
pub src: spirv::Word,
}
-impl Arg1 {
+impl Arg for Arg1 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(false, self.src);
}
@@ -1378,7 +1065,7 @@ struct Arg2 {
pub src: spirv::Word,
}
-impl Arg2 {
+impl Arg for Arg2 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst);
f(false, self.src);
@@ -1395,7 +1082,7 @@ pub struct Arg2St {
pub src2: spirv::Word,
}
-impl Arg2St {
+impl Arg for Arg2St {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(false, self.src1);
f(false, self.src2);
@@ -1413,7 +1100,7 @@ struct Arg3 {
pub src2: spirv::Word,
}
-impl Arg3 {
+impl Arg for Arg3 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst);
f(false, self.src1);
@@ -1434,7 +1121,7 @@ struct Arg4 {
pub src2: spirv::Word,
}
-impl Arg4 {
+impl Arg for Arg4 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst1);
self.dst2.map(|dst2| f(true, dst2));
@@ -1458,7 +1145,7 @@ struct Arg5 {
pub src3: spirv::Word,
}
-impl Arg5 {
+impl Arg for Arg5 {
fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
f(true, self.dst1);
self.dst2.map(|dst2| f(true, dst2));
@@ -1540,44 +1227,6 @@ impl ImplicitConversion {
}
}
-impl Statement {
- fn from_ast<'a, F: FnMut(&'a str) -> u32>(
- s: ast::Statement<&'a str>,
- get_id: &mut F,
- ) -> Option<ast::Statement<spirv::Word>> {
- match s {
- ast::Statement::Label(name) => Some(ast::Statement::Label(get_id(name))),
- ast::Statement::Instruction(p, i) => Some(ast::Statement::Instruction(
- p.map(|p| p.map_id(get_id)),
- i.map_id(get_id),
- )),
- ast::Statement::Variable(_) => None,
- }
- }
-
- fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
- match self {
- Statement::Label(id) => f(false, *id),
- Statement::Instruction(inst) => inst.visit_id(f),
- Statement::Conditional(bra) => bra.visit_id(f),
- Statement::Converison(conv) => conv.visit_id(f),
- Statement::Constant(cons) => cons.visit_id(f),
- }
- }
-
- // WARNING: It is very important to first visit src operands and then dst operands,
- // otherwise SSA renaming will yield weird results
- fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
- match self {
- Statement::Label(id) => f(false, id),
- Statement::Instruction(inst) => inst.visit_id_mut(f),
- Statement::Conditional(bra) => bra.visit_id_mut(f),
- Statement::Converison(conv) => conv.visit_id_mut(f),
- Statement::Constant(cons) => cons.visit_id_mut(f),
- }
- }
-}
-
impl<T> ast::PredAt<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::PredAt<U> {
ast::PredAt {
@@ -1588,7 +1237,7 @@ impl<T> ast::PredAt<T> {
}
impl<T> ast::Instruction<T> {
- fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> {
+ fn map_id1<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Instruction<U> {
match self {
ast::Instruction::Ld(d, a) => ast::Instruction::Ld(d, a.map_id(f)),
ast::Instruction::Mov(d, a) => ast::Instruction::Mov(d, a.map_id(f)),
@@ -1605,9 +1254,28 @@ impl<T> ast::Instruction<T> {
}
}
- fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
+ fn map_id<F: FnMut(T) -> spirv::Word>(self, f: &mut F) -> Instruction<NormalizedArgs> {
match self {
- ast::Instruction::Ld(_, a) => a.visit_id(f),
+ ast::Instruction::Ld(d, a) => Instruction::Ld(d, a.map_id(f)),
+ ast::Instruction::Mov(d, a) => Instruction::Mov(d, a.map_id(f)),
+ ast::Instruction::Mul(d, a) => Instruction::Mul(d, a.map_id(f)),
+ ast::Instruction::Add(d, a) => Instruction::Add(d, a.map_id(f)),
+ ast::Instruction::Setp(d, a) => Instruction::Setp(d, a.map_id(f)),
+ ast::Instruction::SetpBool(d, a) => Instruction::SetpBool(d, a.map_id(f)),
+ ast::Instruction::Not(d, a) => Instruction::Not(d, a.map_id(f)),
+ ast::Instruction::Bra(d, a) => Instruction::Bra(d, a.map_id(f)),
+ ast::Instruction::Cvt(d, a) => Instruction::Cvt(d, a.map_id(f)),
+ ast::Instruction::Shl(d, a) => Instruction::Shl(d, a.map_id(f)),
+ ast::Instruction::St(d, a) => Instruction::St(d, a.map_id(f)),
+ ast::Instruction::Ret(d) => Instruction::Ret(d),
+ }
+ }
+}
+
+impl ast::Instruction<spirv::Word> {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ match self {
+ ast::Instruction::Ld(_, a) => Arg::visit_id(a, f),
ast::Instruction::Mov(_, a) => a.visit_id(f),
ast::Instruction::Mul(_, a) => a.visit_id(f),
ast::Instruction::Add(_, a) => a.visit_id(f),
@@ -1622,7 +1290,7 @@ impl<T> ast::Instruction<T> {
}
}
- fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
match self {
ast::Instruction::Ld(_, a) => a.visit_id_mut(f),
ast::Instruction::Mov(_, a) => a.visit_id_mut(f),
@@ -1692,12 +1360,14 @@ impl<T> ast::Arg1<T> {
fn map_id<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::Arg1<U> {
ast::Arg1 { src: f(self.src) }
}
+}
- fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
- f(false, &self.src);
+impl Arg for ast::Arg1<spirv::Word> {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(false, self.src);
}
- fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
f(false, &mut self.src);
}
}
@@ -1709,13 +1379,15 @@ impl<T> ast::Arg2<T> {
src: self.src.map_id(f),
}
}
+}
- fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
- f(true, &self.dst);
+impl Arg for ast::Arg2<spirv::Word> {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(true, self.dst);
self.src.visit_id(f);
}
- fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
self.src.visit_id_mut(f);
f(true, &mut self.dst);
}
@@ -1728,13 +1400,15 @@ impl<T> ast::Arg2St<T> {
src2: self.src2.map_id(f),
}
}
+}
- fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
+impl Arg for ast::Arg2St<spirv::Word> {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
self.src1.visit_id(f);
self.src2.visit_id(f);
}
- fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
self.src1.visit_id_mut(f);
self.src2.visit_id_mut(f);
}
@@ -1747,13 +1421,15 @@ impl<T> ast::Arg2Mov<T> {
src: self.src.map_id(f),
}
}
+}
- fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
- f(true, &self.dst);
+impl Arg for ast::Arg2Mov<spirv::Word> {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(true, self.dst);
self.src.visit_id(f);
}
- fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
self.src.visit_id_mut(f);
f(true, &mut self.dst);
}
@@ -1767,14 +1443,16 @@ impl<T> ast::Arg3<T> {
src2: self.src2.map_id(f),
}
}
+}
- fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
- f(true, &self.dst);
+impl Arg for ast::Arg3<spirv::Word> {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(true, self.dst);
self.src1.visit_id(f);
self.src2.visit_id(f);
}
- fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
self.src1.visit_id_mut(f);
self.src2.visit_id_mut(f);
f(true, &mut self.dst);
@@ -1790,15 +1468,17 @@ impl<T> ast::Arg4<T> {
src2: self.src2.map_id(f),
}
}
+}
- fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
- f(true, &self.dst1);
- self.dst2.as_ref().map(|i| f(true, i));
+impl Arg for ast::Arg4<spirv::Word> {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(true, self.dst1);
+ self.dst2.map(|i| f(true, i));
self.src1.visit_id(f);
self.src2.visit_id(f);
}
- fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
self.src1.visit_id_mut(f);
self.src2.visit_id_mut(f);
f(true, &mut self.dst1);
@@ -1816,16 +1496,18 @@ impl<T> ast::Arg5<T> {
src3: self.src3.map_id(f),
}
}
+}
- fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
- f(true, &self.dst1);
- self.dst2.as_ref().map(|i| f(true, i));
+impl Arg for ast::Arg5<spirv::Word> {
+ fn visit_id<F: FnMut(bool, spirv::Word)>(&self, f: &mut F) {
+ f(true, self.dst1);
+ self.dst2.map(|i| f(true, i));
self.src1.visit_id(f);
self.src2.visit_id(f);
self.src3.visit_id(f);
}
- fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
+ fn visit_id_mut<F: FnMut(bool, &mut spirv::Word)>(&mut self, f: &mut F) {
self.src1.visit_id_mut(f);
self.src2.visit_id_mut(f);
self.src3.visit_id_mut(f);
@@ -1842,11 +1524,13 @@ impl<T> ast::Operand<T> {
ast::Operand::Imm(v) => ast::Operand::Imm(v),
}
}
+}
- fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
+impl<T: Copy> ast::Operand<T> {
+ fn visit_id<F: FnMut(bool, T)>(&self, f: &mut F) {
match self {
- ast::Operand::Reg(i) => f(false, i),
- ast::Operand::RegOffset(i, _) => f(false, i),
+ ast::Operand::Reg(i) => f(false, *i),
+ ast::Operand::RegOffset(i, _) => f(false, *i),
ast::Operand::Imm(_) => (),
}
}
@@ -1867,18 +1551,20 @@ impl<T> ast::MovOperand<T> {
ast::MovOperand::Vec(s1, s2) => ast::MovOperand::Vec(s1, s2),
}
}
+}
- fn visit_id<F: FnMut(bool, &T)>(&self, f: &mut F) {
+impl<T: Copy> ast::MovOperand<T> {
+ fn visit_id<F: FnMut(bool, T)>(&self, f: &mut F) {
match self {
ast::MovOperand::Op(o) => o.visit_id(f),
- ast::MovOperand::Vec(_, _) => (),
+ ast::MovOperand::Vec(_, _) => todo!(),
}
}
fn visit_id_mut<F: FnMut(bool, &mut T)>(&mut self, f: &mut F) {
match self {
ast::MovOperand::Op(o) => o.visit_id_mut(f),
- ast::MovOperand::Vec(_, _) => (),
+ ast::MovOperand::Vec(_, _) => todo!(),
}
}
}
@@ -2007,19 +1693,17 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
}
}
-fn insert_implicit_conversions_ld_src<TypeCheck: Fn(spirv::Word) -> ast::Type>(
- func: &mut Vec<Statement>,
+fn insert_implicit_conversions_ld_src(
+ func: &mut Vec<ExpandedStatement>,
instr_type: ast::Type,
- type_check: &TypeCheck,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ id_def: &mut NumericIdResolver,
state_space: ast::LdStateSpace,
src: spirv::Word,
) -> spirv::Word {
match state_space {
ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl(
func,
- type_check,
- new_id,
+ id_def,
instr_type,
src,
should_convert_ld_param_src,
@@ -2031,15 +1715,14 @@ fn insert_implicit_conversions_ld_src<TypeCheck: Fn(spirv::Word) -> ast::Type>(
));
let new_src = insert_implicit_conversions_ld_src_impl(
func,
- type_check,
- new_id,
+ id_def,
new_src_type,
src,
should_convert_ld_generic_src_to_bitcast,
);
insert_conversion_src(
func,
- new_id,
+ id_def,
new_src,
new_src_type,
instr_type,
@@ -2051,19 +1734,17 @@ fn insert_implicit_conversions_ld_src<TypeCheck: Fn(spirv::Word) -> ast::Type>(
}
fn insert_implicit_conversions_ld_src_impl<
- TypeCheck: Fn(spirv::Word) -> ast::Type,
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
>(
- func: &mut Vec<Statement>,
- type_check: &TypeCheck,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
instr_type: ast::Type,
src: spirv::Word,
should_convert: ShouldConvert,
) -> spirv::Word {
- let src_type = type_check(src);
+ let src_type = id_def.get_type(src);
if let Some(conv) = should_convert(src_type, instr_type) {
- insert_conversion_src(func, new_id, src, src_type, instr_type, conv)
+ insert_conversion_src(func, id_def, src, src_type, instr_type, conv)
} else {
src
}
@@ -2096,14 +1777,14 @@ fn should_convert_ld_generic_src_to_bitcast(
#[must_use]
fn insert_conversion_src(
- func: &mut Vec<Statement>,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
src: spirv::Word,
src_type: ast::Type,
instr_type: ast::Type,
conv: ConversionKind,
) -> spirv::Word {
- let temp_src = new_id(Some(instr_type));
+ let temp_src = id_def.new_id(Some(instr_type));
func.push(Statement::Converison(ImplicitConversion {
src: src,
dst: temp_src,
@@ -2116,24 +1797,22 @@ fn insert_conversion_src(
fn insert_with_implicit_conversion_dst<
T,
- TypeCheck: Fn(spirv::Word) -> ast::Type,
ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word,
- ToInstruction: FnOnce(T) -> Instruction,
+ ToInstruction: FnOnce(T) -> Instruction<ExpandedArgs>,
>(
- func: &mut Vec<Statement>,
+ func: &mut Vec<ExpandedStatement>,
instr_type: ast::ScalarType,
- type_check: &TypeCheck,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ id_def: &mut NumericIdResolver,
should_convert: ShouldConvert,
mut t: T,
setter: Setter,
to_inst: ToInstruction,
) {
let dst = setter(&mut t);
- let dst_type = type_check(*dst);
+ let dst_type = id_def.get_type(*dst);
let dst_coercion = should_convert(dst_type, instr_type)
- .map(|conv| get_conversion_dst(new_id, dst, ast::Type::Scalar(instr_type), dst_type, conv));
+ .map(|conv| get_conversion_dst(id_def, dst, ast::Type::Scalar(instr_type), dst_type, conv));
func.push(Statement::Instruction(to_inst(t)));
if let Some(conv) = dst_coercion {
func.push(conv);
@@ -2142,14 +1821,14 @@ fn insert_with_implicit_conversion_dst<
#[must_use]
fn get_conversion_dst(
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+ id_def: &mut NumericIdResolver,
dst: &mut spirv::Word,
instr_type: ast::Type,
dst_type: ast::Type,
kind: ConversionKind,
-) -> Statement {
+) -> ExpandedStatement {
let original_dst = *dst;
- let temp_dst = new_id(Some(instr_type));
+ let temp_dst = id_def.new_id(Some(instr_type));
*dst = temp_dst;
Statement::Converison(ImplicitConversion {
src: temp_dst,
@@ -2245,20 +1924,19 @@ fn should_convert_relaxed_dst(
}
}
-fn insert_implicit_bitcasts<TypeCheck: Fn(spirv::Word) -> ast::Type>(
- func: &mut Vec<Statement>,
- type_check: &TypeCheck,
- new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
- mut instr: Instruction,
+fn insert_implicit_bitcasts(
+ func: &mut Vec<ExpandedStatement>,
+ id_def: &mut NumericIdResolver,
+ mut instr: Instruction<ExpandedArgs>,
) {
let mut dst_coercion = None;
if let Some(instr_type) = instr.get_type() {
instr.visit_id_mut(&mut |is_dst, id| {
- let id_type = type_check(*id);
- if should_bitcast(instr_type, type_check(*id)) {
+ let id_type = id_def.get_type(*id);
+ if should_bitcast(instr_type, id_def.get_type(*id)) {
if is_dst {
dst_coercion = Some(get_conversion_dst(
- new_id,
+ id_def,
id,
instr_type,
id_type,
@@ -2267,7 +1945,7 @@ fn insert_implicit_bitcasts<TypeCheck: Fn(spirv::Word) -> ast::Type>(
} else {
*id = insert_conversion_src(
func,
- new_id,
+ id_def,
*id,
id_type,
instr_type,
@@ -2290,724 +1968,6 @@ mod tests {
use crate::ast;
use crate::ptx;
- // page 411
- #[test]
- fn to_reverse_postorder1() {
- let input = vec![
- BasicBlock {
- // A
- start: StmtIndex(0),
- pred: vec![],
- succ: vec![BBIndex(1), BBIndex(2)],
- },
- BasicBlock {
- // B
- start: StmtIndex(1),
- pred: vec![BBIndex(0), BBIndex(11)],
- succ: vec![BBIndex(3), BBIndex(6)],
- },
- BasicBlock {
- // C
- start: StmtIndex(2),
- pred: vec![BBIndex(0), BBIndex(4)],
- succ: vec![BBIndex(4), BBIndex(7)],
- },
- BasicBlock {
- // D
- start: StmtIndex(3),
- pred: vec![BBIndex(1)],
- succ: vec![BBIndex(5), BBIndex(6)],
- },
- BasicBlock {
- // E
- start: StmtIndex(4),
- pred: vec![BBIndex(2)],
- succ: vec![BBIndex(2), BBIndex(7)],
- },
- BasicBlock {
- // F
- start: StmtIndex(5),
- pred: vec![BBIndex(3)],
- succ: vec![BBIndex(8), BBIndex(10)],
- },
- BasicBlock {
- // G
- start: StmtIndex(6),
- pred: vec![BBIndex(1), BBIndex(3)],
- succ: vec![BBIndex(9)],
- },
- BasicBlock {
- // H
- start: StmtIndex(7),
- pred: vec![BBIndex(2), BBIndex(4)],
- succ: vec![BBIndex(12)],
- },
- BasicBlock {
- // I
- start: StmtIndex(8),
- pred: vec![BBIndex(5), BBIndex(9)],
- succ: vec![BBIndex(11)],
- },
- BasicBlock {
- // J
- start: StmtIndex(9),
- pred: vec![BBIndex(6)],
- succ: vec![BBIndex(8)],
- },
- BasicBlock {
- // K
- start: StmtIndex(10),
- pred: vec![BBIndex(5)],
- succ: vec![BBIndex(11)],
- },
- BasicBlock {
- // L
- start: StmtIndex(11),
- pred: vec![BBIndex(8), BBIndex(10)],
- succ: vec![BBIndex(1), BBIndex(12)],
- },
- BasicBlock {
- // M
- start: StmtIndex(12),
- pred: vec![BBIndex(7), BBIndex(11)],
- succ: vec![],
- },
- ];
- let rpostord = to_reverse_postorder(&input);
- assert_eq!(
- rpostord,
- vec![
- BBIndex(0), // A
- BBIndex(2), // C
- BBIndex(4), // E
- BBIndex(7), // H
- BBIndex(1), // B
- BBIndex(3), // D
- BBIndex(6), // G
- BBIndex(9), // J
- BBIndex(5), // F
- BBIndex(10), // K
- BBIndex(8), // I
- BBIndex(11), // L
- BBIndex(12), // M
- ]
- );
- }
-
- #[test]
- fn get_basic_blocks_empty() {
- let func = Vec::new();
- let bbs = get_basic_blocks(&func);
- assert_eq!(
- bbs,
- vec![BasicBlock {
- start: StmtIndex(0),
- pred: vec![],
- succ: vec![],
- }]
- );
- }
-
- #[test]
- fn get_basic_blocks_miniloop() {
- let func = vec![
- Statement::Label(12),
- Statement::Instruction(Instruction::Bra(
- ast::BraData { uniform: false },
- Arg1 { src: 12 },
- )),
- ];
- let bbs = get_basic_blocks(&func);
- assert_eq!(
- bbs,
- vec![BasicBlock {
- start: StmtIndex(0),
- pred: vec![BBIndex(0)],
- succ: vec![BBIndex(0)],
- }]
- );
- }
-
- // "A Simple, Fast Dominance Algorithm" - Fig. 4
- fn simple_fast_dom_fig4() -> Vec<BasicBlock> {
- vec![
- BasicBlock {
- start: StmtIndex(6),
- pred: vec![],
- succ: vec![BBIndex(1), BBIndex(2)],
- },
- BasicBlock {
- start: StmtIndex(5),
- pred: vec![BBIndex(0)],
- succ: vec![BBIndex(5)],
- },
- BasicBlock {
- start: StmtIndex(4),
- pred: vec![BBIndex(0)],
- succ: vec![BBIndex(3), BBIndex(4)],
- },
- BasicBlock {
- start: StmtIndex(3),
- pred: vec![BBIndex(2), BBIndex(4)],
- succ: vec![BBIndex(4)],
- },
- BasicBlock {
- start: StmtIndex(2),
- pred: vec![BBIndex(2), BBIndex(3), BBIndex(5)],
- succ: vec![BBIndex(3), BBIndex(5)],
- },
- BasicBlock {
- start: StmtIndex(1),
- pred: vec![BBIndex(1), BBIndex(4)],
- succ: vec![BBIndex(4)],
- },
- ]
- }
-
- #[test]
- fn immediate_dominators1() {
- let input = simple_fast_dom_fig4();
- let reverse_postorder = vec![
- BBIndex(0),
- BBIndex(1),
- BBIndex(2),
- BBIndex(3),
- BBIndex(4),
- BBIndex(5),
- ];
- let imm_dominators = immediate_dominators(&input, &reverse_postorder);
- assert_eq!(
- imm_dominators,
- vec![
- BBIndex(0),
- BBIndex(0),
- BBIndex(0),
- BBIndex(0),
- BBIndex(0),
- BBIndex(0)
- ]
- );
- }
-
- // page 411
- #[test]
- fn immediate_dominators2() {
- let input = vec![
- BasicBlock {
- // A
- start: StmtIndex(0),
- pred: vec![],
- succ: vec![BBIndex(1), BBIndex(2)],
- },
- BasicBlock {
- // B
- start: StmtIndex(1),
- pred: vec![BBIndex(0), BBIndex(11)],
- succ: vec![BBIndex(3), BBIndex(6)],
- },
- BasicBlock {
- // C
- start: StmtIndex(2),
- pred: vec![BBIndex(0), BBIndex(4)],
- succ: vec![BBIndex(4), BBIndex(7)],
- },
- BasicBlock {
- // D
- start: StmtIndex(3),
- pred: vec![BBIndex(1)],
- succ: vec![BBIndex(5), BBIndex(6)],
- },
- BasicBlock {
- // E
- start: StmtIndex(4),
- pred: vec![BBIndex(2)],
- succ: vec![BBIndex(2), BBIndex(7)],
- },
- BasicBlock {
- // F
- start: StmtIndex(5),
- pred: vec![BBIndex(3)],
- succ: vec![BBIndex(8), BBIndex(10)],
- },
- BasicBlock {
- // G
- start: StmtIndex(6),
- pred: vec![BBIndex(1), BBIndex(3)],
- succ: vec![BBIndex(9)],
- },
- BasicBlock {
- // H
- start: StmtIndex(7),
- pred: vec![BBIndex(2), BBIndex(4)],
- succ: vec![BBIndex(12)],
- },
- BasicBlock {
- // I
- start: StmtIndex(8),
- pred: vec![BBIndex(5), BBIndex(9)],
- succ: vec![BBIndex(11)],
- },
- BasicBlock {
- // J
- start: StmtIndex(9),
- pred: vec![BBIndex(6)],
- succ: vec![BBIndex(8)],
- },
- BasicBlock {
- // K
- start: StmtIndex(10),
- pred: vec![BBIndex(5)],
- succ: vec![BBIndex(11)],
- },
- BasicBlock {
- // L
- start: StmtIndex(11),
- pred: vec![BBIndex(8), BBIndex(10)],
- succ: vec![BBIndex(1), BBIndex(12)],
- },
- BasicBlock {
- // M
- start: StmtIndex(12),
- pred: vec![BBIndex(7), BBIndex(11)],
- succ: vec![],
- },
- ];
- let reverse_postorder = vec![
- BBIndex(0), // A
- BBIndex(2), // C
- BBIndex(4), // E
- BBIndex(7), // H
- BBIndex(1), // B
- BBIndex(3), // D
- BBIndex(6), // G
- BBIndex(9), // J
- BBIndex(5), // F
- BBIndex(10), // K
- BBIndex(8), // I
- BBIndex(11), // L
- BBIndex(12), // M
- ];
- let imm_dominators = immediate_dominators(&input, &reverse_postorder);
- assert_eq!(
- imm_dominators,
- vec![
- BBIndex(0),
- BBIndex(0),
- BBIndex(0),
- BBIndex(1),
- BBIndex(2),
- BBIndex(3),
- BBIndex(1),
- BBIndex(2),
- BBIndex(1),
- BBIndex(6),
- BBIndex(5),
- BBIndex(1),
- BBIndex(0)
- ]
- );
- }
-
- fn sort_pred_succ(bb: &mut BasicBlock) {
- bb.pred.sort();
- bb.succ.sort();
- }
-
- // page 403
- const FIG_19_4: &'static str = "{
- .reg.u32 i;
- .reg.u32 j;
- .reg.u32 k;
- .reg.pred p;
- .reg.pred q;
-
- mov.u32 i, 1;
- mov.u32 j, 1;
- mov.u32 k, 0;
- block_2:
- setp.ge.u32 p, k, 100;
- @p bra block_4; // conditional p block_4 if_false1
- // if_false1:
- setp.ge.u32 q, j, 20;
- @q bra block_6; // conditional q block_6 if_false2
- // if_false2:
- mov.u32 j, i;
- add.u32 k, k, 1;
- bra block_7;
- block_6:
- mov.u32 j, k;
- add.u32 k, k, 2;
- block_7:
- bra block_2;
- block_4:
- ret;
- }";
-
- #[test]
- fn get_basic_blocks_fig_19_4() {
- let func = FIG_19_4;
- let mut errors = Vec::new();
- let ast = ptx::FunctionBodyParser::new()
- .parse(&mut errors, func)
- .unwrap();
- assert_eq!(errors.len(), 0);
- let mut constant_ids = HashMap::new();
- collect_label_ids(&mut constant_ids, &ast);
- let registers = collect_var_definitions(&[], &ast);
- let mut type_check = HashMap::new();
- let (normalized_ids, mut unique_ids) =
- normalize_identifiers(ast, &constant_ids, &mut type_check, registers);
- let type_check = RefCell::new(type_check);
- let new_id = &mut |typ: Option<ast::Type>| {
- let to_insert = unique_ids;
- {
- let mut type_check = type_check.borrow_mut();
- typ.map(|t| (*type_check).insert(to_insert, t));
- }
- unique_ids += 1;
- to_insert
- };
- let normalized_stmts = normalize_statements(normalized_ids, new_id);
- let mut bbs = get_basic_blocks(&normalized_stmts);
- bbs.iter_mut().for_each(sort_pred_succ);
- assert_eq!(
- bbs,
- vec![
- BasicBlock {
- start: StmtIndex(0),
- pred: vec![],
- succ: vec![BBIndex(1)],
- },
- BasicBlock {
- start: StmtIndex(6),
- pred: vec![BBIndex(0), BBIndex(5)],
- succ: vec![BBIndex(2), BBIndex(6)],
- },
- BasicBlock {
- start: StmtIndex(10),
- pred: vec![BBIndex(1)],
- succ: vec![BBIndex(3), BBIndex(4)],
- },
- BasicBlock {
- start: StmtIndex(14),
- pred: vec![BBIndex(2)],
- succ: vec![BBIndex(5)],
- },
- BasicBlock {
- start: StmtIndex(19),
- pred: vec![BBIndex(2)],
- succ: vec![BBIndex(5)],
- },
- BasicBlock {
- start: StmtIndex(23),
- pred: vec![BBIndex(3), BBIndex(4)],
- succ: vec![BBIndex(1)],
- },
- BasicBlock {
- start: StmtIndex(25),
- pred: vec![BBIndex(1)],
- succ: vec![],
- },
- ]
- );
- }
-
- fn cfg_fig_19_4() -> Vec<BasicBlock> {
- vec![
- BasicBlock {
- start: StmtIndex(0),
- pred: vec![],
- succ: vec![BBIndex(1)],
- },
- BasicBlock {
- start: StmtIndex(3),
- pred: vec![BBIndex(0), BBIndex(5)],
- succ: vec![BBIndex(2), BBIndex(6)],
- },
- BasicBlock {
- start: StmtIndex(6),
- pred: vec![BBIndex(1)],
- succ: vec![BBIndex(3), BBIndex(4)],
- },
- BasicBlock {
- start: StmtIndex(9),
- pred: vec![BBIndex(2)],
- succ: vec![BBIndex(5)],
- },
- BasicBlock {
- start: StmtIndex(13),
- pred: vec![BBIndex(2)],
- succ: vec![BBIndex(5)],
- },
- BasicBlock {
- start: StmtIndex(16),
- pred: vec![BBIndex(3), BBIndex(4)],
- succ: vec![BBIndex(1)],
- },
- BasicBlock {
- start: StmtIndex(18),
- pred: vec![BBIndex(1)],
- succ: vec![],
- },
- ]
- }
-
- // cfg from 19.4 with slighlty shuffled order of succ/pred
- #[test]
- fn reverse_postorder_fig_19_4() {
- let mut cfg = cfg_fig_19_4();
- cfg[1].pred.swap(0, 1);
- cfg[2].succ.swap(0, 1);
- let rpostorder = vec![
- BBIndex(0),
- BBIndex(1),
- BBIndex(6),
- BBIndex(2),
- BBIndex(3),
- BBIndex(4),
- BBIndex(5),
- ];
- let doms = immediate_dominators(&cfg, &rpostorder);
- assert_eq!(
- doms,
- vec![
- BBIndex(0),
- BBIndex(0),
- BBIndex(1),
- BBIndex(2),
- BBIndex(2),
- BBIndex(2),
- BBIndex(1)
- ]
- );
- }
-
- #[test]
- fn dominance_frontiers_fig_19_4() {
- let cfg = cfg_fig_19_4();
- let order = to_reverse_postorder(&cfg);
- let doms = immediate_dominators(&cfg, &order);
- let dom_fronts = dominance_frontiers(&cfg, &doms)
- .into_iter()
- .map(|hs| hs.into_iter().collect::<Vec<_>>())
- .collect::<Vec<_>>();
- let should = vec![
- vec![],
- vec![BBIndex(1)],
- vec![BBIndex(1)],
- vec![BBIndex(5)],
- vec![BBIndex(5)],
- vec![BBIndex(1)],
- vec![],
- ];
- assert_eq!(dom_fronts, should);
- }
-
- #[test]
- fn gather_phi_sets_fig_19_4() {
- let func = FIG_19_4;
- let mut errors = Vec::new();
- let fn_ast = ptx::FunctionBodyParser::new()
- .parse(&mut errors, func)
- .unwrap();
- assert_eq!(errors.len(), 0);
- let mut constant_ids = HashMap::new();
- collect_label_ids(&mut constant_ids, &fn_ast);
- assert_eq!(constant_ids.len(), 4);
-
- let mut type_check = HashMap::new();
- let registers = collect_var_definitions(&[], &fn_ast);
- let (normalized_ids, mut unique_ids) =
- normalize_identifiers(fn_ast, &constant_ids, &mut type_check, registers);
- let type_check = RefCell::new(type_check);
- let new_id = &mut |typ: Option<ast::Type>| {
- let to_insert = unique_ids;
- {
- let mut type_check = type_check.borrow_mut();
- typ.map(|t| (*type_check).insert(to_insert, t));
- }
- unique_ids += 1;
- to_insert
- };
- let normalized_stmts = normalize_statements(normalized_ids, new_id);
- let bbs = get_basic_blocks(&normalized_stmts);
- let rpostorder = to_reverse_postorder(&bbs);
- let doms = immediate_dominators(&bbs, &rpostorder);
- let dom_fronts = dominance_frontiers(&bbs, &doms);
- let phi = gather_phi_sets(
- &normalized_stmts,
- constant_ids.len() as u32,
- unique_ids,
- &bbs,
- &dom_fronts,
- );
- assert_eq!(
- phi,
- vec![
- HashSet::new(),
- to_hashset(vec![5, 6]),
- HashSet::new(),
- HashSet::new(),
- HashSet::new(),
- to_hashset(vec![5, 6]),
- HashSet::new()
- ]
- );
- }
-
- fn to_hashset<T: std::hash::Hash + Eq>(v: Vec<T>) -> HashSet<T> {
- v.into_iter().collect::<HashSet<T>>()
- }
-
- #[test]
- fn ssa_rename_19_4() {
- let func = FIG_19_4;
- let mut errors = Vec::new();
- let fn_ast = ptx::FunctionBodyParser::new()
- .parse(&mut errors, func)
- .unwrap();
- assert_eq!(errors.len(), 0);
- let (func, _, mut ssa_phis, unique_ids) = to_ssa(&[], fn_ast);
- assert_phi_dst_id(unique_ids, &ssa_phis);
- assert_dst_unique(&func, &ssa_phis);
- sort_phi(&mut ssa_phis);
-
- let i1 = unique_ids;
- let j1 = unique_ids + 1;
- let j2 = get_dst_from_src(&ssa_phis[1], j1);
- let j3 = get_dst(&func[10]);
- let j4 = get_dst_from_src(&ssa_phis[5], j3);
- let j5 = get_dst(&func[14]);
- let k1 = unique_ids + 2;
- let k2 = get_dst_from_src(&ssa_phis[1], k1);
- let k3 = get_dst(&func[11]);
- let k4 = get_dst_from_src(&ssa_phis[5], k3);
- let k5 = get_dst(&func[15]);
- let p1 = get_dst(&func[4]);
- let q1 = get_dst(&func[7]);
- let block_2 = get_label(&func[3]);
- let if_false1 = get_label(&func[6]);
- let if_false2 = get_label(&func[9]);
- let block_6 = get_label(&func[13]);
- let block_7 = get_label(&func[16]);
- let block_4 = get_label(&func[18]);
-
- {
- assert_eq!(get_ids(&func[0]), vec![i1]);
- assert_eq!(get_ids(&func[1]), vec![j1]);
- assert_eq!(get_ids(&func[2]), vec![k1]);
-
- assert_eq!(
- ssa_phis[1],
- to_phi(vec![(j2, vec![j4, j1]), (k2, vec![k4, k1])])
- );
- assert_eq!(get_ids(&func[3]), vec![block_2]);
- assert_eq!(get_ids(&func[4]), vec![p1, k2]);
- assert_eq!(get_ids(&func[5]), vec![p1, block_4, if_false1]);
-
- assert_eq!(get_ids(&func[6]), vec![if_false1]);
- assert_eq!(get_ids(&func[7]), vec![q1, j2]);
- assert_eq!(get_ids(&func[8]), vec![q1, block_6, if_false2]);
-
- assert_eq!(get_ids(&func[9]), vec![if_false2]);
- assert_eq!(get_ids(&func[10]), vec![j3, i1]);
- assert_eq!(get_ids(&func[11]), vec![k3, k2]);
- assert_eq!(get_ids(&func[12]), vec![block_7]);
-
- assert_eq!(get_ids(&func[13]), vec![block_6]);
- assert_eq!(get_ids(&func[14]), vec![j5, k2]);
- assert_eq!(get_ids(&func[15]), vec![k5, k2]);
-
- assert_eq!(
- ssa_phis[5],
- to_phi(vec![(j4, vec![j3, j5]), (k4, vec![k3, k5])])
- );
- assert_eq!(get_ids(&func[16]), vec![block_7]);
- assert_eq!(get_ids(&func[17]), vec![block_2]);
-
- assert_eq!(get_ids(&func[18]), vec![block_4]);
- assert_eq!(get_ids(&func[19]), vec![]);
- }
- }
-
- fn assert_phi_dst_id(max_id: spirv::Word, phis: &[Vec<PhiDef>]) {
- for phi_set in phis {
- for phi in phi_set {
- assert!(phi.dst > max_id);
- }
- }
- }
-
- fn assert_dst_unique(func: &[Statement], phis: &[Vec<PhiDef>]) {
- let mut seen = HashSet::new();
- for s in func {
- s.visit_id(&mut |is_dst, id| {
- if is_dst {
- assert!(seen.insert(id));
- }
- });
- }
- for phi_set in phis {
- for phi in phi_set {
- assert!(seen.insert(phi.dst));
- }
- }
- }
-
- fn get_ids(s: &Statement) -> Vec<spirv::Word> {
- let mut result = Vec::new();
- s.visit_id(&mut |_, id| {
- result.push(id);
- });
- result
- }
-
- fn sort_phi(phis: &mut [Vec<PhiDef>]) {
- for phi_set in phis {
- phi_set.sort_by_key(|phi| phi.dst);
- }
- }
-
- fn to_phi(raw: Vec<(spirv::Word, Vec<spirv::Word>)>) -> Vec<PhiDef> {
- let result = raw
- .into_iter()
- .map(|(dst, src)| PhiDef {
- dst: dst,
- src: src.into_iter().collect::<HashSet<_>>(),
- })
- .collect::<Vec<_>>();
- let mut result = [result];
- sort_phi(&mut result);
- let [result] = result;
- result
- }
-
- fn get_dst(s: &Statement) -> spirv::Word {
- let mut result = None;
- s.visit_id(&mut |is_dst, id| {
- if is_dst {
- assert_eq!(result.replace(id), None);
- }
- });
- result.unwrap()
- }
-
- fn get_label(s: &Statement) -> spirv::Word {
- match s {
- Statement::Label(id) => *id,
- _ => panic!(),
- }
- }
-
- fn get_dst_from_src(phi: &[PhiDef], src: spirv::Word) -> spirv::Word {
- for phi_set in phi {
- if phi_set.src.contains(&src) {
- return phi_set.dst;
- }
- }
- panic!()
- }
-
static SCALAR_TYPES: [ast::ScalarType; 15] = [
ast::ScalarType::B8,
ast::ScalarType::B16,