From 2d652ff9c81f81e4136316e15165cc6b6ebe96e4 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 18 Nov 2020 00:38:33 +0100 Subject: Fix various bugs --- ptx/src/test/spirv_run/ld_st_stateful.ptx | 25 --- ptx/src/test/spirv_run/ld_st_stateful.spvtxt | 57 ------- ptx/src/test/spirv_run/stateful_ld_st_simple.ptx | 25 +++ .../test/spirv_run/stateful_ld_st_simple.spvtxt | 57 +++++++ ptx/src/translate.rs | 168 +++++++++++---------- 5 files changed, 171 insertions(+), 161 deletions(-) delete mode 100644 ptx/src/test/spirv_run/ld_st_stateful.ptx delete mode 100644 ptx/src/test/spirv_run/ld_st_stateful.spvtxt create mode 100644 ptx/src/test/spirv_run/stateful_ld_st_simple.ptx create mode 100644 ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt diff --git a/ptx/src/test/spirv_run/ld_st_stateful.ptx b/ptx/src/test/spirv_run/ld_st_stateful.ptx deleted file mode 100644 index 859b169..0000000 --- a/ptx/src/test/spirv_run/ld_st_stateful.ptx +++ /dev/null @@ -1,25 +0,0 @@ -.version 6.5 -.target sm_30 -.address_size 64 - -.visible .entry ld_st_stateful( - .param .u64 input, - .param .u64 output -) -{ - .reg .u64 in_addr; - .reg .u64 out_addr; - .reg .u64 in_addr2; - .reg .u64 out_addr2; - .reg .u64 temp; - - ld.param.u64 in_addr, [input]; - ld.param.u64 out_addr, [output]; - - cvta.to.global.u64 in_addr2, in_addr; - cvta.to.global.u64 out_addr2, out_addr; - - ld.global.u64 temp, [in_addr2]; - st.global.u64 [out_addr2], temp; - ret; -} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/ld_st_stateful.spvtxt b/ptx/src/test/spirv_run/ld_st_stateful.spvtxt deleted file mode 100644 index 963d88a..0000000 --- a/ptx/src/test/spirv_run/ld_st_stateful.spvtxt +++ /dev/null @@ -1,57 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int8 - OpCapability Int16 - OpCapability Int64 - OpCapability Float16 - OpCapability Float64 - %30 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "ld_st_offset" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %33 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong - %uint = OpTypeInt 32 0 -%_ptr_Function_uint = OpTypePointer Function %uint -%_ptr_Generic_uint = OpTypePointer Generic %uint - %ulong_4 = OpConstant %ulong 4 - %ulong_4_0 = OpConstant %ulong 4 - %1 = OpFunction %void None %33 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %28 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_uint Function - %7 = OpVariable %_ptr_Function_uint Function - OpStore %2 %8 - OpStore %3 %9 - %10 = OpLoad %ulong %2 - OpStore %4 %10 - %11 = OpLoad %ulong %3 - OpStore %5 %11 - %13 = OpLoad %ulong %4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %13 - %12 = OpLoad %uint %24 - OpStore %6 %12 - %15 = OpLoad %ulong %4 - %21 = OpIAdd %ulong %15 %ulong_4 - %25 = OpConvertUToPtr %_ptr_Generic_uint %21 - %14 = OpLoad %uint %25 - OpStore %7 %14 - %16 = OpLoad %ulong %5 - %17 = OpLoad %uint %7 - %26 = OpConvertUToPtr %_ptr_Generic_uint %16 - OpStore %26 %17 - %18 = OpLoad %ulong %5 - %19 = OpLoad %uint %6 - %23 = OpIAdd %ulong %18 %ulong_4_0 - %27 = OpConvertUToPtr %_ptr_Generic_uint %23 - OpStore %27 %19 - OpReturn - OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx b/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx new file mode 100644 index 0000000..5650ada --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx @@ -0,0 +1,25 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry stateful_ld_st_simple( + .param .u64 input, + .param .u64 output +) +{ + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .u64 in_addr2; + .reg .u64 out_addr2; + .reg .u64 temp; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + cvta.to.global.u64 in_addr2, in_addr; + cvta.to.global.u64 out_addr2, out_addr; + + ld.global.u64 temp, [in_addr2]; + st.global.u64 [out_addr2], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt new file mode 100644 index 0000000..963d88a --- /dev/null +++ b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt @@ -0,0 +1,57 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %30 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "ld_st_offset" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %33 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong + %uint = OpTypeInt 32 0 +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_Generic_uint = OpTypePointer Generic %uint + %ulong_4 = OpConstant %ulong 4 + %ulong_4_0 = OpConstant %ulong 4 + %1 = OpFunction %void None %33 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %28 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_uint Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %2 %8 + OpStore %3 %9 + %10 = OpLoad %ulong %2 + OpStore %4 %10 + %11 = OpLoad %ulong %3 + OpStore %5 %11 + %13 = OpLoad %ulong %4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %13 + %12 = OpLoad %uint %24 + OpStore %6 %12 + %15 = OpLoad %ulong %4 + %21 = OpIAdd %ulong %15 %ulong_4 + %25 = OpConvertUToPtr %_ptr_Generic_uint %21 + %14 = OpLoad %uint %25 + OpStore %7 %14 + %16 = OpLoad %ulong %5 + %17 = OpLoad %uint %7 + %26 = OpConvertUToPtr %_ptr_Generic_uint %16 + OpStore %26 %17 + %18 = OpLoad %ulong %5 + %19 = OpLoad %uint %6 + %23 = OpIAdd %ulong %18 %ulong_4_0 + %27 = OpConvertUToPtr %_ptr_Generic_uint %23 + OpStore %27 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 6fc35d2..f644a27 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -857,6 +857,8 @@ fn replace_uses_of_shared_memory<'a>( ast::LdStateSpace::Shared, ), kind: ConversionKind::PtrToPtr { spirv_ptr: true }, + src_sema: ArgumentSemantics::Default, + dst_sema: ArgumentSemantics::Default, })); } replacement_id @@ -1297,15 +1299,11 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?; - let mut numeric_id_defs = numeric_id_defs.finish(); - let (typed_statements, temporaries) = + let typed_statements = convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; - let ssa_statements = insert_mem_ssa_statements( - typed_statements, - &mut numeric_id_defs, - &mut spirv_decl, - temporaries, - )?; + let ssa_statements = + insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, &mut spirv_decl)?; + let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; @@ -1927,7 +1925,6 @@ fn insert_mem_ssa_statements<'a, 'b>( func: Vec, id_def: &mut NumericIdResolver, fn_decl: &mut SpirvMethodDecl, - temporaries: HashSet, ) -> Result, TranslateError> { let mut result = Vec::with_capacity(func.len()); for arg in fn_decl.output.iter() { @@ -1969,7 +1966,7 @@ fn insert_mem_ssa_statements<'a, 'b>( for s in func { match s { Statement::Call(call) => { - insert_mem_ssa_statement_default(id_def, &mut result, &temporaries, call.cast())? + insert_mem_ssa_statement_default(id_def, &mut result, call.cast())? } Statement::Instruction(inst) => match inst { ast::Instruction::Ret(d) => { @@ -1989,7 +1986,7 @@ fn insert_mem_ssa_statements<'a, 'b>( result.push(Statement::Instruction(ast::Instruction::Ret(d))) } } - inst => insert_mem_ssa_statement_default(id_def, &mut result, &temporaries, inst)?, + inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?, }, Statement::Conditional(mut bra) => { let generated_id = @@ -2005,7 +2002,7 @@ fn insert_mem_ssa_statements<'a, 'b>( result.push(Statement::Conditional(bra)); } Statement::Conversion(conv) => { - insert_mem_ssa_statement_default(id_def, &mut result, &temporaries, conv)? + insert_mem_ssa_statement_default(id_def, &mut result, conv)? } s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s), _ => return Err(TranslateError::Unreachable), @@ -2094,7 +2091,6 @@ impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( id_def: &mut NumericIdResolver, result: &mut Vec, - temporaries: &HashSet, stmt: F, ) -> Result<(), TranslateError> { let mut post_statements = Vec::new(); @@ -2612,13 +2608,13 @@ fn insert_implicit_conversions( )?; } s @ Statement::Conditional(_) + | s @ Statement::Conversion(_) | s @ Statement::Label(_) | s @ Statement::Constant(_) | s @ Statement::Variable(_) | s @ Statement::LoadVar(_, _) | s @ Statement::StoreVar(_, _) | s @ Statement::Undef(_, _) - | s @ Statement::Conversion(_) | s @ Statement::RetValue(_, _) => result.push(s), } } @@ -2685,6 +2681,8 @@ fn insert_implicit_conversions_impl( from, to, kind: conv_kind, + src_sema: ArgumentSemantics::Default, + dst_sema: ArgumentSemantics::Default, })); result } @@ -3741,6 +3739,8 @@ fn emit_cvt( src_t.kind(), )), kind: ConversionKind::Default, + src_sema: ArgumentSemantics::Default, + dst_sema: ArgumentSemantics::Default, }; emit_implicit_conversion(builder, map, &cv)?; new_dst @@ -4113,6 +4113,8 @@ fn emit_implicit_conversion( from: wide_bit_type, to: cv.to.clone(), kind: ConversionKind::Default, + src_sema: cv.src_sema, + dst_sema: cv.dst_sema, }, )?; } @@ -4244,8 +4246,8 @@ fn expand_map_variables<'a, 'b>( fn convert_to_stateful_memory_access<'a>( func_args: &mut SpirvMethodDecl, func_body: Vec, - id_defs: &mut MutableNumericIdResolver<'a>, -) -> Result<(Vec, HashSet), TranslateError> { + id_defs: &mut NumericIdResolver<'a>, +) -> Result, TranslateError> { let func_args_64bit = func_args .input .iter() @@ -4371,7 +4373,7 @@ fn convert_to_stateful_memory_access<'a>( let mut remapped_ids = HashMap::new(); let mut result = Vec::with_capacity(regs_ptr_current.len() + func_body.len()); for reg in regs_ptr_current { - let new_id = id_defs.new_id(ast::Type::Pointer( + let new_id = id_defs.new_variable(ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::U8), ast::LdStateSpace::Global, )); @@ -4386,7 +4388,6 @@ fn convert_to_stateful_memory_access<'a>( })); remapped_ids.insert(reg, new_id); } - let mut temporaries = HashSet::new(); for statement in func_body { match statement { l @ Statement::Label(_) => result.push(l), @@ -4402,10 +4403,9 @@ fn convert_to_stateful_memory_access<'a>( inst.visit_variable(&mut |arg_desc: ArgumentDescriptor, typ| { Ok(match remapped_ids.get(&arg_desc.op) { Some(new_id) => { - let old_type_full = id_defs.get_typed(arg_desc.op)?; + let (old_type_full, _) = id_defs.get_typed(arg_desc.op)?; let old_type = old_type_full.clone(); - let converting_id = id_defs.new_id(old_type_full); - temporaries.insert(converting_id); + let converting_id = id_defs.new_non_variable(Some(old_type_full)); if arg_desc.is_dst { post_statements.push(Statement::Conversion( ImplicitConversion { @@ -4419,6 +4419,8 @@ fn convert_to_stateful_memory_access<'a>( kind: ConversionKind::BitToPtr( ast::LdStateSpace::Global, ), + src_sema: ArgumentSemantics::Default, + dst_sema: arg_desc.sema, }, )); converting_id @@ -4432,6 +4434,8 @@ fn convert_to_stateful_memory_access<'a>( ), to: old_type, kind: ConversionKind::PtrToBit(ast::UIntType::U64), + src_sema: arg_desc.sema, + dst_sema: ArgumentSemantics::Default, })); converting_id } @@ -4441,19 +4445,23 @@ fn convert_to_stateful_memory_access<'a>( if arg_desc.is_dst { return Err(TranslateError::Unreachable); } - let old_type = id_defs.get_typed(arg_desc.op)?; + let (old_type, _) = id_defs.get_typed(arg_desc.op)?; let old_type_clone = old_type.clone(); - let converting_id = id_defs.new_id(old_type); - temporaries.insert(converting_id); + let converting_id = id_defs.new_non_variable(Some(old_type)); result.push(Statement::Conversion(ImplicitConversion { src: *new_id, dst: converting_id, from: ast::Type::Pointer( - ast::PointerType::Scalar(ast::ScalarType::U8), - ast::LdStateSpace::Global, + ast::PointerType::Pointer( + ast::ScalarType::U8, + ast::LdStateSpace::Global, + ), + ast::LdStateSpace::Param, ), to: old_type_clone, - kind: ConversionKind::PtrToBit(ast::UIntType::U64), + kind: ConversionKind::PtrToPtr { spirv_ptr: false }, + src_sema: arg_desc.sema, + dst_sema: ArgumentSemantics::Default, })); converting_id } @@ -4469,33 +4477,6 @@ fn convert_to_stateful_memory_access<'a>( Statement::Call(call) => todo!(), _ => return Err(TranslateError::Unreachable), } - /* - match statement { - statement. - /* - Statement::Instruction(inst) => result.push( - inst.visit_variable_extended( - &mut |id_desc: ArgumentDescriptor, typ| { - if let Some(new_id) = remapped_ids.get(&id_desc.op) { - if id_desc.is_dst { - panic!() - } else { - result.push(Statement::Conversion(ImplicitConversion { - src - })); - Ok(*new_id) - } - } else { - Ok(id_desc.op) - } - }, - ) - .unwrap(), - ), - s => result.push(s),\ - */ - } - */ } for arg in func_args.input.iter_mut() { if func_args_ptr.contains(&arg.name) { @@ -4505,14 +4486,14 @@ fn convert_to_stateful_memory_access<'a>( ); } } - Ok((result, temporaries)) + Ok(result) } -fn is_64_bit_integer(id_defs: &MutableNumericIdResolver, id: spirv::Word) -> bool { +fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool { match id_defs.get_typed(id) { - Ok(ast::Type::Scalar(ast::ScalarType::U64)) - | Ok(ast::Type::Scalar(ast::ScalarType::S64)) - | Ok(ast::Type::Scalar(ast::ScalarType::B64)) => true, + Ok((ast::Type::Scalar(ast::ScalarType::U64), _)) + | Ok((ast::Type::Scalar(ast::ScalarType::S64), _)) + | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true, _ => false, } } @@ -5572,9 +5553,9 @@ impl VisitVariable for ast::Instruction { fn visit_variable< 'a, F: FnMut( - ArgumentDescriptor, + ArgumentDescriptor, Option<&ast::Type>, - ) -> Result, + ) -> Result, >( self, f: &mut F, @@ -5583,30 +5564,28 @@ impl VisitVariable for ast::Instruction { } } -impl VisitVariable for ImplicitConversion { - fn visit_variable< - 'a, - F: FnMut( - ArgumentDescriptor, - Option<&ast::Type>, - ) -> Result, +impl ImplicitConversion { + fn map< + T: ArgParamsEx, + U: ArgParamsEx, + V: ArgumentMapVisitor, >( self, - f: &mut F, - ) -> Result { - let new_src = f( + visitor: &mut V, + ) -> Result, U>, TranslateError> { + let new_dst = visitor.id( ArgumentDescriptor { - op: self.src, - is_dst: false, - sema: ArgumentSemantics::Default, + op: self.dst, + is_dst: true, + sema: self.dst_sema, }, - Some(&self.from), + Some(&self.to), )?; - let new_dst = f( + let new_src = visitor.id( ArgumentDescriptor { - op: self.dst, - is_dst: true, - sema: ArgumentSemantics::Default, + op: self.src, + is_dst: false, + sema: self.src_sema, }, Some(&self.from), )?; @@ -5620,6 +5599,35 @@ impl VisitVariable for ImplicitConversion { } } +impl VisitVariable for ImplicitConversion { + fn visit_variable< + 'a, + F: FnMut( + ArgumentDescriptor, + Option<&ast::Type>, + ) -> Result, + >( + self, + f: &mut F, + ) -> Result { + self.map(f) + } +} + +impl VisitVariableExpanded for ImplicitConversion { + fn visit_variable_extended< + F: FnMut( + ArgumentDescriptor, + Option<&ast::Type>, + ) -> Result, + >( + self, + f: &mut F, + ) -> Result { + self.map(f) + } +} + impl ArgumentMapVisitor for T where T: FnMut( @@ -6051,6 +6059,8 @@ struct ImplicitConversion { from: ast::Type, to: ast::Type, kind: ConversionKind, + src_sema: ArgumentSemantics, + dst_sema: ArgumentSemantics, } #[derive(PartialEq, Copy, Clone)] -- cgit v1.2.3