From 7bd4179d1dd24f81b56e66fd13c16631b518495f Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 22 Sep 2024 19:47:08 +0200 Subject: Add more passes --- ptx/Cargo.toml | 2 + ptx/src/pass/deparamize_functions.rs | 141 ++++++++++++++ ptx/src/pass/emit_llvm.rs | 1 + ptx/src/pass/emit_spirv.rs | 1 + ptx/src/pass/expand_operands.rs | 289 ++++++++++++++++++++++++++++ ptx/src/pass/fix_special_registers2.rs | 209 ++++++++++++++++++++ ptx/src/pass/insert_explicit_load_store.rs | 273 ++++++++++++++++++++++++++ ptx/src/pass/insert_implicit_conversions.rs | 7 + ptx/src/pass/mod.rs | 272 +++++++++++++++++++++++++- ptx/src/pass/normalize_identifiers2.rs | 111 ++--------- ptx/src/pass/normalize_labels.rs | 1 + ptx/src/pass/normalize_predicates2.rs | 4 +- ptx/src/pass/resolve_function_pointers.rs | 2 +- 13 files changed, 1208 insertions(+), 105 deletions(-) create mode 100644 ptx/src/pass/deparamize_functions.rs create mode 100644 ptx/src/pass/expand_operands.rs create mode 100644 ptx/src/pass/fix_special_registers2.rs create mode 100644 ptx/src/pass/insert_explicit_load_store.rs diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index fd86f15..e2c4ff8 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -18,6 +18,8 @@ bit-vec = "0.6" half ="1.6" bitflags = "1.2" rustc-hash = "2.0.0" +strum = "0.26" +strum_macros = "0.26" [dependencies.lalrpop-util] version = "0.19.12" diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs new file mode 100644 index 0000000..04c8831 --- /dev/null +++ b/ptx/src/pass/deparamize_functions.rs @@ -0,0 +1,141 @@ +use std::collections::BTreeMap; + +use super::*; + +pub(super) fn run<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2, + directive: Directive2<'input, ast::Instruction, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), + }) +} + +fn run_method<'input>( + resolver: &mut GlobalStringIdentResolver2, + mut method: Function2<'input, ast::Instruction, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + if method.func_decl.name.is_kernel() { + return Ok(method); + } + let is_declaration = method.body.is_none(); + let mut body = Vec::new(); + let mut remap_returns = Vec::new(); + for arg in method.func_decl.return_arguments.iter_mut() { + match arg.state_space { + ptx_parser::StateSpace::Param => { + arg.state_space = ptx_parser::StateSpace::Reg; + let old_name = arg.name; + arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); + if is_declaration { + continue; + } + remap_returns.push((old_name, arg.name, arg.v_type.clone())); + body.push(Statement::Variable(ast::Variable { + align: None, + name: old_name, + v_type: arg.v_type.clone(), + state_space: ptx_parser::StateSpace::Param, + array_init: Vec::new(), + })); + } + ptx_parser::StateSpace::Reg => {} + _ => return Err(error_unreachable()), + } + } + for arg in method.func_decl.input_arguments.iter_mut() { + match arg.state_space { + ptx_parser::StateSpace::Param => { + arg.state_space = ptx_parser::StateSpace::Reg; + let old_name = arg.name; + arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); + if is_declaration { + continue; + } + body.push(Statement::Variable(ast::Variable { + align: None, + name: old_name, + v_type: arg.v_type.clone(), + state_space: ptx_parser::StateSpace::Param, + array_init: Vec::new(), + })); + body.push(Statement::Instruction(ast::Instruction::St { + data: ast::StData { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::StCacheOperator::Writethrough, + typ: arg.v_type.clone(), + }, + arguments: ast::StArgs { + src1: old_name, + src2: arg.name, + }, + })); + } + ptx_parser::StateSpace::Reg => {} + _ => return Err(error_unreachable()), + } + } + if remap_returns.is_empty() { + return Ok(method); + } + let body = method + .body + .map(|statements| { + for statement in statements { + run_statement(&remap_returns, &mut body, statement)?; + } + Ok::<_, TranslateError>(body) + }) + .transpose()?; + Ok(Function2 { + func_decl: method.func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'input>( + remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>, + result: &mut Vec, SpirvWord>>, + statement: Statement, SpirvWord>, +) -> Result<(), TranslateError> { + match statement { + Statement::Instruction(ast::Instruction::Ret { .. }) => { + for (old_name, new_name, type_) in remap_returns.iter().cloned() { + result.push(Statement::Instruction(ast::Instruction::Ld { + data: ast::LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Reg, + caching: ast::LdCacheOperator::Cached, + typ: type_, + non_coherent: false, + }, + arguments: ast::LdArgs { + dst: new_name, + src: old_name, + }, + })); + } + result.push(statement); + } + statement => { + result.push(statement); + } + } + Ok(()) +} diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index de85efc..3060335 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -308,6 +308,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Statement::PtrAccess(_) => todo!(), Statement::RepackVector(_) => todo!(), Statement::FunctionPointer(_) => todo!(), + Statement::VectorAccess(_) => todo!(), }) } diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs index ae4dcfe..120a477 100644 --- a/ptx/src/pass/emit_spirv.rs +++ b/ptx/src/pass/emit_spirv.rs @@ -1561,6 +1561,7 @@ fn emit_function_body_ops<'input>( builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?; } } + Statement::VectorAccess(vector_access) => todo!(), } } Ok(()) diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs new file mode 100644 index 0000000..3dabf40 --- /dev/null +++ b/ptx/src/pass/expand_operands.rs @@ -0,0 +1,289 @@ +use super::*; + +pub(super) fn run<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec>, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: Directive2< + 'input, + ast::Instruction>, + ast::ParsedOperand, + >, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + Directive2::Variable(linking, var) => Directive2::Variable(linking, var), + Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), + }) +} + +fn run_method<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + method: Function2< + 'input, + ast::Instruction>, + ast::ParsedOperand, + >, +) -> Result, SpirvWord>, TranslateError> { + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(resolver, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { + func_decl: method.func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + result: &mut Vec, SpirvWord>>, + statement: UnconditionalStatement, +) -> Result<(), TranslateError> { + let mut visitor = FlattenArguments::new(resolver, result); + let new_statement = statement.visit_map(&mut visitor)?; + visitor.result.push(new_statement); + Ok(()) +} + +struct FlattenArguments<'a, 'input> { + result: &'a mut Vec, + resolver: &'a mut GlobalStringIdentResolver2<'input>, + post_stmts: Vec, +} + +impl<'a, 'input> FlattenArguments<'a, 'input> { + fn new( + resolver: &'a mut GlobalStringIdentResolver2<'input>, + result: &'a mut Vec, + ) -> Self { + FlattenArguments { + result, + resolver, + post_stmts: Vec::new(), + } + } + + fn reg(&mut self, name: SpirvWord) -> Result { + Ok(name) + } + + fn reg_offset( + &mut self, + reg: SpirvWord, + offset: i32, + type_space: Option<(&ast::Type, ast::StateSpace)>, + _is_dst: bool, + ) -> Result { + let (type_, state_space) = if let Some((type_, state_space)) = type_space { + (type_, state_space) + } else { + return Err(TranslateError::UntypedSymbol); + }; + if state_space == ast::StateSpace::Reg { + let (reg_type, reg_space) = self.resolver.get_typed(reg)?; + if *reg_space != ast::StateSpace::Reg { + return Err(error_mismatched_type()); + } + let reg_scalar_type = match reg_type { + ast::Type::Scalar(underlying_type) => *underlying_type, + _ => return Err(error_mismatched_type()), + }; + let reg_type = reg_type.clone(); + let id_constant_stmt = self + .resolver + .register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: reg_scalar_type, + value: ast::ImmediateValue::S64(offset as i64), + })); + let arith_details = match reg_scalar_type.kind() { + ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, + saturate: false, + }), + ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { + ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, + saturate: false, + }) + } + _ => return Err(error_unreachable()), + }; + let id_add_result = self + .resolver + .register_unnamed(Some((reg_type, state_space))); + self.result + .push(Statement::Instruction(ast::Instruction::Add { + data: arith_details, + arguments: ast::AddArgs { + dst: id_add_result, + src1: reg, + src2: id_constant_stmt, + }, + })); + Ok(id_add_result) + } else { + let id_constant_stmt = self.resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: ast::ScalarType::S64, + value: ast::ImmediateValue::S64(offset as i64), + })); + let dst = self + .resolver + .register_unnamed(Some((type_.clone(), state_space))); + self.result.push(Statement::PtrAccess(PtrAccess { + underlying_type: type_.clone(), + state_space: state_space, + dst, + ptr_src: reg, + offset_src: id_constant_stmt, + })); + Ok(dst) + } + } + + fn immediate( + &mut self, + value: ast::ImmediateValue, + type_space: Option<(&ast::Type, ast::StateSpace)>, + ) -> Result { + let (scalar_t, state_space) = + if let Some((ast::Type::Scalar(scalar), state_space)) = type_space { + (*scalar, state_space) + } else { + return Err(TranslateError::UntypedSymbol); + }; + let id = self + .resolver + .register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: id, + typ: scalar_t, + value, + })); + Ok(id) + } + + fn vec_member( + &mut self, + vector_src: SpirvWord, + member: u8, + _type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + ) -> Result { + if is_dst { + return Err(error_mismatched_type()); + } + let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? { + (ast::Type::Vector(vector_width, scalar_t), space) => { + (*vector_width, *scalar_t, *space) + } + _ => return Err(error_mismatched_type()), + }; + let temporary = self + .resolver + .register_unnamed(Some((scalar_type.into(), space))); + self.result.push(Statement::VectorAccess(VectorAccess { + scalar_type, + vector_width, + dst: temporary, + src: vector_src, + member: member, + })); + Ok(temporary) + } + + fn vec_pack( + &mut self, + vecs: Vec, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + let (scalar_t, state_space) = match type_space { + Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space), + _ => return Err(error_mismatched_type()), + }; + let temp_vec = self + .resolver + .register_unnamed(Some((scalar_t.into(), state_space))); + let statement = Statement::RepackVector(RepackVectorDetails { + is_extract: is_dst, + typ: scalar_t, + packed: temp_vec, + unpacked: vecs, + relaxed_type_check, + }); + if is_dst { + self.post_stmts.push(statement); + } else { + self.result.push(statement); + } + Ok(temp_vec) + } +} + +impl<'a, 'b> ast::VisitorMap, SpirvWord, TranslateError> + for FlattenArguments<'a, 'b> +{ + fn visit( + &mut self, + args: ast::ParsedOperand, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + match args { + ast::ParsedOperand::Reg(r) => self.reg(r), + ast::ParsedOperand::Imm(x) => self.immediate(x, type_space), + ast::ParsedOperand::RegOffset(reg, offset) => { + self.reg_offset(reg, offset, type_space, is_dst) + } + ast::ParsedOperand::VecMember(vec, member) => { + self.vec_member(vec, member, type_space, is_dst) + } + ast::ParsedOperand::VecPack(vecs) => { + self.vec_pack(vecs, type_space, is_dst, relaxed_type_check) + } + } + } + + fn visit_ident( + &mut self, + name: ::Ident, + _type_space: Option<(&ast::Type, ast::StateSpace)>, + _is_dst: bool, + _relaxed_type_check: bool, + ) -> Result<::Ident, TranslateError> { + self.reg(name) + } +} + +impl Drop for FlattenArguments<'_, '_> { + fn drop(&mut self) { + self.result.extend(self.post_stmts.drain(..)); + } +} diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs new file mode 100644 index 0000000..97f6356 --- /dev/null +++ b/ptx/src/pass/fix_special_registers2.rs @@ -0,0 +1,209 @@ +use super::*; + +pub(super) fn run<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + special_registers: &'a SpecialRegistersMap2, + directives: Vec>, +) -> Result>, TranslateError> { + let declarations = SpecialRegistersMap2::generate_declarations(resolver); + let mut result = Vec::with_capacity(declarations.len() + directives.len()); + let mut sreg_to_function = + FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default()); + for (sreg, declaration) in declarations { + let name = if let ast::MethodName::Func(name) = declaration.name { + name + } else { + return Err(error_unreachable()); + }; + result.push(UnconditionalDirective::Method(UnconditionalFunction { + func_decl: declaration, + globals: Vec::new(), + body: None, + import_as: None, + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + })); + sreg_to_function.insert(sreg, name); + } + let mut visitor = SpecialRegisterResolver { + resolver, + special_registers, + sreg_to_function, + result: Vec::new(), + }; + directives + .into_iter() + .map(|directive| run_directive(&mut visitor, directive)) + .collect::, _>>() +} + +fn run_directive<'a, 'input>( + visitor: &mut SpecialRegisterResolver<'a, 'input>, + directive: UnconditionalDirective<'input>, +) -> Result, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?), + }) +} + +fn run_method<'a, 'input>( + visitor: &mut SpecialRegisterResolver<'a, 'input>, + method: UnconditionalFunction<'input>, +) -> Result, TranslateError> { + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(visitor, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { + func_decl: method.func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'a, 'input>( + visitor: &mut SpecialRegisterResolver<'a, 'input>, + result: &mut Vec, + statement: UnconditionalStatement, +) -> Result<(), TranslateError> { + let converted_statement = statement.visit_map(visitor)?; + result.extend(visitor.result.drain(..)); + result.push(converted_statement); + Ok(()) +} + +struct SpecialRegisterResolver<'a, 'input> { + resolver: &'a mut GlobalStringIdentResolver2<'input>, + special_registers: &'a SpecialRegistersMap2, + sreg_to_function: FxHashMap, + result: Vec, +} + +impl<'a, 'b, 'input> + ast::VisitorMap, ast::ParsedOperand, TranslateError> + for SpecialRegisterResolver<'a, 'input> +{ + fn visit( + &mut self, + operand: ast::ParsedOperand, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result, TranslateError> { + map_operand(operand, &mut |ident, vector_index| { + self.replace_sreg(ident, vector_index, is_dst) + }) + } + + fn visit_ident( + &mut self, + args: SpirvWord, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + self.replace_sreg(args, None, is_dst) + } +} + +impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> { + fn replace_sreg( + &mut self, + name: SpirvWord, + vector_index: Option, + is_dst: bool, + ) -> Result { + if let Some(sreg) = self.special_registers.get(name) { + if is_dst { + return Err(error_mismatched_type()); + } + let input_arguments = match (vector_index, sreg.get_function_input_type()) { + (Some(idx), Some(inp_type)) => { + if inp_type != ast::ScalarType::U8 { + return Err(TranslateError::Unreachable); + } + let constant = self.resolver.register_unnamed(Some(( + ast::Type::Scalar(inp_type), + ast::StateSpace::Reg, + ))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: constant, + typ: inp_type, + value: ast::ImmediateValue::U64(idx as u64), + })); + vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)] + } + (None, None) => Vec::new(), + _ => return Err(error_mismatched_type()), + }; + let return_type = sreg.get_function_return_type(); + let fn_result = self + .resolver + .register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg))); + let return_arguments = vec![( + fn_result, + ast::Type::Scalar(return_type), + ast::StateSpace::Reg, + )]; + let data = ast::CallDetails { + uniform: false, + return_arguments: return_arguments + .iter() + .map(|(_, typ, space)| (typ.clone(), *space)) + .collect(), + input_arguments: input_arguments + .iter() + .map(|(_, typ, space)| (typ.clone(), *space)) + .collect(), + }; + let arguments = ast::CallArgs::> { + return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(), + func: self.sreg_to_function[&sreg], + input_arguments: input_arguments + .iter() + .map(|(name, _, _)| ast::ParsedOperand::Reg(*name)) + .collect(), + }; + self.result + .push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })); + Ok(fn_result) + } else { + Ok(name) + } + } +} + +pub fn map_operand( + this: ast::ParsedOperand, + fn_: &mut impl FnMut(T, Option) -> Result, +) -> Result, Err> { + Ok(match this { + ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?), + ast::ParsedOperand::RegOffset(ident, offset) => { + ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset) + } + ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm), + ast::ParsedOperand::VecMember(ident, member) => { + ast::ParsedOperand::Reg(fn_(ident, Some(member))?) + } + ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack( + idents + .into_iter() + .map(|ident| fn_(ident, None)) + .collect::, _>>()?, + ), + }) +} diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs new file mode 100644 index 0000000..e8f01cd --- /dev/null +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -0,0 +1,273 @@ +use super::*; +use ptx_parser::VisitorMap; +use rustc_hash::FxHashSet; + +// This pass: +// * Turns all .local, .param and .reg in-body variables into .local variables +// (if _not_ an input method argument) +// * Inserts explicit `ld`/`st` for newly converted .reg variables +// * Fixup state space of all existing `ld`/`st` instructions into newly +// converted variables +// * Turns `.entry` input arguments into param::entry and all related `.param` +// loads into `param::entry` loads +// * All `.func` input arguments are turned into `.reg` arguments by another +// pass, so we do nothing there +pub(super) fn run<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: Directive2<'input, ast::Instruction, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => { + let visitor = InsertMemSSAVisitor::new(resolver); + Directive2::Method(run_method(visitor, method)?) + } + }) +} + +fn run_method<'a, 'input>( + mut visitor: InsertMemSSAVisitor<'a, 'input>, + method: Function2<'input, ast::Instruction, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + let mut func_decl = method.func_decl; + for arg in func_decl.return_arguments.iter_mut() { + visitor.visit_variable(arg); + } + let is_kernel = func_decl.name.is_kernel(); + // let mut prelude = Vec::with_capacity(method.body.as_ref().map(Vec::len).unwrap_or(0)); + if is_kernel { + for arg in func_decl.input_arguments.iter_mut() { + let old_name = arg.name; + let old_space = arg.state_space; + let new_space = ast::StateSpace::ParamEntry; + let new_name = visitor + .resolver + .register_unnamed(Some((arg.v_type.clone(), new_space))); + visitor.input_argument(old_name, new_name, old_space); + arg.name = new_name; + arg.state_space = new_space; + } + }; + let body = method + .body + .map(move |statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(&mut visitor, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { + func_decl: func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'a, 'input>( + visitor: &mut InsertMemSSAVisitor<'a, 'input>, + result: &mut Vec, + statement: ExpandedStatement, +) -> Result<(), TranslateError> { + match statement { + Statement::Variable(mut var) => { + visitor.visit_variable(&mut var); + result.push(Statement::Variable(var)); + } + Statement::Instruction(ast::Instruction::Ld { data, arguments }) => { + let instruction = visitor.visit_ld(data, arguments)?; + let instruction = ast::visit_map(instruction, visitor)?; + result.push(Statement::Instruction(instruction)); + } + Statement::Instruction(ast::Instruction::St { + data, + mut arguments, + }) => { + let instruction = visitor.visit_st(data, arguments)?; + let instruction = ast::visit_map(instruction, visitor)?; + result.push(Statement::Instruction(instruction)); + } + s => result.push(s.visit_map(visitor)?), + } + Ok(()) +} + +struct InsertMemSSAVisitor<'a, 'input> { + resolver: &'a mut GlobalStringIdentResolver2<'input>, + variables: FxHashMap, +} + +impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { + fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self { + Self { + resolver, + variables: FxHashMap::default(), + } + } + + fn input_argument( + &mut self, + old_name: SpirvWord, + new_name: SpirvWord, + old_space: ast::StateSpace, + ) -> Result<(), TranslateError> { + if old_space != ast::StateSpace::Param { + return Err(error_unreachable()); + } + self.variables.insert( + old_name, + RemapAction::LDStSpaceChange { + name: new_name, + old_space, + new_space: ast::StateSpace::ParamEntry, + }, + ); + Ok(()) + } + + fn variable( + &mut self, + old_name: SpirvWord, + new_name: SpirvWord, + old_space: ast::StateSpace, + ) -> Result<(), TranslateError> { + Ok(match old_space { + ast::StateSpace::Reg => { + self.variables + .insert(old_name, RemapAction::PreLdPostSt(new_name)); + } + ast::StateSpace::Param => { + self.variables.insert( + old_name, + RemapAction::LDStSpaceChange { + old_space, + new_space: ast::StateSpace::Local, + name: new_name, + }, + ); + } + // Good as-is + ast::StateSpace::Local => {} + // Will be pulled into global scope later + ast::StateSpace::Generic + | ast::StateSpace::SharedCluster + | ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::SharedCta + | ast::StateSpace::Shared => {} + ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => { + return Err(error_unreachable()) + } + }) + } + + fn visit_st( + &self, + mut data: ast::StData, + mut arguments: ast::StArgs, + ) -> Result, TranslateError> { + if let Some(remap) = self.variables.get(&arguments.src1) { + match remap { + RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()), + RemapAction::LDStSpaceChange { + old_space, + new_space, + name, + } => { + if data.state_space != *old_space { + return Err(error_mismatched_type()); + } + data.state_space = *new_space; + arguments.src1 = *name; + } + } + } + Ok(ast::Instruction::St { data, arguments }) + } + + fn visit_ld( + &self, + mut data: ast::LdDetails, + mut arguments: ast::LdArgs, + ) -> Result, TranslateError> { + if let Some(remap) = self.variables.get(&arguments.src) { + match remap { + RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()), + RemapAction::LDStSpaceChange { + old_space, + new_space, + name, + } => { + if data.state_space != *old_space { + return Err(error_mismatched_type()); + } + data.state_space = *new_space; + arguments.src = *name; + } + } + } + Ok(ast::Instruction::Ld { data, arguments }) + } + + fn visit_variable(&mut self, var: &mut ast::Variable) { + if var.state_space != ast::StateSpace::Local { + let old_name = var.name; + let old_space = var.state_space; + let new_space = ast::StateSpace::Local; + let new_name = self + .resolver + .register_unnamed(Some((var.v_type.clone(), new_space))); + self.variable(old_name, new_name, old_space); + var.name = new_name; + var.state_space = new_space; + } + } +} + +impl<'a, 'input> ast::VisitorMap + for InsertMemSSAVisitor<'a, 'input> +{ + fn visit( + &mut self, + args: SpirvWord, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + todo!() + } + + fn visit_ident( + &mut self, + args: SpirvWord, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + self.visit(args, type_space, is_dst, relaxed_type_check) + } +} + +#[derive(Clone, Copy)] +enum RemapAction { + PreLdPostSt(SpirvWord), + LDStSpaceChange { + old_space: ast::StateSpace, + new_space: ast::StateSpace, + name: SpirvWord, + }, +} diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs index 3249b82..c04fa09 100644 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -45,6 +45,13 @@ pub(super) fn run( Statement::RepackVector(repack), )?; } + Statement::VectorAccess(vector_access) => { + insert_implicit_conversions_impl( + &mut result, + id_def, + Statement::VectorAccess(vector_access), + )?; + } s @ Statement::Conditional(_) | s @ Statement::Conversion(_) | s @ Statement::Label(_) diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 04d3e49..b82d3c5 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -13,15 +13,21 @@ use std::{ mem, rc::Rc, }; +use strum::IntoEnumIterator; +use strum_macros::EnumIter; mod convert_dynamic_shared_memory_usage; mod convert_to_stateful_memory_access; mod convert_to_typed; +mod deparamize_functions; pub(crate) mod emit_llvm; mod emit_spirv; mod expand_arguments; +mod expand_operands; mod extract_globals; mod fix_special_registers; +mod fix_special_registers2; +mod insert_explicit_load_store; mod insert_implicit_conversions; mod insert_mem_ssa_statements; mod normalize_identifiers; @@ -68,6 +74,20 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result(ast: ast::Module<'input>) -> Result { + let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); + let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); + let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?; + let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; + let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; + let directives = resolve_function_pointers::run(directives)?; + let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?; + let directives = expand_operands::run(&mut flat_resolver, directives)?; + let directives = deparamize_functions::run(&mut flat_resolver, directives)?; + let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; + todo!() +} + fn translate_directive<'input, 'a>( id_defs: &'a mut GlobalStringIdResolver<'input>, ptx_impl_imports: &'a mut HashMap>, @@ -323,7 +343,7 @@ pub struct KernelInfo { pub uses_shared_mem: bool, } -#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] +#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)] enum PtxSpecialRegister { Tid, Ntid, @@ -346,6 +366,17 @@ impl PtxSpecialRegister { } } + fn as_str(self) -> &'static str { + match self { + Self::Tid => "%tid", + Self::Ntid => "%ntid", + Self::Ctaid => "%ctaid", + Self::Nctaid => "%nctaid", + Self::Clock => "%clock", + Self::LanemaskLt => "%lanemask_lt", + } + } + fn get_type(self) -> ast::Type { match self { PtxSpecialRegister::Tid @@ -726,6 +757,7 @@ enum Statement { PtrAccess(PtrAccess

), RepackVector(RepackVectorDetails), FunctionPointer(FunctionPointerDetails), + VectorAccess(VectorAccess), } impl> Statement, T> { @@ -894,6 +926,36 @@ impl> Statement, T> { offset_src, }) } + Statement::VectorAccess(VectorAccess { + scalar_type, + vector_width, + dst, + src: vector_src, + member, + }) => { + let dst: SpirvWord = visitor.visit_ident( + dst, + Some((&scalar_type.into(), ast::StateSpace::Reg)), + true, + false, + )?; + let src = visitor.visit_ident( + vector_src, + Some(( + &ast::Type::Vector(vector_width, scalar_type), + ast::StateSpace::Reg, + )), + false, + false, + )?; + Statement::VectorAccess(VectorAccess { + scalar_type, + vector_width, + dst, + src, + member, + }) + } Statement::RepackVector(RepackVectorDetails { is_extract, typ, @@ -1448,6 +1510,7 @@ fn compute_denorm_information<'input>( Statement::Label(_) => {} Statement::Variable(_) => {} Statement::PtrAccess { .. } => {} + Statement::VectorAccess { .. } => {} Statement::RepackVector(_) => {} Statement::FunctionPointer(_) => {} } @@ -1668,7 +1731,7 @@ pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> { } pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> { - pub func_decl: Rc>>, + pub func_decl: ast::MethodDeclaration<'input, SpirvWord>, pub globals: Vec>, pub body: Option>>, import_as: Option, @@ -1712,10 +1775,31 @@ struct GlobalStringIdentResolver2<'input> { } impl<'input> GlobalStringIdentResolver2<'input> { - fn register_intermediate( + fn new(spirv_word: SpirvWord) -> Self { + Self { + current_id: spirv_word, + ident_map: FxHashMap::default(), + } + } + + fn register_named( &mut self, + name: Cow<'input, str>, type_space: Option<(ast::Type, ast::StateSpace)>, ) -> SpirvWord { + let new_id = self.current_id; + self.ident_map.insert( + new_id, + IdentEntry { + name: Some(name), + type_space, + }, + ); + self.current_id.0 += 1; + new_id + } + + fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord { let new_id = self.current_id; self.ident_map.insert( new_id, @@ -1727,9 +1811,191 @@ impl<'input> GlobalStringIdentResolver2<'input> { self.current_id.0 += 1; new_id } + + fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> { + match self.ident_map.get(&id) { + Some(IdentEntry { + type_space: Some(type_space), + .. + }) => Ok(type_space), + _ => Err(error_unknown_symbol()), + } + } } struct IdentEntry<'input> { name: Option>, type_space: Option<(ast::Type, ast::StateSpace)>, } + +struct ScopedResolver<'input, 'b> { + flat_resolver: &'b mut GlobalStringIdentResolver2<'input>, + scopes: Vec>, +} + +impl<'input, 'b> ScopedResolver<'input, 'b> { + fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self { + Self { + flat_resolver, + scopes: vec![ScopeMarker::new()], + } + } + + fn start_scope(&mut self) { + self.scopes.push(ScopeMarker::new()); + } + + fn end_scope(&mut self) { + let scope = self.scopes.pop().unwrap(); + scope.flush(self.flat_resolver); + } + + fn add( + &mut self, + name: Cow<'input, str>, + type_space: Option<(ast::Type, ast::StateSpace)>, + ) -> Result { + let result = self.flat_resolver.current_id; + self.flat_resolver.current_id.0 += 1; + let current_scope = self.scopes.last_mut().unwrap(); + if current_scope + .name_to_ident + .insert(name.clone(), result) + .is_some() + { + return Err(error_unknown_symbol()); + } + current_scope.ident_map.insert( + result, + IdentEntry { + name: Some(name), + type_space, + }, + ); + Ok(result) + } + + fn get(&mut self, name: &str) -> Result { + self.scopes + .iter() + .rev() + .find_map(|resolver| resolver.name_to_ident.get(name).copied()) + .ok_or_else(|| error_unreachable()) + } + + fn get_in_current_scope(&self, label: &'input str) -> Result { + let current_scope = self.scopes.last().unwrap(); + current_scope + .name_to_ident + .get(label) + .copied() + .ok_or_else(|| error_unreachable()) + } +} + +struct ScopeMarker<'input> { + ident_map: FxHashMap>, + name_to_ident: FxHashMap, SpirvWord>, +} + +impl<'input> ScopeMarker<'input> { + fn new() -> Self { + Self { + ident_map: FxHashMap::default(), + name_to_ident: FxHashMap::default(), + } + } + + fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) { + resolver.ident_map.extend(self.ident_map); + } +} + +struct SpecialRegistersMap2 { + reg_to_id: FxHashMap, + id_to_reg: FxHashMap, +} + +impl SpecialRegistersMap2 { + fn new(resolver: &mut ScopedResolver) -> Result { + let mut result = SpecialRegistersMap2 { + reg_to_id: FxHashMap::default(), + id_to_reg: FxHashMap::default(), + }; + for sreg in PtxSpecialRegister::iter() { + let text = sreg.as_str(); + let id = resolver.add( + Cow::Borrowed(text), + Some((sreg.get_type(), ast::StateSpace::Reg)), + )?; + result.reg_to_id.insert(sreg, id); + result.id_to_reg.insert(id, sreg); + } + Ok(result) + } + + fn get(&self, id: SpirvWord) -> Option { + self.id_to_reg.get(&id).copied() + } + + fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord { + match self.reg_to_id.entry(reg) { + hash_map::Entry::Occupied(e) => *e.get(), + hash_map::Entry::Vacant(e) => { + let numeric_id = SpirvWord(current_id.0); + current_id.0 += 1; + e.insert(numeric_id); + self.id_to_reg.insert(numeric_id, reg); + numeric_id + } + } + } + + fn generate_declarations<'a, 'input>( + resolver: &'a mut GlobalStringIdentResolver2<'input>, + ) -> impl ExactSizeIterator< + Item = ( + PtxSpecialRegister, + ast::MethodDeclaration<'input, SpirvWord>, + ), + > + 'a { + PtxSpecialRegister::iter().map(|sreg| { + let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); + let name = + ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None)); + let return_type = sreg.get_function_return_type(); + let input_type = sreg.get_function_return_type(); + ( + sreg, + ast::MethodDeclaration { + return_arguments: vec![ast::Variable { + align: None, + v_type: return_type.into(), + state_space: ast::StateSpace::Reg, + name: resolver + .register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))), + array_init: Vec::new(), + }], + name: name, + input_arguments: vec![ast::Variable { + align: None, + v_type: input_type.into(), + state_space: ast::StateSpace::Reg, + name: resolver + .register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))), + array_init: Vec::new(), + }], + shared_mem: None, + }, + ) + }) + } +} + +pub struct VectorAccess { + scalar_type: ast::ScalarType, + vector_width: u8, + dst: SpirvWord, + src: SpirvWord, + member: u8, +} diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs index e3fb88d..beaf08b 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -2,21 +2,21 @@ use super::*; use ptx_parser as ast; use rustc_hash::FxHashMap; -pub(crate) fn run<'input>( - fn_defs: &mut GlobalStringIdentResolver2<'input>, +pub(crate) fn run<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, directives: Vec>>, ) -> Result>, TranslateError> { - let mut resolver = NameResolver::new(fn_defs); + resolver.start_scope(); let result = directives .into_iter() - .map(|directive| run_directive(&mut resolver, directive)) + .map(|directive| run_directive(resolver, directive)) .collect::, _>>()?; resolver.end_scope(); Ok(result) } fn run_directive<'input, 'b>( - resolver: &mut NameResolver<'input, 'b>, + resolver: &mut ScopedResolver<'input, 'b>, directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>, ) -> Result, TranslateError> { Ok(match directive { @@ -30,7 +30,7 @@ fn run_directive<'input, 'b>( } fn run_method<'input, 'b>( - resolver: &mut NameResolver<'input, 'b>, + resolver: &mut ScopedResolver<'input, 'b>, linkage: ast::LinkingDirective, method: ast::Function<'input, &'input str, ast::Statement>>, ) -> Result, TranslateError> { @@ -41,11 +41,7 @@ fn run_method<'input, 'b>( } }; resolver.start_scope(); - let func_decl = Rc::new(RefCell::new(run_function_decl( - resolver, - method.func_directive, - name, - )?)); + let func_decl = run_function_decl(resolver, method.func_directive, name)?; let body = method .body .map(|statements| { @@ -66,7 +62,7 @@ fn run_method<'input, 'b>( } fn run_function_decl<'input, 'b>( - resolver: &mut NameResolver<'input, 'b>, + resolver: &mut ScopedResolver<'input, 'b>, func_directive: ast::MethodDeclaration<'input, &'input str>, name: ast::MethodName<'input, SpirvWord>, ) -> Result, TranslateError> { @@ -90,7 +86,7 @@ fn run_function_decl<'input, 'b>( } fn run_variable<'input, 'b>( - resolver: &mut NameResolver<'input, 'b>, + resolver: &mut ScopedResolver<'input, 'b>, variable: ast::Variable<&'input str>, ) -> Result, TranslateError> { Ok(ast::Variable { @@ -106,7 +102,7 @@ fn run_variable<'input, 'b>( } fn run_statements<'input, 'b>( - resolver: &mut NameResolver<'input, 'b>, + resolver: &mut ScopedResolver<'input, 'b>, result: &mut Vec, statements: Vec>>, ) -> Result<(), TranslateError> { @@ -148,7 +144,7 @@ fn run_statements<'input, 'b>( } fn run_instruction<'input, 'b>( - resolver: &mut NameResolver<'input, 'b>, + resolver: &mut ScopedResolver<'input, 'b>, instruction: ast::Instruction>, ) -> Result>, TranslateError> { ast::visit_map(instruction, &mut |name: &'input str, @@ -163,7 +159,7 @@ fn run_instruction<'input, 'b>( } fn run_multivariable<'input, 'b>( - resolver: &mut NameResolver<'input, 'b>, + resolver: &mut ScopedResolver<'input, 'b>, result: &mut Vec, variable: ast::MultiVariable<&'input str>, ) -> Result<(), TranslateError> { @@ -201,86 +197,3 @@ fn run_multivariable<'input, 'b>( } Ok(()) } - -struct NameResolver<'input, 'b> { - flat_resolver: &'b mut GlobalStringIdentResolver2<'input>, - scopes: Vec>, -} - -impl<'input, 'b> NameResolver<'input, 'b> { - fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self { - Self { - flat_resolver, - scopes: vec![ScopeStringIdentResolver::new()], - } - } - - fn start_scope(&mut self) { - self.scopes.push(ScopeStringIdentResolver::new()); - } - - fn end_scope(&mut self) { - let scope = self.scopes.pop().unwrap(); - scope.flush(self.flat_resolver); - } - - fn add( - &mut self, - name: Cow<'input, str>, - type_space: Option<(ast::Type, ast::StateSpace)>, - ) -> Result { - let result = self.flat_resolver.current_id; - self.flat_resolver.current_id.0 += 1; - let current_scope = self.scopes.last_mut().unwrap(); - if current_scope - .name_to_ident - .insert(name.clone(), result) - .is_some() - { - return Err(error_unknown_symbol()); - } - current_scope.ident_map.insert( - result, - IdentEntry { - name: Some(name), - type_space, - }, - ); - Ok(result) - } - - fn get(&mut self, name: &str) -> Result { - self.scopes - .iter() - .rev() - .find_map(|resolver| resolver.name_to_ident.get(name).copied()) - .ok_or_else(|| error_unreachable()) - } - - fn get_in_current_scope(&self, label: &'input str) -> Result { - let current_scope = self.scopes.last().unwrap(); - current_scope - .name_to_ident - .get(label) - .copied() - .ok_or_else(|| error_unreachable()) - } -} - -struct ScopeStringIdentResolver<'input> { - ident_map: FxHashMap>, - name_to_ident: FxHashMap, SpirvWord>, -} - -impl<'input> ScopeStringIdentResolver<'input> { - fn new() -> Self { - Self { - ident_map: FxHashMap::default(), - name_to_ident: FxHashMap::default(), - } - } - - fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) { - resolver.ident_map.extend(self.ident_map); - } -} diff --git a/ptx/src/pass/normalize_labels.rs b/ptx/src/pass/normalize_labels.rs index 097d87c..037e918 100644 --- a/ptx/src/pass/normalize_labels.rs +++ b/ptx/src/pass/normalize_labels.rs @@ -26,6 +26,7 @@ pub(super) fn run( | Statement::Constant(..) | Statement::Label(..) | Statement::PtrAccess { .. } + | Statement::VectorAccess { .. } | Statement::RepackVector(..) | Statement::FunctionPointer(..) => {} } diff --git a/ptx/src/pass/normalize_predicates2.rs b/ptx/src/pass/normalize_predicates2.rs index 2d15bba..d91e23c 100644 --- a/ptx/src/pass/normalize_predicates2.rs +++ b/ptx/src/pass/normalize_predicates2.rs @@ -55,8 +55,8 @@ fn run_statement<'input>( Statement::Variable(var) => result.push(Statement::Variable(var)), Statement::Instruction((predicate, instruction)) => { if let Some(pred) = predicate { - let if_true = resolver.register_intermediate(None); - let if_false = resolver.register_intermediate(None); + let if_true = resolver.register_unnamed(None); + let if_false = resolver.register_unnamed(None); let folded_bra = match &instruction { ast::Instruction::Bra { arguments, .. } => Some(arguments.src), _ => None, diff --git a/ptx/src/pass/resolve_function_pointers.rs b/ptx/src/pass/resolve_function_pointers.rs index 9aaa694..eb7abb1 100644 --- a/ptx/src/pass/resolve_function_pointers.rs +++ b/ptx/src/pass/resolve_function_pointers.rs @@ -20,7 +20,7 @@ fn run_directive<'input>( var @ Directive2::Variable(..) => var, Directive2::Method(method) => { { - let func_decl = method.func_decl.borrow(); + let func_decl = &method.func_decl; match func_decl.name { ptx_parser::MethodName::Kernel(_) => {} ptx_parser::MethodName::Func(name) => { -- cgit v1.2.3