From dcaea507ba84e375d57c4ed051477439423ae8ef Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Sun, 20 Sep 2020 15:44:52 +0200 Subject: Add more tests --- ptx/src/test/spirv_run/local_align.spvtxt | 8 +- ptx/src/test/spirv_run/mod.rs | 11 +- ptx/src/test/spirv_run/mov_address.ptx | 15 ++ ptx/src/test/spirv_run/mov_address.spvtxt | 46 ++++++ ptx/src/test/spirv_run/reg_local.ptx | 24 ++++ ptx/src/test/spirv_run/reg_local.spvtxt | 46 ++++++ ptx/src/test/spirv_run/reg_slm.ptx | 26 ---- ptx/src/test/spirv_run/reg_slm.spvtxt | 46 ------ ptx/src/translate.rs | 224 +++++++++++++----------------- 9 files changed, 241 insertions(+), 205 deletions(-) create mode 100644 ptx/src/test/spirv_run/mov_address.ptx create mode 100644 ptx/src/test/spirv_run/mov_address.spvtxt create mode 100644 ptx/src/test/spirv_run/reg_local.ptx create mode 100644 ptx/src/test/spirv_run/reg_local.spvtxt delete mode 100644 ptx/src/test/spirv_run/reg_slm.ptx delete mode 100644 ptx/src/test/spirv_run/reg_slm.spvtxt diff --git a/ptx/src/test/spirv_run/local_align.spvtxt b/ptx/src/test/spirv_run/local_align.spvtxt index 09a3f92..2482a75 100644 --- a/ptx/src/test/spirv_run/local_align.spvtxt +++ b/ptx/src/test/spirv_run/local_align.spvtxt @@ -13,8 +13,10 @@ %25 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong %uchar = OpTypeInt 8 0 -%_arr_uchar_8 = OpTypeArray %uchar %8 -%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8 + %uint = OpTypeInt 32 0 + %uint_8 = OpConstant %uint 8 +%_arr_uchar_uint_8 = OpTypeArray %uchar %uint_8 +%_ptr_Function__arr_uchar_uint_8 = OpTypePointer Function %_arr_uchar_uint_8 %_ptr_Generic_ulong = OpTypePointer Generic %ulong %1 = OpFunction %void None %25 %8 = OpFunctionParameter %ulong @@ -22,7 +24,7 @@ %20 = OpLabel %2 = OpVariable %_ptr_Function_ulong Function %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup + %4 = OpVariable %_ptr_Function__arr_uchar_uint_8 Function %5 = OpVariable %_ptr_Function_ulong Function %6 = OpVariable %_ptr_Function_ulong Function %7 = OpVariable %_ptr_Function_ulong Function diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 4a793c4..4c9d779 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -8,7 +8,7 @@ use spirv_headers::Word; use spirv_tools_sys::{ spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env, }; -use std::collections::hash_map::Entry; +use std::{collections::hash_map::Entry, cmp}; use std::error; use std::ffi::{c_void, CStr, CString}; use std::fmt; @@ -59,8 +59,9 @@ test_ptx!(local_align, [1u64], [1u64]); test_ptx!(call, [1u64], [2u64]); test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]); test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]); -//test_ptx!(ntid, [3u32], [4u32]); -//test_ptx!(reg_slm, [12u64], [12u64]); +test_ptx!(ntid, [3u32], [4u32]); +test_ptx!(reg_local, [12u64], [12u64]); +test_ptx!(mov_address, [0xDEADu64], [0u64]); struct DisplayError { err: T, @@ -123,8 +124,8 @@ fn run_spirv + ze::SafeRepr + Copy + Debug>( kernel.set_indirect_access( ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE, )?; - let mut inp_b = ze::DeviceBuffer::::new(&mut ctx, &dev, input.len())?; - let mut out_b = ze::DeviceBuffer::::new(&mut ctx, &dev, output.len())?; + let mut inp_b = ze::DeviceBuffer::::new(&mut ctx, &dev, cmp::max(input.len(),1))?; + let mut out_b = ze::DeviceBuffer::::new(&mut ctx, &dev, cmp::max(output.len(), 1))?; let inp_b_ptr_mut: ze::BufferPtrMut = (&mut inp_b).into(); let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?; let ev0 = ze::Event::new(&event_pool, 0)?; diff --git a/ptx/src/test/spirv_run/mov_address.ptx b/ptx/src/test/spirv_run/mov_address.ptx new file mode 100644 index 0000000..433fc0e --- /dev/null +++ b/ptx/src/test/spirv_run/mov_address.ptx @@ -0,0 +1,15 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry mov_address( + .param .u64 input, + .param .u64 output +) +{ + .local .b8 __local_depot0[8]; + .reg .u64 temp; + + mov.u64 temp, __local_depot0; + ret; +} diff --git a/ptx/src/test/spirv_run/mov_address.spvtxt b/ptx/src/test/spirv_run/mov_address.spvtxt new file mode 100644 index 0000000..6810fec --- /dev/null +++ b/ptx/src/test/spirv_run/mov_address.spvtxt @@ -0,0 +1,46 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %25 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "add" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %28 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %1 = OpFunction %void None %28 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %23 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %21 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpIAdd %ulong %17 %ulong_1 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %22 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/reg_local.ptx b/ptx/src/test/spirv_run/reg_local.ptx new file mode 100644 index 0000000..fb234d8 --- /dev/null +++ b/ptx/src/test/spirv_run/reg_local.ptx @@ -0,0 +1,24 @@ +.version 6.5 +.target sm_30 +.address_size 64 + +.visible .entry reg_local( + .param .u64 input, + .param .u64 output +) +{ + .local .align 8 .b8 local_x[8]; + .reg .u64 in_addr; + .reg .u64 out_addr; + .reg .b64 temp; + .reg .s64 unused; + + ld.param.u64 in_addr, [input]; + ld.param.u64 out_addr, [output]; + + ld.global.u64 temp, [in_addr]; + st.u64 [local_x], temp; + ld.u64 temp, [local_x]; + st.global.u64 [out_addr], temp; + ret; +} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt new file mode 100644 index 0000000..6810fec --- /dev/null +++ b/ptx/src/test/spirv_run/reg_local.spvtxt @@ -0,0 +1,46 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int64 + OpCapability Int8 + %25 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "add" + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %28 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %1 = OpFunction %void None %28 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %23 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %11 = OpLoad %ulong %2 + %10 = OpCopyObject %ulong %11 + OpStore %4 %10 + %13 = OpLoad %ulong %3 + %12 = OpCopyObject %ulong %13 + OpStore %5 %12 + %15 = OpLoad %ulong %4 + %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %14 = OpLoad %ulong %21 + OpStore %6 %14 + %17 = OpLoad %ulong %6 + %16 = OpIAdd %ulong %17 %ulong_1 + OpStore %7 %16 + %18 = OpLoad %ulong %5 + %19 = OpLoad %ulong %7 + %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 + OpStore %22 %19 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/reg_slm.ptx b/ptx/src/test/spirv_run/reg_slm.ptx deleted file mode 100644 index 929d116..0000000 --- a/ptx/src/test/spirv_run/reg_slm.ptx +++ /dev/null @@ -1,26 +0,0 @@ -.version 6.5 -.target sm_30 -.address_size 64 - -.visible .entry reg_slm( - .param .u64 input, - .param .u64 output -) -{ - .local .align 8 .b8 slm[8]; - .reg .u64 in_addr; - .reg .u64 out_addr; - .reg .b64 temp; - .reg .s64 unused; - - ld.param.u64 in_addr, [input]; - ld.param.u64 out_addr, [output]; - - mov.s64 unused, slm; - - ld.global.u64 temp, [in_addr]; - st.u64 [slm], temp; - ld.u64 temp, [slm]; - st.global.u64 [out_addr], temp; - ret; -} \ No newline at end of file diff --git a/ptx/src/test/spirv_run/reg_slm.spvtxt b/ptx/src/test/spirv_run/reg_slm.spvtxt deleted file mode 100644 index 6810fec..0000000 --- a/ptx/src/test/spirv_run/reg_slm.spvtxt +++ /dev/null @@ -1,46 +0,0 @@ - OpCapability GenericPointer - OpCapability Linkage - OpCapability Addresses - OpCapability Kernel - OpCapability Int64 - OpCapability Int8 - %25 = OpExtInstImport "OpenCL.std" - OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %1 "add" - %void = OpTypeVoid - %ulong = OpTypeInt 64 0 - %28 = OpTypeFunction %void %ulong %ulong -%_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Generic_ulong = OpTypePointer Generic %ulong - %ulong_1 = OpConstant %ulong 1 - %1 = OpFunction %void None %28 - %8 = OpFunctionParameter %ulong - %9 = OpFunctionParameter %ulong - %23 = OpLabel - %2 = OpVariable %_ptr_Function_ulong Function - %3 = OpVariable %_ptr_Function_ulong Function - %4 = OpVariable %_ptr_Function_ulong Function - %5 = OpVariable %_ptr_Function_ulong Function - %6 = OpVariable %_ptr_Function_ulong Function - %7 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %8 - OpStore %3 %9 - %11 = OpLoad %ulong %2 - %10 = OpCopyObject %ulong %11 - OpStore %4 %10 - %13 = OpLoad %ulong %3 - %12 = OpCopyObject %ulong %13 - OpStore %5 %12 - %15 = OpLoad %ulong %4 - %21 = OpConvertUToPtr %_ptr_Generic_ulong %15 - %14 = OpLoad %ulong %21 - OpStore %6 %14 - %17 = OpLoad %ulong %6 - %16 = OpIAdd %ulong %17 %ulong_1 - OpStore %7 %16 - %18 = OpLoad %ulong %5 - %19 = OpLoad %ulong %7 - %22 = OpConvertUToPtr %_ptr_Generic_ulong %18 - OpStore %22 %19 - OpReturn - OpFunctionEnd diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 3f71286..3e6b495 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -217,11 +217,13 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result f, None => continue, }; + emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?; emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive)?; emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?; builder.end_function()?; @@ -229,6 +231,33 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result( builder: &mut dr::Builder, map: &mut TypeWordMap, @@ -239,7 +268,12 @@ fn emit_function_header<'a>( let fn_id = match func_directive { ast::MethodDecl::Kernel(name, _) => { let fn_id = global.get_id(name)?; - builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]); + let interface = global + .special_registers + .iter() + .map(|(_, id)| *id) + .collect::>(); + builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, interface); fn_id } ast::MethodDecl::Func(_, name, _) => name, @@ -293,7 +327,7 @@ fn emit_memory_model(builder: &mut dr::Builder) { fn to_ssa_function<'a>( id_defs: &mut GlobalStringIdResolver<'a>, f: ast::ParsedFunction<'a>, -) -> Result, TranslateError> { +) -> Result, TranslateError> { let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive); to_ssa(str_resolver, fn_resolver, fn_decl, f.body) } @@ -333,13 +367,14 @@ fn to_ssa<'input, 'b>( fn_defs: GlobalFnDeclResolver<'input, 'b>, f_args: ast::MethodDecl<'input, ExpandedArgParams>, f_body: Option>>>, -) -> Result, TranslateError> { +) -> Result, TranslateError> { let f_body = match f_body { Some(vec) => vec, None => { - return Ok(ExpandedFunction { + return Ok(Function { func_directive: f_args, body: None, + globals: Vec::new(), }) } }; @@ -357,12 +392,21 @@ fn to_ssa<'input, 'b>( let mut numeric_id_defs = numeric_id_defs.unmut(); let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); let sorted_statements = normalize_variable_decls(labeled_statements); - Ok(ExpandedFunction { + let (f_body, globals) = extract_globals(sorted_statements); + Ok(Function { func_directive: f_args, - body: Some(sorted_statements), + globals: globals, + body: Some(f_body), }) } +fn extract_globals( + sorted_statements: Vec, +) -> (Vec, Vec) { + // This fn will be used for SLM + (sorted_statements, Vec::new()) +} + fn normalize_variable_decls(mut func: Vec) -> Vec { func[1..].sort_by_key(|s| match s { Statement::Variable(_) => 0, @@ -477,7 +521,9 @@ fn add_types_to_statements( }, _ => dets, }; - Ok(Statement::Instruction(ast::Instruction::MovVector(new_dets, args))) + Ok(Statement::Instruction(ast::Instruction::MovVector( + new_dets, args, + ))) } s => Ok(s), } @@ -724,7 +770,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( (_, ArgumentSemantics::Address) => return Ok(desc.op), (t, ArgumentSemantics::RegisterPointer) | (t, ArgumentSemantics::Default) - | (t, ArgumentSemantics::Ptr) => t, + | (t, ArgumentSemantics::PhysicalPointer) => t, }; let generated_id = id_def.new_id(id_type); if !desc.is_dst { @@ -873,7 +919,7 @@ impl<'a, 'b> ArgumentMapVisitor )); Ok(result_id) } - ArgumentSemantics::Ptr => { + ArgumentSemantics::PhysicalPointer => { let scalar_t = ast::ScalarType::U64; let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t)); let result_id = self.id_def.new_id(typ); @@ -1137,7 +1183,7 @@ fn emit_function_body_ops( builder.begin_block(Some(*id))?; } _ => { - if builder.block.is_none() { + if builder.block.is_none() && builder.function.is_some() { builder.begin_block(None)?; } } @@ -1166,10 +1212,9 @@ fn emit_function_body_ops( name, }) => { let st_class = match v_type { - ast::VariableType::Reg(_) | ast::VariableType::Param(_) => { - spirv::StorageClass::Function - } - ast::VariableType::Local(_) => spirv::StorageClass::Workgroup, + ast::VariableType::Reg(_) + | ast::VariableType::Param(_) + | ast::VariableType::Local(_) => spirv::StorageClass::Function, }; let type_id = map.get_or_add( builder, @@ -1234,7 +1279,7 @@ fn emit_function_body_ops( ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { builder.load(result_type, Some(arg.dst), arg.src, None, [])?; } - ast::LdStateSpace::Param => { + ast::LdStateSpace::Param | ast::LdStateSpace::Local => { let result_type = map.get_or_add(builder, SpirvType::from(data.typ)); builder.copy_object(result_type, Some(arg.dst), arg.src)?; } @@ -1242,18 +1287,20 @@ fn emit_function_body_ops( } } ast::Instruction::St(data, arg) => { - if data.qualifier != ast::LdStQualifier::Weak - || (data.state_space != ast::StStateSpace::Generic - && data.state_space != ast::StStateSpace::Param - && data.state_space != ast::StStateSpace::Global) - { + if data.qualifier != ast::LdStQualifier::Weak { todo!() } - if data.state_space == ast::StStateSpace::Param { + if data.state_space == ast::StStateSpace::Param + || data.state_space == ast::StStateSpace::Local + { let result_type = map.get_or_add(builder, SpirvType::from(data.typ)); builder.copy_object(result_type, Some(arg.src1), arg.src2)?; - } else { + } else if data.state_space == ast::StStateSpace::Generic + || data.state_space == ast::StStateSpace::Global + { builder.store(arg.src1, arg.src2, None, &[])?; + } else { + todo!() } } // SPIR-V does not support ret as guaranteed-converged @@ -1643,7 +1690,7 @@ fn emit_implicit_conversion( let from_parts = cv.from.to_parts(); let to_parts = cv.to.to_parts(); match (from_parts.kind, to_parts.kind, cv.kind) { - (_, _, ConversionKind::Ptr(space)) => { + (_, _, ConversionKind::BitToPtr(space)) => { let dst_type = map.get_or_add( builder, SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()), @@ -1699,14 +1746,11 @@ fn emit_implicit_conversion( } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(), (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) - | (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default) => { + | (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default) + | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { let into_type = map.get_or_add(builder, SpirvType::from(cv.to)); builder.bitcast(into_type, Some(cv.dst), cv.src)?; } - (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { - let into_type = map.get_or_add(builder, SpirvType::from(cv.to)); - builder.convert_ptr_to_u(into_type, Some(cv.dst), cv.src)?; - } _ => unreachable!(), } Ok(()) @@ -2181,7 +2225,7 @@ impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> { decl.get_fn_decl_str(id) } - fn get_src_semantics(m: &Self::MovOperand) -> ArgumentSemantics { + fn get_src_semantics(_: &Self::MovOperand) -> ArgumentSemantics { ArgumentSemantics::Default } } @@ -2230,7 +2274,12 @@ pub enum StateSpace { enum ExpandedArgParams {} type ExpandedStatement = Statement, ExpandedArgParams>; -type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>; + +struct Function<'input> { + pub func_directive: ast::MethodDecl<'input, ExpandedArgParams>, + pub globals: Vec, + pub body: Option>, +} impl ast::ArgParams for ExpandedArgParams { type ID = spirv::Word; @@ -2248,7 +2297,7 @@ impl ArgParamsEx for ExpandedArgParams { decl.get_fn_decl(*id) } - fn get_src_semantics(m: &Self::MovOperand) -> ArgumentSemantics { + fn get_src_semantics(_: &spirv::Word) -> ArgumentSemantics { ArgumentSemantics::Default } } @@ -2398,12 +2447,12 @@ struct ArgumentDescriptor { sema: ArgumentSemantics, } -#[derive(Copy, Clone, PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum ArgumentSemantics { // normal register access Default, // st/ld global - Ptr, + PhysicalPointer, // st/ld .param, .local RegisterPointer, // mov of .local/.global variables @@ -2720,7 +2769,8 @@ enum ConversionKind { Default, // zero-extend/chop/bitcast depending on types SignExtend, - Ptr(ast::LdStateSpace), + BitToPtr(ast::LdStateSpace), + PtrToBit, } impl ast::PredAt { @@ -2831,7 +2881,7 @@ impl ast::Arg2 { sema: if is_param { ArgumentSemantics::RegisterPointer } else { - ArgumentSemantics::Ptr + ArgumentSemantics::PhysicalPointer }, }, t, @@ -2919,7 +2969,7 @@ impl ast::Arg2St { sema: if is_param { ArgumentSemantics::RegisterPointer } else { - ArgumentSemantics::Ptr + ArgumentSemantics::PhysicalPointer }, }, t, @@ -3518,7 +3568,7 @@ fn get_implicit_conversions_ld_src( ) -> Result, TranslateError> { let src_type = id_def.get_typed(src)?; match state_space { - ast::LdStateSpace::Param => { + ast::LdStateSpace::Param | ast::LdStateSpace::Local => { if src_type != instr_type { Ok(vec![ ImplicitConversion { @@ -3560,7 +3610,7 @@ fn get_implicit_conversions_ld_src( dst: u32::max_value(), from: src_type, to: instr_type, - kind: ConversionKind::Ptr(state_space), + kind: ConversionKind::BitToPtr(state_space), }); if result.len() == 2 { let new_id = id_def.new_id(new_src_type); @@ -3570,92 +3620,9 @@ fn get_implicit_conversions_ld_src( } Ok(result) } - _ => todo!(), + _ => Err(TranslateError::Todo), } } -fn insert_implicit_conversions_ld_src( - func: &mut Vec, - instr_type: ast::Type, - id_def: &mut MutableNumericIdResolver, - state_space: ast::LdStateSpace, - src: spirv::Word, -) -> Result { - match state_space { - ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl( - func, - id_def, - instr_type, - src, - should_convert_ld_param_src, - ), - ast::LdStateSpace::Generic | ast::LdStateSpace::Global => { - let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts( - mem::size_of::() as u8, - ScalarKind::Bit, - )); - let new_src = insert_implicit_conversions_ld_src_impl( - func, - id_def, - new_src_type, - src, - should_convert_ld_generic_src_to_bitcast, - )?; - Ok(insert_conversion_src( - func, - id_def, - new_src, - new_src_type, - instr_type, - ConversionKind::Ptr(state_space), - )) - } - _ => todo!(), - } -} - -fn insert_implicit_conversions_ld_src_impl< - ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option, ->( - func: &mut Vec, - id_def: &mut MutableNumericIdResolver, - instr_type: ast::Type, - src: spirv::Word, - should_convert: ShouldConvert, -) -> Result { - let src_type = id_def.get_typed(src)?; - if let Some(conv) = should_convert(src_type, instr_type) { - Ok(insert_conversion_src( - func, id_def, src, src_type, instr_type, conv, - )) - } else { - Ok(src) - } -} - -fn should_convert_ld_param_src( - src_type: ast::Type, - instr_type: ast::Type, -) -> Option { - if src_type != instr_type { - return Some(ConversionKind::Default); - } - None -} - -// HACK ALERT -// IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an -// additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier -fn should_convert_ld_generic_src_to_bitcast( - src_type: ast::Type, - _instr_type: ast::Type, -) -> Option { - if let ast::Type::Scalar(src_type) = src_type { - if src_type.kind() == ScalarKind::Signed { - return Some(ConversionKind::Default); - } - } - None -} #[must_use] fn insert_conversion_src( @@ -3832,14 +3799,21 @@ fn insert_implicit_bitcasts( None => return Ok(desc.op), }; let id_actual_type = id_def.get_typed(desc.op)?; - if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) { + let conv_kind = if desc.sema == ArgumentSemantics::Address { + Some(ConversionKind::PtrToBit) + } else if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) { + Some(ConversionKind::Default) + } else { + None + }; + if let Some(conv_kind) = conv_kind { if desc.is_dst { dst_coercion = Some(get_conversion_dst( id_def, &mut desc.op, id_type_from_instr, id_actual_type, - ConversionKind::Default, + conv_kind, )); Ok(desc.op) } else { @@ -3849,7 +3823,7 @@ fn insert_implicit_bitcasts( desc.op, id_actual_type, id_type_from_instr, - ConversionKind::Default, + conv_kind, )) } } else { -- cgit v1.2.3