From 9f60990765301af9a359100b94192137f466d351 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 17 May 2020 18:45:22 +0200 Subject: Start introducing support for bitcast coercions in instructions --- ptx/src/ast.rs | 4 +- ptx/src/ptx.lalrpop | 23 +++--- ptx/src/translate.rs | 192 ++++++++++++++++++++++++++++++++++++++++++++++++--- 3 files changed, 201 insertions(+), 18 deletions(-) diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index f685b7d..ce9a596 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -238,7 +238,9 @@ pub struct MovData {} pub struct MulData {} -pub struct AddData {} +pub struct AddData { + pub typ: ScalarType, +} pub struct SetpData {} diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 999d511..79290da 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -161,6 +161,7 @@ Variable: ast::Variable<&'input str> = { VariableName: (&'input str, Option) = { => (id, None), + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names => { let left_angle = id.as_bytes().iter().copied().position(|x| x == b'<').unwrap(); let count = id[left_angle+1..id.len()-1].parse::(); @@ -270,9 +271,13 @@ RoundingMode = { ".rn", ".rz", ".rm", ".rp" }; -IntType = { - ".u16", ".u32", ".u64", - ".s16", ".s32", ".s64", +IntType : ast::ScalarType = { + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add @@ -283,12 +288,12 @@ InstAdd: ast::Instruction<&'input str> = { }; InstAddMode: ast::AddData = { - IntType => ast::AddData{}, - ".sat" ".s32" => ast::AddData{}, - RoundingMode? ".ftz"? ".sat"? ".f32" => ast::AddData{}, - RoundingMode? ".f64" => ast::AddData{}, - ".rn"? ".ftz"? ".sat"? ".f16" => ast::AddData{}, - ".rn"? ".ftz"? ".sat"? ".f16x2" => ast::AddData{} + => ast::AddData{ typ: t }, + ".sat" ".s32" => ast::AddData{ typ: ast::ScalarType::S32 }, + RoundingMode? ".ftz"? ".sat"? ".f32" => ast::AddData{ typ: ast::ScalarType::F32 }, + RoundingMode? ".f64" => ast::AddData{ typ: ast::ScalarType::F64 }, + ".rn"? ".ftz"? ".sat"? ".f16" => ast::AddData{ typ: ast::ScalarType::F16 }, + ".rn"? ".ftz"? ".sat"? ".f16x2" => todo!() }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 63d7f7b..a186772 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -3,7 +3,7 @@ use bit_vec::BitVec; use rspirv::dr; use std::cell::RefCell; use std::collections::{BTreeMap, HashMap, HashSet}; -use std::fmt; +use std::{borrow::Cow, fmt}; use rspirv::binary::Assemble; @@ -125,13 +125,17 @@ fn emit_function<'a>( let mut contant_ids = HashMap::new(); collect_arg_ids(&mut contant_ids, &f.args); collect_label_ids(&mut contant_ids, &f.body); - let (mut normalized_ids, unique_ids) = normalize_identifiers(f.body, &contant_ids); - let bbs = get_basic_blocks(&normalized_ids); + let registers = collect_registers(&f.body); + let (normalized_ids, unique_ids, type_check) = + normalize_identifiers(f.body, &contant_ids, registers); + let (mut func_body, unique_ids) = + insert_implicit_conversion(normalized_ids, unique_ids, &|x| type_check[&x]); + let bbs = get_basic_blocks(&func_body); let rpostorder = to_reverse_postorder(&bbs); let doms = immediate_dominators(&bbs, &rpostorder); let dom_fronts = dominance_frontiers(&bbs, &doms); let (_, unique_ids) = ssa_legalize( - &mut normalized_ids, + &mut func_body, contant_ids.len() as u32, unique_ids, &bbs, @@ -140,11 +144,70 @@ fn emit_function<'a>( ); let id_offset = builder.reserve_ids(unique_ids); emit_function_args(builder, id_offset, map, &f.args); - emit_function_body_ops(builder, id_offset, map, &normalized_ids, &bbs)?; + emit_function_body_ops(builder, id_offset, map, &func_body, &bbs)?; builder.end_function()?; Ok(func_id) } +fn collect_registers<'a>(body: &[ast::Statement<&'a str>]) -> HashMap, ast::Type> { + let mut result = HashMap::new(); + for s in body { + match s { + ast::Statement::Variable(var) => match var.count { + Some(count) => { + for i in 0..count { + result.insert(Cow::Owned(format!("{}{}", var.name, i)), var.v_type); + } + } + None => { + result.insert(Cow::Borrowed(var.name), var.v_type); + } + }, + ast::Statement::Label(_) | ast::Statement::Instruction(_, _) => (), + } + } + result +} + +/* + There are three kinds of implicit conversions in PTX: + * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands + * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size + * pointer dereference in st/ld: not documented, but for instruction `ld.. x, [y]` semantics are x = *(*)y +*/ +fn insert_implicit_conversion ast::Type>( + normalized_ids: Vec, + unique_ids: spirv::Word, + type_check: &TypeCheck, +) -> (Vec, spirv::Word) { + let mut id = unique_ids; + let new_id = &mut || { + let temp = id; + id += 1; + temp + }; + let mut result = Vec::with_capacity(normalized_ids.len()); + for s in normalized_ids.into_iter() { + match s { + Statement::Instruction(inst) => match inst { + ast::Instruction::Add(add, arg) => { + arg.insert_implicit_conversions( + &mut result, + ast::Type::Scalar(add.typ), + type_check, + new_id, + |arg| Statement::Instruction(ast::Instruction::Add(add, arg)), + ); + } + _ => todo!(), + }, + s @ Statement::Conditional(_) | s @ Statement::Label(_) => result.push(s), + Statement::Converison(_) => unreachable!(), + } + } + (result, id) +} + fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -223,6 +286,7 @@ fn emit_function_body_ops( // If block startd with a label it has already been emitted, // all other labels in the block are unused Statement::Label(_) => (), + Statement::Converison(_) => todo!(), Statement::Conditional(bra) => { builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?; } @@ -300,7 +364,8 @@ fn emit_function_body_ops( fn normalize_identifiers<'a>( func: Vec>, constant_identifiers: &HashMap<&'a str, spirv::Word>, // arguments and labels can't be redefined -) -> (Vec, spirv::Word) { + types: HashMap, ast::Type>, +) -> (Vec, spirv::Word, HashMap) { let mut result = Vec::with_capacity(func.len()); let mut id: u32 = constant_identifiers.len() as u32; let mut remapped_ids = HashMap::new(); @@ -324,7 +389,11 @@ fn normalize_identifiers<'a>( for s in func { Statement::from_ast(s, &mut result, &mut get_or_add); } - (result, id) + let mut type_map = HashMap::with_capacity(types.len()); + for (old_id, new_id) in remapped_ids { + type_map.insert(new_id, types[old_id]); + } + (result, id, type_map) } fn ssa_legalize( @@ -580,6 +649,7 @@ fn gather_phi_sets( match s { Statement::Instruction(inst) => inst.visit_id(&mut visitor), Statement::Conditional(brc) => visitor(false, &brc.predicate), + Statement::Converison(conv) => conv.visit_id(&mut visitor), // label redefinition is a compile-time error Statement::Label(_) => (), } @@ -630,6 +700,7 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec { unresolved_bb_edge.push((StmtIndex(idx), bra.if_false)); unresolved_bb_edge.push((StmtIndex(idx), bra.if_true)); } + Statement::Converison(_) => (), }; } let mut bb_edge = HashSet::new(); @@ -647,7 +718,7 @@ fn get_basic_blocks(fun: &[Statement]) -> Vec { bb_edge.insert((StmtIndex(target.0 - 1), target)); } } - Statement::Label(_) => { + Statement::Converison(_) | Statement::Label(_) => { bb_edge.insert((StmtIndex(target.0 - 1), target)); } // This is already in `unresolved_bb_edge` @@ -816,6 +887,7 @@ enum Statement { Instruction(ast::Instruction), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), + Converison(ImplicitConversion), } struct BrachCondition { @@ -823,6 +895,7 @@ struct BrachCondition { if_true: spirv::Word, if_false: spirv::Word, } + impl BrachCondition { fn visit_id(&self, f: &mut F) { f(false, &self.predicate); @@ -837,6 +910,25 @@ impl BrachCondition { } } +struct ImplicitConversion { + dst: spirv::Word, + src: spirv::Word, + from: ast::Type, + to: ast::Type, +} + +impl ImplicitConversion { + fn visit_id(&self, f: &mut F) { + f(false, &self.src); + f(true, &self.dst); + } + + fn visit_id_mut(&mut self, f: &mut F) { + f(false, &mut self.src); + f(true, &mut self.dst); + } +} + impl Statement { fn from_ast<'a, F: FnMut(Option<&'a str>) -> u32>( s: ast::Statement<&'a str>, @@ -885,6 +977,7 @@ impl Statement { Statement::Label(id) => f(false, id), Statement::Instruction(inst) => inst.visit_id(f), Statement::Conditional(bra) => bra.visit_id(f), + Statement::Converison(conv) => conv.visit_id(f), } } @@ -895,6 +988,7 @@ impl Statement { Statement::Label(id) => f(false, id), Statement::Instruction(inst) => inst.visit_id_mut(f), Statement::Conditional(bra) => bra.visit_id_mut(f), + Statement::Converison(conv) => conv.visit_id_mut(f), } } } @@ -1068,6 +1162,31 @@ impl ast::Arg3 { } } +impl ast::Arg3 { + fn insert_implicit_conversions< + TypeCheck: Fn(spirv::Word) -> ast::Type, + NewId: FnMut() -> spirv::Word, + NewStatement: FnOnce(Self) -> Statement, + >( + self, + func: &mut Vec, + op_type: ast::Type, + type_check: &TypeCheck, + new_id: &mut NewId, + new_statement: NewStatement, + ) { + let src1 = self + .src1 + .insert_implicit_conversion(func, op_type, type_check, new_id); + let src2 = self + .src2 + .insert_implicit_conversion(func, op_type, type_check, new_id); + insert_implicit_conversion_dst(func, op_type, type_check, new_id, self.dst, |dst| { + new_statement(Self { dst, src1, src2 }) + }); + } +} + impl ast::Arg4 { fn map_id U>(self, f: &mut F) -> ast::Arg4 { ast::Arg4 { @@ -1147,6 +1266,37 @@ impl ast::Operand { } } +impl ast::Operand { + fn insert_implicit_conversion< + TypeCheck: Fn(spirv::Word) -> ast::Type, + NewId: FnMut() -> spirv::Word, + >( + self, + func: &mut Vec, + op_type: ast::Type, + type_check: &TypeCheck, + new_id: &mut NewId, + ) -> Self { + match self { + ast::Operand::Reg(src) => { + if type_check(src) == op_type { + return self; + } + let new_src = new_id(); + func.push(Statement::Converison(ImplicitConversion { + src: src, + dst: new_src, + from: type_check(src), + to: op_type, + })); + ast::Operand::Reg(new_src) + } + o @ ast::Operand::Imm(_) => o, + ast::Operand::RegOffset(_, _) => todo!(), + } + } +} + impl ast::MovOperand { fn map_id U>(self, f: &mut F) -> ast::MovOperand { match self { @@ -1170,6 +1320,32 @@ impl ast::MovOperand { } } +fn insert_implicit_conversion_dst< + TypeCheck: Fn(spirv::Word) -> ast::Type, + NewId: FnMut() -> spirv::Word, + NewStatement: FnOnce(spirv::Word) -> Statement, +>( + func: &mut Vec, + op_type: ast::Type, + type_check: &TypeCheck, + new_id: &mut NewId, + dst: spirv::Word, + new_statement: NewStatement, +) { + if type_check(dst) == op_type { + func.push(new_statement(dst)); + } else { + let new_dst = new_id(); + func.push(new_statement(new_dst)); + func.push(Statement::Converison(ImplicitConversion { + src: new_dst, + dst: dst, + from: type_check(new_dst), + to: op_type, + })); + } +} + // CFGs below taken from "Modern Compiler Implementation in Java" #[cfg(test)] mod tests { -- cgit v1.2.3