aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-07-20 20:15:23 +0200
committerAndrzej Janik <[email protected]>2020-07-20 20:15:23 +0200
commit4e9a71ed3884e66db666b0413f5efd4ff9d97a3a (patch)
treeee639a15e97f46ee0123360b8aa2d24eb98c2217
parent872d69c714e647bab9192d6ae5105fe2638b4f77 (diff)
downloadZLUDA-4e9a71ed3884e66db666b0413f5efd4ff9d97a3a.tar.gz
ZLUDA-4e9a71ed3884e66db666b0413f5efd4ff9d97a3a.zip
Update type lookup map when emitting new instructions during translation
-rw-r--r--ptx/src/translate.rs137
1 files changed, 76 insertions, 61 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 6620666..0d86066 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -180,11 +180,23 @@ fn to_ssa<'a>(
collect_arg_ids(&mut contant_ids, &mut type_check, &f_args);
collect_label_ids(&mut contant_ids, &f_body);
let registers = collect_var_definitions(&f_args, &f_body);
- let (normalized_ids, unique_ids) =
+ let (normalized_ids, mut unique_ids) =
normalize_identifiers(f_body, &contant_ids, &mut type_check, registers);
- let (normalized_stmts, unique_ids) = normalize_statements(normalized_ids, unique_ids);
- let (mut func_body, unique_ids) =
- insert_implicit_conversions(normalized_stmts, unique_ids, &|x| type_check[&x]);
+ let type_check = RefCell::new(type_check);
+ let new_id = &mut |typ: Option<ast::Type>| {
+ let to_insert = unique_ids;
+ {
+ let mut type_check = type_check.borrow_mut();
+ typ.map(|t| (*type_check).insert(to_insert, t));
+ }
+ unique_ids += 1;
+ to_insert
+ };
+ let normalized_stmts = normalize_statements(normalized_ids, new_id);
+ let mut func_body = insert_implicit_conversions(normalized_stmts, new_id, &|x| {
+ let type_check = type_check.borrow();
+ type_check[&x]
+ });
let bbs = get_basic_blocks(&func_body);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
@@ -202,22 +214,16 @@ fn to_ssa<'a>(
fn normalize_statements(
func: Vec<ast::Statement<spirv::Word>>,
- unique_ids: spirv::Word,
-) -> (Vec<Statement>, spirv::Word) {
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
+) -> Vec<Statement> {
let mut result = Vec::with_capacity(func.len());
- let mut id = unique_ids;
- let new_id = &mut || {
- let to_insert = id;
- id += 1;
- to_insert
- };
for s in func {
match s {
ast::Statement::Label(id) => result.push(Statement::Label(id)),
ast::Statement::Instruction(pred, inst) => {
if let Some(pred) = pred {
- let mut if_true = new_id();
- let mut if_false = new_id();
+ let mut if_true = new_id(None);
+ let mut if_false = new_id(None);
if pred.not {
std::mem::swap(&mut if_true, &mut if_false);
}
@@ -245,13 +251,13 @@ fn normalize_statements(
ast::Statement::Variable(_) => unreachable!(),
}
}
- (result, id)
+ result
}
#[must_use]
fn normalize_insert_instruction(
func: &mut Vec<Statement>,
- new_id: &mut impl FnMut() -> spirv::Word,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
instr: ast::Instruction<spirv::Word>,
) -> Instruction {
match instr {
@@ -302,7 +308,7 @@ fn normalize_insert_instruction(
fn normalize_expand_arg2(
func: &mut Vec<Statement>,
- new_id: &mut impl FnMut() -> spirv::Word,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2<spirv::Word>,
) -> Arg2 {
@@ -314,7 +320,7 @@ fn normalize_expand_arg2(
fn normalize_expand_arg2mov(
func: &mut Vec<Statement>,
- new_id: &mut impl FnMut() -> spirv::Word,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2Mov<spirv::Word>,
) -> Arg2 {
@@ -326,7 +332,7 @@ fn normalize_expand_arg2mov(
fn normalize_expand_arg2st(
func: &mut Vec<Statement>,
- new_id: &mut impl FnMut() -> spirv::Word,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg2St<spirv::Word>,
) -> Arg2St {
@@ -338,7 +344,7 @@ fn normalize_expand_arg2st(
fn normalize_expand_arg3(
func: &mut Vec<Statement>,
- new_id: &mut impl FnMut() -> spirv::Word,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg3<spirv::Word>,
) -> Arg3 {
@@ -351,7 +357,7 @@ fn normalize_expand_arg3(
fn normalize_expand_arg4(
func: &mut Vec<Statement>,
- new_id: &mut impl FnMut() -> spirv::Word,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg4<spirv::Word>,
) -> Arg4 {
@@ -365,7 +371,7 @@ fn normalize_expand_arg4(
fn normalize_expand_arg5(
func: &mut Vec<Statement>,
- new_id: &mut impl FnMut() -> spirv::Word,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
a: ast::Arg5<spirv::Word>,
) -> Arg5 {
@@ -380,7 +386,7 @@ fn normalize_expand_arg5(
fn normalize_expand_operand(
func: &mut Vec<Statement>,
- new_id: &mut impl FnMut() -> spirv::Word,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
opr: ast::Operand<spirv::Word>,
) -> spirv::Word {
@@ -388,7 +394,7 @@ fn normalize_expand_operand(
ast::Operand::Reg(r) => r,
ast::Operand::Imm(x) => {
if let Some(typ) = inst_type() {
- let id = new_id();
+ let id = new_id(Some(ast::Type::Scalar(typ)));
func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: typ,
@@ -405,7 +411,7 @@ fn normalize_expand_operand(
fn normalize_expand_mov_operand(
func: &mut Vec<Statement>,
- new_id: &mut impl FnMut() -> spirv::Word,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
inst_type: &impl Fn() -> Option<ast::ScalarType>,
opr: ast::MovOperand<spirv::Word>,
) -> spirv::Word {
@@ -456,15 +462,9 @@ fn collect_var_definitions<'a>(
*/
fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
normalized_ids: Vec<Statement>,
- unique_ids: spirv::Word,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
type_check: &TypeCheck,
-) -> (Vec<Statement>, spirv::Word) {
- let mut id = unique_ids;
- let new_id = &mut || {
- let temp = id;
- id += 1;
- temp
- };
+) -> Vec<Statement> {
let mut result = Vec::with_capacity(normalized_ids.len());
for s in normalized_ids.into_iter() {
match s {
@@ -518,7 +518,7 @@ fn insert_implicit_conversions<TypeCheck: Fn(spirv::Word) -> ast::Type>(
Statement::Converison(_) => unreachable!(),
}
}
- (result, id)
+ result
}
fn get_function_type(
@@ -2007,14 +2007,11 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
}
}
-fn insert_implicit_conversions_ld_src<
- TypeCheck: Fn(spirv::Word) -> ast::Type,
- NewId: FnMut() -> spirv::Word,
->(
+fn insert_implicit_conversions_ld_src<TypeCheck: Fn(spirv::Word) -> ast::Type>(
func: &mut Vec<Statement>,
instr_type: ast::Type,
type_check: &TypeCheck,
- new_id: &mut NewId,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
state_space: ast::LdStateSpace,
src: spirv::Word,
) -> spirv::Word {
@@ -2055,12 +2052,11 @@ fn insert_implicit_conversions_ld_src<
fn insert_implicit_conversions_ld_src_impl<
TypeCheck: Fn(spirv::Word) -> ast::Type,
- NewId: FnMut() -> spirv::Word,
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
>(
func: &mut Vec<Statement>,
type_check: &TypeCheck,
- new_id: &mut NewId,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
instr_type: ast::Type,
src: spirv::Word,
should_convert: ShouldConvert,
@@ -2099,15 +2095,15 @@ fn should_convert_ld_generic_src_to_bitcast(
}
#[must_use]
-fn insert_conversion_src<NewId: FnMut() -> spirv::Word>(
+fn insert_conversion_src(
func: &mut Vec<Statement>,
- new_id: &mut NewId,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
src: spirv::Word,
src_type: ast::Type,
instr_type: ast::Type,
conv: ConversionKind,
) -> spirv::Word {
- let temp_src = new_id();
+ let temp_src = new_id(Some(instr_type));
func.push(Statement::Converison(ImplicitConversion {
src: src,
dst: temp_src,
@@ -2121,7 +2117,6 @@ fn insert_conversion_src<NewId: FnMut() -> spirv::Word>(
fn insert_with_implicit_conversion_dst<
T,
TypeCheck: Fn(spirv::Word) -> ast::Type,
- NewId: FnMut() -> spirv::Word,
ShouldConvert: FnOnce(ast::Type, ast::ScalarType) -> Option<ConversionKind>,
Setter: Fn(&mut T) -> &mut spirv::Word,
ToInstruction: FnOnce(T) -> Instruction,
@@ -2129,7 +2124,7 @@ fn insert_with_implicit_conversion_dst<
func: &mut Vec<Statement>,
instr_type: ast::ScalarType,
type_check: &TypeCheck,
- new_id: &mut NewId,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
should_convert: ShouldConvert,
mut t: T,
setter: Setter,
@@ -2146,15 +2141,15 @@ fn insert_with_implicit_conversion_dst<
}
#[must_use]
-fn get_conversion_dst<NewId: FnMut() -> spirv::Word>(
- new_id: &mut NewId,
+fn get_conversion_dst(
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
dst: &mut spirv::Word,
instr_type: ast::Type,
dst_type: ast::Type,
kind: ConversionKind,
) -> Statement {
let original_dst = *dst;
- let temp_dst = new_id();
+ let temp_dst = new_id(Some(instr_type));
*dst = temp_dst;
Statement::Converison(ImplicitConversion {
src: temp_dst,
@@ -2250,13 +2245,10 @@ fn should_convert_relaxed_dst(
}
}
-fn insert_implicit_bitcasts<
- TypeCheck: Fn(spirv::Word) -> ast::Type,
- NewId: FnMut() -> spirv::Word,
->(
+fn insert_implicit_bitcasts<TypeCheck: Fn(spirv::Word) -> ast::Type>(
func: &mut Vec<Statement>,
type_check: &TypeCheck,
- new_id: &mut NewId,
+ new_id: &mut impl FnMut(Option<ast::Type>) -> spirv::Word,
mut instr: Instruction,
) {
let mut dst_coercion = None;
@@ -2662,9 +2654,20 @@ mod tests {
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &ast);
let registers = collect_var_definitions(&[], &ast);
- let (normalized_ids, unique_ids) =
- normalize_identifiers(ast, &constant_ids, &mut HashMap::new(), registers);
- let (normalized_stmts, _) = normalize_statements(normalized_ids, unique_ids);
+ let mut type_check = HashMap::new();
+ let (normalized_ids, mut unique_ids) =
+ normalize_identifiers(ast, &constant_ids, &mut type_check, registers);
+ let type_check = RefCell::new(type_check);
+ let new_id = &mut |typ: Option<ast::Type>| {
+ let to_insert = unique_ids;
+ {
+ let mut type_check = type_check.borrow_mut();
+ typ.map(|t| (*type_check).insert(to_insert, t));
+ }
+ unique_ids += 1;
+ to_insert
+ };
+ let normalized_stmts = normalize_statements(normalized_ids, new_id);
let mut bbs = get_basic_blocks(&normalized_stmts);
bbs.iter_mut().for_each(sort_pred_succ);
assert_eq!(
@@ -2811,10 +2814,22 @@ mod tests {
let mut constant_ids = HashMap::new();
collect_label_ids(&mut constant_ids, &fn_ast);
assert_eq!(constant_ids.len(), 4);
+
+ let mut type_check = HashMap::new();
let registers = collect_var_definitions(&[], &fn_ast);
- let (normalized_ids, unique_ids) =
- normalize_identifiers(fn_ast, &constant_ids, &mut HashMap::new(), registers);
- let (normalized_stmts, max_id) = normalize_statements(normalized_ids, unique_ids);
+ let (normalized_ids, mut unique_ids) =
+ normalize_identifiers(fn_ast, &constant_ids, &mut type_check, registers);
+ let type_check = RefCell::new(type_check);
+ let new_id = &mut |typ: Option<ast::Type>| {
+ let to_insert = unique_ids;
+ {
+ let mut type_check = type_check.borrow_mut();
+ typ.map(|t| (*type_check).insert(to_insert, t));
+ }
+ unique_ids += 1;
+ to_insert
+ };
+ let normalized_stmts = normalize_statements(normalized_ids, new_id);
let bbs = get_basic_blocks(&normalized_stmts);
let rpostorder = to_reverse_postorder(&bbs);
let doms = immediate_dominators(&bbs, &rpostorder);
@@ -2822,7 +2837,7 @@ mod tests {
let phi = gather_phi_sets(
&normalized_stmts,
constant_ids.len() as u32,
- max_id,
+ unique_ids,
&bbs,
&dom_fronts,
);