aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-09-20 15:44:52 +0200
committerAndrzej Janik <[email protected]>2020-09-20 15:44:52 +0200
commitdcaea507ba84e375d57c4ed051477439423ae8ef (patch)
tree01f06d36bd3fb0fe61516824d61dff3556b3af5f /ptx
parent17f2d09cc74a8744e96519ae04afbde7dde65705 (diff)
downloadZLUDA-dcaea507ba84e375d57c4ed051477439423ae8ef.tar.gz
ZLUDA-dcaea507ba84e375d57c4ed051477439423ae8ef.zip
Add more tests
Diffstat (limited to 'ptx')
-rw-r--r--ptx/src/test/spirv_run/local_align.spvtxt8
-rw-r--r--ptx/src/test/spirv_run/mod.rs11
-rw-r--r--ptx/src/test/spirv_run/mov_address.ptx15
-rw-r--r--ptx/src/test/spirv_run/mov_address.spvtxt (renamed from ptx/src/test/spirv_run/reg_slm.spvtxt)0
-rw-r--r--ptx/src/test/spirv_run/reg_local.ptx (renamed from ptx/src/test/spirv_run/reg_slm.ptx)10
-rw-r--r--ptx/src/test/spirv_run/reg_local.spvtxt46
-rw-r--r--ptx/src/translate.rs224
7 files changed, 175 insertions, 139 deletions
diff --git a/ptx/src/test/spirv_run/local_align.spvtxt b/ptx/src/test/spirv_run/local_align.spvtxt
index 09a3f92..2482a75 100644
--- a/ptx/src/test/spirv_run/local_align.spvtxt
+++ b/ptx/src/test/spirv_run/local_align.spvtxt
@@ -13,8 +13,10 @@
%25 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
%uchar = OpTypeInt 8 0
-%_arr_uchar_8 = OpTypeArray %uchar %8
-%_ptr_Function__arr_uchar_8 = OpTypePointer Function %_arr_uchar_8
+ %uint = OpTypeInt 32 0
+ %uint_8 = OpConstant %uint 8
+%_arr_uchar_uint_8 = OpTypeArray %uchar %uint_8
+%_ptr_Function__arr_uchar_uint_8 = OpTypePointer Function %_arr_uchar_uint_8
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%1 = OpFunction %void None %25
%8 = OpFunctionParameter %ulong
@@ -22,7 +24,7 @@
%20 = OpLabel
%2 = OpVariable %_ptr_Function_ulong Function
%3 = OpVariable %_ptr_Function_ulong Function
- %4 = OpVariable %_ptr_Function__arr_uchar_8 Workgroup
+ %4 = OpVariable %_ptr_Function__arr_uchar_uint_8 Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 4a793c4..4c9d779 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -8,7 +8,7 @@ use spirv_headers::Word;
use spirv_tools_sys::{
spv_binary, spv_endianness_t, spv_parsed_instruction_t, spv_result_t, spv_target_env,
};
-use std::collections::hash_map::Entry;
+use std::{collections::hash_map::Entry, cmp};
use std::error;
use std::ffi::{c_void, CStr, CString};
use std::fmt;
@@ -59,8 +59,9 @@ test_ptx!(local_align, [1u64], [1u64]);
test_ptx!(call, [1u64], [2u64]);
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
-//test_ptx!(ntid, [3u32], [4u32]);
-//test_ptx!(reg_slm, [12u64], [12u64]);
+test_ptx!(ntid, [3u32], [4u32]);
+test_ptx!(reg_local, [12u64], [12u64]);
+test_ptx!(mov_address, [0xDEADu64], [0u64]);
struct DisplayError<T: Debug> {
err: T,
@@ -123,8 +124,8 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
kernel.set_indirect_access(
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
)?;
- let mut inp_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, input.len())?;
- let mut out_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, output.len())?;
+ let mut inp_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(input.len(),1))?;
+ let mut out_b = ze::DeviceBuffer::<T>::new(&mut ctx, &dev, cmp::max(output.len(), 1))?;
let inp_b_ptr_mut: ze::BufferPtrMut<T> = (&mut inp_b).into();
let event_pool = ze::EventPool::new(&mut ctx, 3, Some(&[&dev]))?;
let ev0 = ze::Event::new(&event_pool, 0)?;
diff --git a/ptx/src/test/spirv_run/mov_address.ptx b/ptx/src/test/spirv_run/mov_address.ptx
new file mode 100644
index 0000000..433fc0e
--- /dev/null
+++ b/ptx/src/test/spirv_run/mov_address.ptx
@@ -0,0 +1,15 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry mov_address(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .local .b8 __local_depot0[8];
+ .reg .u64 temp;
+
+ mov.u64 temp, __local_depot0;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/reg_slm.spvtxt b/ptx/src/test/spirv_run/mov_address.spvtxt
index 6810fec..6810fec 100644
--- a/ptx/src/test/spirv_run/reg_slm.spvtxt
+++ b/ptx/src/test/spirv_run/mov_address.spvtxt
diff --git a/ptx/src/test/spirv_run/reg_slm.ptx b/ptx/src/test/spirv_run/reg_local.ptx
index 929d116..fb234d8 100644
--- a/ptx/src/test/spirv_run/reg_slm.ptx
+++ b/ptx/src/test/spirv_run/reg_local.ptx
@@ -2,12 +2,12 @@
.target sm_30
.address_size 64
-.visible .entry reg_slm(
+.visible .entry reg_local(
.param .u64 input,
.param .u64 output
)
{
- .local .align 8 .b8 slm[8];
+ .local .align 8 .b8 local_x[8];
.reg .u64 in_addr;
.reg .u64 out_addr;
.reg .b64 temp;
@@ -16,11 +16,9 @@
ld.param.u64 in_addr, [input];
ld.param.u64 out_addr, [output];
- mov.s64 unused, slm;
-
ld.global.u64 temp, [in_addr];
- st.u64 [slm], temp;
- ld.u64 temp, [slm];
+ st.u64 [local_x], temp;
+ ld.u64 temp, [local_x];
st.global.u64 [out_addr], temp;
ret;
} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt
new file mode 100644
index 0000000..6810fec
--- /dev/null
+++ b/ptx/src/test/spirv_run/reg_local.spvtxt
@@ -0,0 +1,46 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %25 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "add"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %28 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_1 = OpConstant %ulong 1
+ %1 = OpFunction %void None %28
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %23 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %21 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %14 = OpLoad %ulong %21
+ OpStore %6 %14
+ %17 = OpLoad %ulong %6
+ %16 = OpIAdd %ulong %17 %ulong_1
+ OpStore %7 %16
+ %18 = OpLoad %ulong %5
+ %19 = OpLoad %ulong %7
+ %22 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ OpStore %22 %19
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 3f71286..3e6b495 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -217,11 +217,13 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, Translate
let opencl_id = emit_opencl_import(&mut builder);
emit_memory_model(&mut builder);
let mut map = TypeWordMap::new(&mut builder);
+ emit_builtins(&mut builder, &mut map, &id_defs);
for f in ssa_functions {
let f_body = match f.body {
Some(f) => f,
None => continue,
};
+ emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
emit_function_header(&mut builder, &mut map, &id_defs, f.func_directive)?;
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
builder.end_function()?;
@@ -229,6 +231,33 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, Translate
Ok(builder.module())
}
+fn emit_builtins(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ id_defs: &GlobalStringIdResolver,
+) {
+ for (reg, id) in id_defs.special_registers.iter() {
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(
+ Box::new(SpirvType::from(reg.get_type())),
+ spirv::StorageClass::UniformConstant,
+ ),
+ );
+ builder.variable(
+ result_type,
+ Some(*id),
+ spirv::StorageClass::UniformConstant,
+ None,
+ );
+ builder.decorate(
+ *id,
+ spirv::Decoration::BuiltIn,
+ &[dr::Operand::BuiltIn(reg.get_builtin())],
+ );
+ }
+}
+
fn emit_function_header<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -239,7 +268,12 @@ fn emit_function_header<'a>(
let fn_id = match func_directive {
ast::MethodDecl::Kernel(name, _) => {
let fn_id = global.get_id(name)?;
- builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]);
+ let interface = global
+ .special_registers
+ .iter()
+ .map(|(_, id)| *id)
+ .collect::<Vec<_>>();
+ builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, interface);
fn_id
}
ast::MethodDecl::Func(_, name, _) => name,
@@ -293,7 +327,7 @@ fn emit_memory_model(builder: &mut dr::Builder) {
fn to_ssa_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
f: ast::ParsedFunction<'a>,
-) -> Result<ExpandedFunction<'a>, TranslateError> {
+) -> Result<Function<'a>, TranslateError> {
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive);
to_ssa(str_resolver, fn_resolver, fn_decl, f.body)
}
@@ -333,13 +367,14 @@ fn to_ssa<'input, 'b>(
fn_defs: GlobalFnDeclResolver<'input, 'b>,
f_args: ast::MethodDecl<'input, ExpandedArgParams>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
-) -> Result<ExpandedFunction<'input>, TranslateError> {
+) -> Result<Function<'input>, TranslateError> {
let f_body = match f_body {
Some(vec) => vec,
None => {
- return Ok(ExpandedFunction {
+ return Ok(Function {
func_directive: f_args,
body: None,
+ globals: Vec::new(),
})
}
};
@@ -357,12 +392,21 @@ fn to_ssa<'input, 'b>(
let mut numeric_id_defs = numeric_id_defs.unmut();
let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
let sorted_statements = normalize_variable_decls(labeled_statements);
- Ok(ExpandedFunction {
+ let (f_body, globals) = extract_globals(sorted_statements);
+ Ok(Function {
func_directive: f_args,
- body: Some(sorted_statements),
+ globals: globals,
+ body: Some(f_body),
})
}
+fn extract_globals(
+ sorted_statements: Vec<ExpandedStatement>,
+) -> (Vec<ExpandedStatement>, Vec<ExpandedStatement>) {
+ // This fn will be used for SLM
+ (sorted_statements, Vec::new())
+}
+
fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> {
func[1..].sort_by_key(|s| match s {
Statement::Variable(_) => 0,
@@ -477,7 +521,9 @@ fn add_types_to_statements(
},
_ => dets,
};
- Ok(Statement::Instruction(ast::Instruction::MovVector(new_dets, args)))
+ Ok(Statement::Instruction(ast::Instruction::MovVector(
+ new_dets, args,
+ )))
}
s => Ok(s),
}
@@ -724,7 +770,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
(_, ArgumentSemantics::Address) => return Ok(desc.op),
(t, ArgumentSemantics::RegisterPointer)
| (t, ArgumentSemantics::Default)
- | (t, ArgumentSemantics::Ptr) => t,
+ | (t, ArgumentSemantics::PhysicalPointer) => t,
};
let generated_id = id_def.new_id(id_type);
if !desc.is_dst {
@@ -873,7 +919,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
));
Ok(result_id)
}
- ArgumentSemantics::Ptr => {
+ ArgumentSemantics::PhysicalPointer => {
let scalar_t = ast::ScalarType::U64;
let id_constant_stmt = self.id_def.new_id(ast::Type::Scalar(scalar_t));
let result_id = self.id_def.new_id(typ);
@@ -1137,7 +1183,7 @@ fn emit_function_body_ops(
builder.begin_block(Some(*id))?;
}
_ => {
- if builder.block.is_none() {
+ if builder.block.is_none() && builder.function.is_some() {
builder.begin_block(None)?;
}
}
@@ -1166,10 +1212,9 @@ fn emit_function_body_ops(
name,
}) => {
let st_class = match v_type {
- ast::VariableType::Reg(_) | ast::VariableType::Param(_) => {
- spirv::StorageClass::Function
- }
- ast::VariableType::Local(_) => spirv::StorageClass::Workgroup,
+ ast::VariableType::Reg(_)
+ | ast::VariableType::Param(_)
+ | ast::VariableType::Local(_) => spirv::StorageClass::Function,
};
let type_id = map.get_or_add(
builder,
@@ -1234,7 +1279,7 @@ fn emit_function_body_ops(
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
}
- ast::LdStateSpace::Param => {
+ ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
@@ -1242,18 +1287,20 @@ fn emit_function_body_ops(
}
}
ast::Instruction::St(data, arg) => {
- if data.qualifier != ast::LdStQualifier::Weak
- || (data.state_space != ast::StStateSpace::Generic
- && data.state_space != ast::StStateSpace::Param
- && data.state_space != ast::StStateSpace::Global)
- {
+ if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
- if data.state_space == ast::StStateSpace::Param {
+ if data.state_space == ast::StStateSpace::Param
+ || data.state_space == ast::StStateSpace::Local
+ {
let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
builder.copy_object(result_type, Some(arg.src1), arg.src2)?;
- } else {
+ } else if data.state_space == ast::StStateSpace::Generic
+ || data.state_space == ast::StStateSpace::Global
+ {
builder.store(arg.src1, arg.src2, None, &[])?;
+ } else {
+ todo!()
}
}
// SPIR-V does not support ret as guaranteed-converged
@@ -1643,7 +1690,7 @@ fn emit_implicit_conversion(
let from_parts = cv.from.to_parts();
let to_parts = cv.to.to_parts();
match (from_parts.kind, to_parts.kind, cv.kind) {
- (_, _, ConversionKind::Ptr(space)) => {
+ (_, _, ConversionKind::BitToPtr(space)) => {
let dst_type = map.get_or_add(
builder,
SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()),
@@ -1699,14 +1746,11 @@ fn emit_implicit_conversion(
}
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(),
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
- | (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default) => {
+ | (TypeKind::Scalar, TypeKind::Vector, ConversionKind::Default)
+ | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
}
- (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
- let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
- builder.convert_ptr_to_u(into_type, Some(cv.dst), cv.src)?;
- }
_ => unreachable!(),
}
Ok(())
@@ -2181,7 +2225,7 @@ impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
decl.get_fn_decl_str(id)
}
- fn get_src_semantics(m: &Self::MovOperand) -> ArgumentSemantics {
+ fn get_src_semantics(_: &Self::MovOperand) -> ArgumentSemantics {
ArgumentSemantics::Default
}
}
@@ -2230,7 +2274,12 @@ pub enum StateSpace {
enum ExpandedArgParams {}
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
-type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>;
+
+struct Function<'input> {
+ pub func_directive: ast::MethodDecl<'input, ExpandedArgParams>,
+ pub globals: Vec<ExpandedStatement>,
+ pub body: Option<Vec<ExpandedStatement>>,
+}
impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word;
@@ -2248,7 +2297,7 @@ impl ArgParamsEx for ExpandedArgParams {
decl.get_fn_decl(*id)
}
- fn get_src_semantics(m: &Self::MovOperand) -> ArgumentSemantics {
+ fn get_src_semantics(_: &spirv::Word) -> ArgumentSemantics {
ArgumentSemantics::Default
}
}
@@ -2398,12 +2447,12 @@ struct ArgumentDescriptor<Op> {
sema: ArgumentSemantics,
}
-#[derive(Copy, Clone, PartialEq, Eq)]
+#[derive(Copy, Clone, PartialEq, Eq, Debug)]
pub enum ArgumentSemantics {
// normal register access
Default,
// st/ld global
- Ptr,
+ PhysicalPointer,
// st/ld .param, .local
RegisterPointer,
// mov of .local/.global variables
@@ -2720,7 +2769,8 @@ enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
- Ptr(ast::LdStateSpace),
+ BitToPtr(ast::LdStateSpace),
+ PtrToBit,
}
impl<T> ast::PredAt<T> {
@@ -2831,7 +2881,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
sema: if is_param {
ArgumentSemantics::RegisterPointer
} else {
- ArgumentSemantics::Ptr
+ ArgumentSemantics::PhysicalPointer
},
},
t,
@@ -2919,7 +2969,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
sema: if is_param {
ArgumentSemantics::RegisterPointer
} else {
- ArgumentSemantics::Ptr
+ ArgumentSemantics::PhysicalPointer
},
},
t,
@@ -3518,7 +3568,7 @@ fn get_implicit_conversions_ld_src(
) -> Result<Vec<ImplicitConversion>, TranslateError> {
let src_type = id_def.get_typed(src)?;
match state_space {
- ast::LdStateSpace::Param => {
+ ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
if src_type != instr_type {
Ok(vec![
ImplicitConversion {
@@ -3560,7 +3610,7 @@ fn get_implicit_conversions_ld_src(
dst: u32::max_value(),
from: src_type,
to: instr_type,
- kind: ConversionKind::Ptr(state_space),
+ kind: ConversionKind::BitToPtr(state_space),
});
if result.len() == 2 {
let new_id = id_def.new_id(new_src_type);
@@ -3570,92 +3620,9 @@ fn get_implicit_conversions_ld_src(
}
Ok(result)
}
- _ => todo!(),
+ _ => Err(TranslateError::Todo),
}
}
-fn insert_implicit_conversions_ld_src(
- func: &mut Vec<ExpandedStatement>,
- instr_type: ast::Type,
- id_def: &mut MutableNumericIdResolver,
- state_space: ast::LdStateSpace,
- src: spirv::Word,
-) -> Result<spirv::Word, TranslateError> {
- match state_space {
- ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl(
- func,
- id_def,
- instr_type,
- src,
- should_convert_ld_param_src,
- ),
- ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
- let new_src_type = ast::Type::Scalar(ast::ScalarType::from_parts(
- mem::size_of::<usize>() as u8,
- ScalarKind::Bit,
- ));
- let new_src = insert_implicit_conversions_ld_src_impl(
- func,
- id_def,
- new_src_type,
- src,
- should_convert_ld_generic_src_to_bitcast,
- )?;
- Ok(insert_conversion_src(
- func,
- id_def,
- new_src,
- new_src_type,
- instr_type,
- ConversionKind::Ptr(state_space),
- ))
- }
- _ => todo!(),
- }
-}
-
-fn insert_implicit_conversions_ld_src_impl<
- ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
->(
- func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
- instr_type: ast::Type,
- src: spirv::Word,
- should_convert: ShouldConvert,
-) -> Result<spirv::Word, TranslateError> {
- let src_type = id_def.get_typed(src)?;
- if let Some(conv) = should_convert(src_type, instr_type) {
- Ok(insert_conversion_src(
- func, id_def, src, src_type, instr_type, conv,
- ))
- } else {
- Ok(src)
- }
-}
-
-fn should_convert_ld_param_src(
- src_type: ast::Type,
- instr_type: ast::Type,
-) -> Option<ConversionKind> {
- if src_type != instr_type {
- return Some(ConversionKind::Default);
- }
- None
-}
-
-// HACK ALERT
-// IGC currently segfaults if you bitcast integer -> ptr, that's why we emit an
-// additional S64/U64 -> B64 conversion here, so the SPIR-V emission is easier
-fn should_convert_ld_generic_src_to_bitcast(
- src_type: ast::Type,
- _instr_type: ast::Type,
-) -> Option<ConversionKind> {
- if let ast::Type::Scalar(src_type) = src_type {
- if src_type.kind() == ScalarKind::Signed {
- return Some(ConversionKind::Default);
- }
- }
- None
-}
#[must_use]
fn insert_conversion_src(
@@ -3832,14 +3799,21 @@ fn insert_implicit_bitcasts(
None => return Ok(desc.op),
};
let id_actual_type = id_def.get_typed(desc.op)?;
- if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) {
+ let conv_kind = if desc.sema == ArgumentSemantics::Address {
+ Some(ConversionKind::PtrToBit)
+ } else if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) {
+ Some(ConversionKind::Default)
+ } else {
+ None
+ };
+ if let Some(conv_kind) = conv_kind {
if desc.is_dst {
dst_coercion = Some(get_conversion_dst(
id_def,
&mut desc.op,
id_type_from_instr,
id_actual_type,
- ConversionKind::Default,
+ conv_kind,
));
Ok(desc.op)
} else {
@@ -3849,7 +3823,7 @@ fn insert_implicit_bitcasts(
desc.op,
id_actual_type,
id_type_from_instr,
- ConversionKind::Default,
+ conv_kind,
))
}
} else {