From b4de21fbc5eaf33540f1121bfe7c6ba0acaff6c9 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 2 Aug 2021 01:04:05 +0200 Subject: Use calls to OpenCL builtins when translating sregs, do SPIRV->LLVM conversion on every build --- ptx/src/translate.rs | 162 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 100 insertions(+), 62 deletions(-) (limited to 'ptx/src/translate.rs') diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 5fea075..6c2c594 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -448,7 +448,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result( &mut numeric_id_defs, &mut (*func_decl).borrow_mut(), )?; - let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?; + let ssa_statements = + fix_special_registers(ptx_impl_imports, ssa_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = @@ -1269,6 +1270,7 @@ fn to_ssa<'input, 'b>( } fn fix_special_registers( + ptx_impl_imports: &mut HashMap, typed_statements: Vec, numeric_id_defs: &mut NumericIdResolver, ) -> Result, TranslateError> { @@ -1276,7 +1278,6 @@ fn fix_special_registers( for s in typed_statements { match s { Statement::LoadVar( - mut details @ LoadVarDetails { @@ -1285,48 +1286,53 @@ fn fix_special_registers( }, ) => { let index = details.member_index.unwrap().0; - if index == 3 { - result.push(Statement::Constant(ConstantDefinition { - dst: details.arg.dst, - typ: ast::ScalarType::U32, - value: ast::ImmediateValue::U64(0), - })); - } else { - let sreg_and_type = match numeric_id_defs.special_registers.get(details.arg.src) - { - Some(reg) => get_sreg_id_scalar_type(numeric_id_defs, reg), - None => None, - }; - let (sreg_src, scalar_typ, vector_width) = match sreg_and_type { - Some(sreg_and_type) => sreg_and_type, - None => { - result.push(Statement::LoadVar(details)); - continue; - } - }; - let temp_id = numeric_id_defs - .register_intermediate(Some((details.typ.clone(), details.state_space))); - let real_dst = details.arg.dst; - details.arg.dst = temp_id; - result.push(Statement::LoadVar(LoadVarDetails { - arg: Arg2 { - src: sreg_src, - dst: temp_id, - }, - state_space: ast::StateSpace::Sreg, - typ: ast::Type::Scalar(scalar_typ), - member_index: Some((index, Some(vector_width))), - })); - result.push(Statement::Conversion(ImplicitConversion { - src: temp_id, - dst: real_dst, - from_type: ast::Type::Scalar(scalar_typ), - from_space: ast::StateSpace::Sreg, - to_type: ast::Type::Scalar(ast::ScalarType::U32), - to_space: ast::StateSpace::Sreg, - kind: ConversionKind::Default, - })); - } + let sreg = numeric_id_defs + .special_registers + .get(details.arg.src) + .ok_or_else(|| error_unreachable())?; + let (ocl_name, ocl_type) = sreg.get_opencl_fn_type(); + let index_constant = numeric_id_defs.register_intermediate(Some(( + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, + ))); + result.push(Statement::Constant(ConstantDefinition { + dst: index_constant, + typ: ast::ScalarType::U32, + value: ast::ImmediateValue::U64(index as u64), + })); + let fn_result = numeric_id_defs.register_intermediate(Some(( + ast::Type::Scalar(ocl_type), + ast::StateSpace::Reg, + ))); + let return_arguments = + vec![(fn_result, ast::Type::Scalar(ocl_type), ast::StateSpace::Reg)]; + let input_arguments = vec![( + TypedOperand::Reg(index_constant), + ast::Type::Scalar(ast::ScalarType::U32), + ast::StateSpace::Reg, + )]; + let fn_call = register_external_fn_call( + numeric_id_defs, + ptx_impl_imports, + ocl_name.to_string(), + return_arguments.iter().map(|(_, typ, space)| (typ, *space)), + input_arguments.iter().map(|(_, typ, space)| (typ, *space)), + )?; + result.push(Statement::Call(ResolvedCall { + uniform: false, + return_arguments, + name: fn_call, + input_arguments, + })); + result.push(Statement::Conversion(ImplicitConversion { + src: fn_result, + dst: details.arg.dst, + from_type: ast::Type::Scalar(ocl_type), + from_space: ast::StateSpace::Reg, + to_type: ast::Type::Scalar(ast::ScalarType::U32), + to_space: ast::StateSpace::Reg, + kind: ConversionKind::Default, + })); } s => result.push(s), } @@ -1721,8 +1727,8 @@ fn instruction_to_fn_call( id_defs, ptx_impl_imports, fn_name, - return_arguments, - input_arguments, + return_arguments.iter().map(|(_, typ, state)| (typ, *state)), + input_arguments.iter().map(|(_, typ, state)| (typ, *state)), )?; Ok(Statement::Call(ResolvedCall { uniform: false, @@ -1732,12 +1738,12 @@ fn instruction_to_fn_call( })) } -fn register_external_fn_call( +fn register_external_fn_call<'a>( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, name: String, - return_arguments: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], - input_arguments: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], + return_arguments: impl Iterator, + input_arguments: impl Iterator, ) -> Result { match ptx_impl_imports.entry(name) { hash_map::Entry::Vacant(entry) => { @@ -1770,19 +1776,18 @@ fn register_external_fn_call( } } -fn fn_arguments_to_variables( +fn fn_arguments_to_variables<'a>( id_defs: &mut NumericIdResolver, - args: &[(ArgumentDescriptor, ast::Type, ast::StateSpace)], + args: impl Iterator, ) -> Vec> { - args.iter() - .map(|(_, typ, space)| ast::Variable { - align: None, - v_type: typ.clone(), - state_space: *space, - name: id_defs.register_intermediate(None), - array_init: Vec::new(), - }) - .collect::>() + args.map(|(typ, space)| ast::Variable { + align: None, + v_type: typ.clone(), + state_space: space, + name: id_defs.register_intermediate(None), + array_init: Vec::new(), + }) + .collect::>() } fn arguments_to_resolved_arguments( @@ -2226,7 +2231,7 @@ fn expand_arguments<'a, 'b>( Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), - Statement::Constant(_) => return Err(error_unreachable()), + Statement::Constant(c) => result.push(Statement::Constant(c)), } } Ok(result) @@ -4686,6 +4691,19 @@ impl PtxSpecialRegister { } } + fn get_scalar_type(self) -> ast::ScalarType { + match self { + PtxSpecialRegister::Tid + | PtxSpecialRegister::Ntid + | PtxSpecialRegister::Ctaid + | PtxSpecialRegister::Nctaid => ast::ScalarType::U32, + PtxSpecialRegister::Tid64 + | PtxSpecialRegister::Ntid64 + | PtxSpecialRegister::Ctaid64 + | PtxSpecialRegister::Nctaid64 => ast::ScalarType::U64, + } + } + fn get_builtin(self) -> spirv::BuiltIn { match self { PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => { @@ -4701,6 +4719,23 @@ impl PtxSpecialRegister { } } + fn get_opencl_fn_type(self) -> (&'static str, ast::ScalarType) { + match self { + PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => { + ("_Z12get_local_idj", ast::ScalarType::U64) + } + PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => { + ("_Z14get_local_sizej", ast::ScalarType::U64) + } + PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => { + ("_Z12get_group_idj", ast::ScalarType::U64) + } + PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => { + ("_Z14get_num_groupsj", ast::ScalarType::U64) + } + } + } + fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> { match self { PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)), @@ -4743,6 +4778,8 @@ impl SpecialRegistersMap { } fn interface(&self) -> Vec { + return Vec::new(); + /* self.reg_to_id .iter() .filter_map(|(sreg, id)| { @@ -4753,6 +4790,7 @@ impl SpecialRegistersMap { } }) .collect::>() + */ } fn get(&self, id: spirv::Word) -> Option { -- cgit v1.2.3