summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-11-17 22:05:22 +0100
committerAndrzej Janik <[email protected]>2020-11-17 22:05:22 +0100
commit3fd1ca9b53d9abb68cd08c638c24670cd79ae443 (patch)
tree5bb486343bb874eb8b2e12808aa71fecc0dbdaf0
parentf3aba1746443c4dba06ce2eb8634f16600acdea9 (diff)
parent9e820d62b7d52c1748c2f8c46c263e843e3b3855 (diff)
downloadZLUDA-3fd1ca9b53d9abb68cd08c638c24670cd79ae443.tar.gz
ZLUDA-3fd1ca9b53d9abb68cd08c638c24670cd79ae443.zip
Merge branch 'stateful_try1' into stateful_try2
-rw-r--r--ptx/src/ptx.lalrpop12
-rw-r--r--ptx/src/test/spirv_run/ld_st_stateful.ptx25
-rw-r--r--ptx/src/test/spirv_run/ld_st_stateful.spvtxt57
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/translate.rs494
5 files changed, 507 insertions, 82 deletions
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 6c231b2..d2c235a 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -1237,18 +1237,18 @@ InstRet: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvta
InstCvta: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "cvta" <to:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
+ "cvta" <from:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
ast::Instruction::Cvta(ast::CvtaDetails {
- to: to,
- from: ast::CvtaStateSpace::Generic,
+ to: ast::CvtaStateSpace::Generic,
+ from,
size: s
},
a)
},
- "cvta" ".to" <from:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
+ "cvta" ".to" <to:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
ast::Instruction::Cvta(ast::CvtaDetails {
- to: ast::CvtaStateSpace::Generic,
- from: from,
+ to,
+ from: ast::CvtaStateSpace::Generic,
size: s
},
a)
diff --git a/ptx/src/test/spirv_run/ld_st_stateful.ptx b/ptx/src/test/spirv_run/ld_st_stateful.ptx
new file mode 100644
index 0000000..859b169
--- /dev/null
+++ b/ptx/src/test/spirv_run/ld_st_stateful.ptx
@@ -0,0 +1,25 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry ld_st_stateful(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 in_addr2;
+ .reg .u64 out_addr2;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ cvta.to.global.u64 in_addr2, in_addr;
+ cvta.to.global.u64 out_addr2, out_addr;
+
+ ld.global.u64 temp, [in_addr2];
+ st.global.u64 [out_addr2], temp;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/ld_st_stateful.spvtxt b/ptx/src/test/spirv_run/ld_st_stateful.spvtxt
new file mode 100644
index 0000000..963d88a
--- /dev/null
+++ b/ptx/src/test/spirv_run/ld_st_stateful.spvtxt
@@ -0,0 +1,57 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %30 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "ld_st_offset"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %33 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %uint = OpTypeInt 32 0
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %ulong_4 = OpConstant %ulong 4
+ %ulong_4_0 = OpConstant %ulong 4
+ %1 = OpFunction %void None %33
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %28 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %10 = OpLoad %ulong %2
+ OpStore %4 %10
+ %11 = OpLoad %ulong %3
+ OpStore %5 %11
+ %13 = OpLoad %ulong %4
+ %24 = OpConvertUToPtr %_ptr_Generic_uint %13
+ %12 = OpLoad %uint %24
+ OpStore %6 %12
+ %15 = OpLoad %ulong %4
+ %21 = OpIAdd %ulong %15 %ulong_4
+ %25 = OpConvertUToPtr %_ptr_Generic_uint %21
+ %14 = OpLoad %uint %25
+ OpStore %7 %14
+ %16 = OpLoad %ulong %5
+ %17 = OpLoad %uint %7
+ %26 = OpConvertUToPtr %_ptr_Generic_uint %16
+ OpStore %26 %17
+ %18 = OpLoad %ulong %5
+ %19 = OpLoad %uint %6
+ %23 = OpIAdd %ulong %18 %ulong_4_0
+ %27 = OpConvertUToPtr %_ptr_Generic_uint %23
+ OpStore %27 %19
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index bd74508..e014231 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -133,6 +133,7 @@ test_ptx!(
[0b11111000_11000001_00100010_10100000u32, 16u32, 8u32],
[0b11000001u32]
);
+test_ptx!(ld_st_stateful, [1u64], [1u64]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 86a7e73..6fc35d2 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -724,7 +724,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
multi_hash_map_append(&mut directly_called_by, call.func, call_key);
Statement::Call(call)
}
- statement => statement.map_id(&mut |id| {
+ statement => statement.map_id(&mut |id, _| {
if extern_shared_decls.contains_key(&id) {
methods_using_extern_shared.insert(call_key);
}
@@ -841,7 +841,7 @@ fn replace_uses_of_shared_memory<'a>(
result.push(Statement::Call(call))
}
statement => {
- let new_statement = statement.map_id(&mut |id| {
+ let new_statement = statement.map_id(&mut |id, _| {
if let Some(typ) = extern_shared_decls.get(&id) {
let replacement_id = new_id();
if *typ != ast::SizedScalarType::B8 {
@@ -1294,12 +1294,18 @@ fn to_ssa<'input, 'b>(
};
let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
let mut numeric_id_defs = id_defs.finish();
- let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
+ let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
- let ssa_statements =
- insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, &mut spirv_decl)?;
let mut numeric_id_defs = numeric_id_defs.finish();
+ let (typed_statements, temporaries) =
+ convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
+ let ssa_statements = insert_mem_ssa_statements(
+ typed_statements,
+ &mut numeric_id_defs,
+ &mut spirv_decl,
+ temporaries,
+ )?;
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
@@ -1631,17 +1637,7 @@ fn convert_to_typed_statements(
},
Statement::Label(i) => result.push(Statement::Label(i)),
Statement::Variable(v) => result.push(Statement::Variable(v)),
- Statement::LoadVar(a, t) => result.push(Statement::LoadVar(a, t)),
- Statement::StoreVar(a, t) => result.push(Statement::StoreVar(a, t)),
- Statement::Call(c) => result.push(Statement::Call(c.cast())),
- Statement::Composite(c) => result.push(Statement::Composite(c)),
- Statement::Conditional(c) => result.push(Statement::Conditional(c)),
- Statement::Conversion(c) => result.push(Statement::Conversion(c)),
- Statement::Constant(c) => result.push(Statement::Constant(c)),
- Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Undef(_, _) | Statement::PtrAdd { .. } => {
- return Err(TranslateError::Unreachable)
- }
+ _ => return Err(TranslateError::Unreachable),
}
}
Ok(result)
@@ -1888,7 +1884,7 @@ fn normalize_labels(
fn normalize_predicates(
func: Vec<NormalizedStatement>,
id_def: &mut NumericIdResolver,
-) -> Vec<UnconditionalStatement> {
+) -> Result<Vec<UnconditionalStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for s in func {
match s {
@@ -1921,16 +1917,17 @@ fn normalize_predicates(
}
Statement::Variable(var) => result.push(Statement::Variable(var)),
// Blocks are flattened when resolving ids
- _ => unreachable!(),
+ _ => return Err(TranslateError::Unreachable),
}
}
- result
+ Ok(result)
}
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
fn_decl: &mut SpirvMethodDecl,
+ temporaries: HashSet<spirv::Word>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.output.iter() {
@@ -1972,7 +1969,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
for s in func {
match s {
Statement::Call(call) => {
- insert_mem_ssa_statement_default(id_def, &mut result, call.cast())?
+ insert_mem_ssa_statement_default(id_def, &mut result, &temporaries, call.cast())?
}
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
@@ -1992,7 +1989,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
result.push(Statement::Instruction(ast::Instruction::Ret(d)))
}
}
- inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
+ inst => insert_mem_ssa_statement_default(id_def, &mut result, &temporaries, inst)?,
},
Statement::Conditional(mut bra) => {
let generated_id =
@@ -2007,15 +2004,11 @@ fn insert_mem_ssa_statements<'a, 'b>(
bra.predicate = generated_id;
result.push(Statement::Conditional(bra));
}
+ Statement::Conversion(conv) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, &temporaries, conv)?
+ }
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
- Statement::LoadVar(_, _)
- | Statement::StoreVar(_, _)
- | Statement::Conversion(_)
- | Statement::RetValue(_, _)
- | Statement::Constant(_)
- | Statement::Undef(_, _)
- | Statement::PtrAdd { .. } => {}
- Statement::Composite(_) => todo!(),
+ _ => return Err(TranslateError::Unreachable),
}
}
Ok(result)
@@ -2036,7 +2029,19 @@ fn type_to_variable_type(t: &ast::Type) -> Result<Option<ast::VariableType>, Tra
.map_err(|_| TranslateError::MismatchedType)?,
len.clone(),
))),
- ast::Type::Pointer(_, _) => None,
+ ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
+ Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
+ scalar_type
+ .clone()
+ .try_into()
+ .map_err(|_| TranslateError::Unreachable)?,
+ (*space)
+ .try_into()
+ .map_err(|_| TranslateError::Unreachable)?,
+ )))
+ }
+ ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
+ _ => return Err(TranslateError::Unreachable),
})
}
@@ -2089,6 +2094,7 @@ impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded
fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
id_def: &mut NumericIdResolver,
result: &mut Vec<TypedStatement>,
+ temporaries: &HashSet<spirv::Word>,
stmt: F,
) -> Result<(), TranslateError> {
let mut post_statements = Vec::new();
@@ -2162,7 +2168,7 @@ fn expand_arguments<'a, 'b>(
state_space,
dst,
ptr_src,
- constant_src,
+ offset_src: constant_src,
} => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
let sema = match state_space {
@@ -2204,7 +2210,7 @@ fn expand_arguments<'a, 'b>(
state_space,
dst: new_dst,
ptr_src: new_ptr_src,
- constant_src: new_constant_src,
+ offset_src: new_constant_src,
})
}
Statement::Label(id) => result.push(Statement::Label(id)),
@@ -2212,10 +2218,10 @@ fn expand_arguments<'a, 'b>(
Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Composite(_)
- | Statement::Conversion(_)
- | Statement::Constant(_)
- | Statement::Undef(_, _) => unreachable!(),
+ Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
+ Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => {
+ return Err(TranslateError::Unreachable)
+ }
}
}
Ok(result)
@@ -2291,7 +2297,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
state_space: *state_space,
dst,
ptr_src: reg,
- constant_src: id_constant_stmt,
+ offset_src: id_constant_stmt,
});
return Ok(dst);
} else {
@@ -2580,7 +2586,7 @@ fn insert_implicit_conversions(
state_space,
dst,
ptr_src,
- constant_src,
+ offset_src: constant_src,
} => {
let visit_desc = VisitArgumentDescriptor {
desc: ArgumentDescriptor {
@@ -2594,7 +2600,7 @@ fn insert_implicit_conversions(
state_space,
dst,
ptr_src: new_ptr_src,
- constant_src,
+ offset_src: constant_src,
},
};
insert_implicit_conversions_impl(
@@ -2612,8 +2618,8 @@ fn insert_implicit_conversions(
| s @ Statement::LoadVar(_, _)
| s @ Statement::StoreVar(_, _)
| s @ Statement::Undef(_, _)
+ | s @ Statement::Conversion(_)
| s @ Statement::RetValue(_, _) => result.push(s),
- Statement::Conversion(_) => unreachable!(),
}
}
Ok(result)
@@ -3225,16 +3231,28 @@ fn emit_function_body_ops(
state_space,
dst,
ptr_src,
- constant_src,
+ offset_src,
} => {
- let s64_type = map.get_or_add_scalar(builder, ast::ScalarType::S64);
- let ptr_as_s64 = builder.bitcast(s64_type, None, *ptr_src)?;
- let added_ptr = builder.i_add(s64_type, None, ptr_as_s64, *constant_src)?;
+ let u8_pointer = map.get_or_add(
+ builder,
+ SpirvType::from(ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ *state_space,
+ )),
+ );
let result_type = map.get_or_add(
builder,
SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)),
);
- builder.bitcast(result_type, Some(*dst), added_ptr)?;
+ let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
+ let temp = builder.in_bounds_ptr_access_chain(
+ u8_pointer,
+ None,
+ ptr_src_u8,
+ *offset_src,
+ &[],
+ )?;
+ builder.bitcast(result_type, Some(*dst), temp)?;
}
}
}
@@ -4219,6 +4237,286 @@ fn expand_map_variables<'a, 'b>(
Ok(())
}
+// TODO: detect more patterns (mov, call via reg, call via param)
+// TODO: don't convert to ptr if the register is not ultimately used for ld/st
+// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
+// argument expansion
+fn convert_to_stateful_memory_access<'a>(
+ func_args: &mut SpirvMethodDecl,
+ func_body: Vec<TypedStatement>,
+ id_defs: &mut MutableNumericIdResolver<'a>,
+) -> Result<(Vec<TypedStatement>, HashSet<spirv::Word>), TranslateError> {
+ let func_args_64bit = func_args
+ .input
+ .iter()
+ .filter_map(|arg| match arg.v_type {
+ ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name),
+ _ => None,
+ })
+ .collect::<HashSet<_>>();
+ let mut stateful_markers = Vec::new();
+ let mut stateful_init_reg = MultiHashMap::new();
+ for statement in func_body.iter() {
+ match statement {
+ Statement::Instruction(ast::Instruction::Cvta(
+ ast::CvtaDetails {
+ to: ast::CvtaStateSpace::Global,
+ size: ast::CvtaSize::U64,
+ from: ast::CvtaStateSpace::Generic,
+ },
+ arg,
+ )) => {
+ if let Some(src) = arg.src.underlying() {
+ if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, arg.dst) {
+ stateful_markers.push((arg.dst, *src));
+ }
+ }
+ }
+ Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Param,
+ typ: ast::LdStType::Scalar(ast::LdStScalarType::U64),
+ ..
+ },
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Param,
+ typ: ast::LdStType::Scalar(ast::LdStScalarType::S64),
+ ..
+ },
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Param,
+ typ: ast::LdStType::Scalar(ast::LdStScalarType::B64),
+ ..
+ },
+ arg,
+ )) => {
+ if let (ast::IdOrVector::Reg(dst), Some(src)) = (&arg.dst, arg.src.underlying()) {
+ if func_args_64bit.contains(src) {
+ multi_hash_map_append(&mut stateful_init_reg, *dst, *src);
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ let mut func_args_ptr = HashSet::new();
+ let mut regs_ptr_current = HashSet::new();
+ for (dst, src) in stateful_markers {
+ if let Some(func_args) = stateful_init_reg.get(&src) {
+ for a in func_args {
+ func_args_ptr.insert(*a);
+ regs_ptr_current.insert(src);
+ regs_ptr_current.insert(dst);
+ }
+ }
+ }
+ // We don't need to propagate this further, because it's usually handled by
+ // IGC pass StatelessToStatefull
+ /*
+ let mut regs_ptr_seen = HashSet::new();
+ while regs_ptr_current.len() > 0 {
+ let mut regs_ptr_new = HashSet::new();
+ for statement in func_body.iter() {
+ match statement {
+ Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Add(
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S64,
+ saturate: false,
+ }),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Sub(
+ ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ arg,
+ ))
+ | Statement::Instruction(ast::Instruction::Sub(
+ ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: ast::SIntType::S64,
+ saturate: false,
+ }),
+ arg,
+ )) => {
+ if let Some(src1) = arg.src1.underlying() {
+ if regs_ptr_current.contains(src1) {
+ regs_ptr_new.insert(arg.dst);
+ }
+ } else if let Some(src2) = arg.src2.underlying() {
+ if regs_ptr_current.contains(src2) {
+ regs_ptr_new.insert(arg.dst);
+ }
+ }
+ }
+ // We don't care about PtrAdd, because it gets produced later
+ _ => {}
+ }
+ }
+ for id in regs_ptr_current {
+ regs_ptr_seen.insert(id);
+ }
+ regs_ptr_current = regs_ptr_new;
+ }
+ */
+ let mut remapped_ids = HashMap::new();
+ let mut result = Vec::with_capacity(regs_ptr_current.len() + func_body.len());
+ for reg in regs_ptr_current {
+ let new_id = id_defs.new_id(ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ));
+ result.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: new_id,
+ array_init: Vec::new(),
+ v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
+ ast::SizedScalarType::U8,
+ ast::PointerStateSpace::Global,
+ )),
+ }));
+ remapped_ids.insert(reg, new_id);
+ }
+ let mut temporaries = HashSet::new();
+ for statement in func_body {
+ match statement {
+ l @ Statement::Label(_) => result.push(l),
+ c @ Statement::Conditional(_) => result.push(c),
+ Statement::Variable(var) => {
+ if !remapped_ids.contains_key(&var.name) {
+ result.push(Statement::Variable(var));
+ }
+ }
+ Statement::Instruction(inst) => {
+ let mut post_statements = Vec::new();
+ let new_statement =
+ inst.visit_variable(&mut |arg_desc: ArgumentDescriptor<spirv::Word>, typ| {
+ Ok(match remapped_ids.get(&arg_desc.op) {
+ Some(new_id) => {
+ let old_type_full = id_defs.get_typed(arg_desc.op)?;
+ let old_type = old_type_full.clone();
+ let converting_id = id_defs.new_id(old_type_full);
+ temporaries.insert(converting_id);
+ if arg_desc.is_dst {
+ post_statements.push(Statement::Conversion(
+ ImplicitConversion {
+ src: converting_id,
+ dst: *new_id,
+ from: old_type,
+ to: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ kind: ConversionKind::BitToPtr(
+ ast::LdStateSpace::Global,
+ ),
+ },
+ ));
+ converting_id
+ } else {
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ to: old_type,
+ kind: ConversionKind::PtrToBit(ast::UIntType::U64),
+ }));
+ converting_id
+ }
+ }
+ None => match func_args_ptr.get(&arg_desc.op) {
+ Some(new_id) => {
+ if arg_desc.is_dst {
+ return Err(TranslateError::Unreachable);
+ }
+ let old_type = id_defs.get_typed(arg_desc.op)?;
+ let old_type_clone = old_type.clone();
+ let converting_id = id_defs.new_id(old_type);
+ temporaries.insert(converting_id);
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ to: old_type_clone,
+ kind: ConversionKind::PtrToBit(ast::UIntType::U64),
+ }));
+ converting_id
+ }
+ None => arg_desc.op,
+ },
+ })
+ })?;
+ result.push(new_statement);
+ for s in post_statements {
+ result.push(s);
+ }
+ }
+ Statement::Call(call) => todo!(),
+ _ => return Err(TranslateError::Unreachable),
+ }
+ /*
+ match statement {
+ statement.
+ /*
+ Statement::Instruction(inst) => result.push(
+ inst.visit_variable_extended(
+ &mut |id_desc: ArgumentDescriptor<spirv::Word>, typ| {
+ if let Some(new_id) = remapped_ids.get(&id_desc.op) {
+ if id_desc.is_dst {
+ panic!()
+ } else {
+ result.push(Statement::Conversion(ImplicitConversion {
+ src
+ }));
+ Ok(*new_id)
+ }
+ } else {
+ Ok(id_desc.op)
+ }
+ },
+ )
+ .unwrap(),
+ ),
+ s => result.push(s),\
+ */
+ }
+ */
+ }
+ for arg in func_args.input.iter_mut() {
+ if func_args_ptr.contains(&arg.name) {
+ arg.v_type = ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ );
+ }
+ }
+ Ok((result, temporaries))
+}
+
+fn is_64_bit_integer(id_defs: &MutableNumericIdResolver, id: spirv::Word) -> bool {
+ match id_defs.get_typed(id) {
+ Ok(ast::Type::Scalar(ast::ScalarType::U64))
+ | Ok(ast::Type::Scalar(ast::ScalarType::S64))
+ | Ok(ast::Type::Scalar(ast::ScalarType::B64)) => true,
+ _ => false,
+ }
+}
+
#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
enum PtxSpecialRegister {
Tid,
@@ -4542,12 +4840,12 @@ enum Statement<I, P: ast::ArgParams> {
Label(u32),
Variable(ast::Variable<ast::VariableType, P::Id>),
Instruction(I),
+ // SPIR-V compatible replacement for PTX predicates
+ Conditional(BrachCondition),
+ Call(ResolvedCall<P>),
LoadVar(ast::Arg2<ExpandedArgParams>, ast::Type),
StoreVar(ast::Arg2St<ExpandedArgParams>, ast::Type),
- Call(ResolvedCall<P>),
Composite(CompositeRead),
- // SPIR-V compatible replacement for PTX predicates
- Conditional(BrachCondition),
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
RetValue(ast::RetData, spirv::Word),
@@ -4557,67 +4855,74 @@ enum Statement<I, P: ast::ArgParams> {
state_space: ast::LdStateSpace,
dst: spirv::Word,
ptr_src: spirv::Word,
- constant_src: spirv::Word,
+ offset_src: spirv::Word,
},
}
impl ExpandedStatement {
- fn map_id(self, f: &mut impl FnMut(spirv::Word) -> spirv::Word) -> ExpandedStatement {
+ fn map_id(self, f: &mut impl FnMut(spirv::Word, bool) -> spirv::Word) -> ExpandedStatement {
match self {
- Statement::Label(id) => Statement::Label(f(id)),
+ Statement::Label(id) => Statement::Label(f(id, false)),
Statement::Variable(mut var) => {
- var.name = f(var.name);
+ var.name = f(var.name, true);
Statement::Variable(var)
}
Statement::Instruction(inst) => inst
- .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| Ok(f(arg.op)))
+ .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| {
+ Ok(f(arg.op, arg.is_dst))
+ })
.unwrap(),
Statement::LoadVar(mut arg, typ) => {
- arg.dst = f(arg.dst);
- arg.src = f(arg.src);
+ arg.dst = f(arg.dst, true);
+ arg.src = f(arg.src, false);
Statement::LoadVar(arg, typ)
}
Statement::StoreVar(mut arg, typ) => {
- arg.src1 = f(arg.src1);
- arg.src2 = f(arg.src2);
+ arg.src1 = f(arg.src1, false);
+ arg.src2 = f(arg.src2, false);
Statement::StoreVar(arg, typ)
}
Statement::Call(mut call) => {
- for (id, _) in call.ret_params.iter_mut() {
- *id = f(*id);
+ for (id, typ) in call.ret_params.iter_mut() {
+ let is_dst = match typ {
+ ast::FnArgumentType::Reg(_) => true,
+ ast::FnArgumentType::Param(_) => false,
+ ast::FnArgumentType::Shared => false,
+ };
+ *id = f(*id, is_dst);
}
- call.func = f(call.func);
+ call.func = f(call.func, false);
for (id, _) in call.param_list.iter_mut() {
- *id = f(*id);
+ *id = f(*id, false);
}
Statement::Call(call)
}
Statement::Composite(mut composite) => {
- composite.dst = f(composite.dst);
- composite.src_composite = f(composite.src_composite);
+ composite.dst = f(composite.dst, true);
+ composite.src_composite = f(composite.src_composite, false);
Statement::Composite(composite)
}
Statement::Conditional(mut conditional) => {
- conditional.predicate = f(conditional.predicate);
- conditional.if_true = f(conditional.if_true);
- conditional.if_false = f(conditional.if_false);
+ conditional.predicate = f(conditional.predicate, false);
+ conditional.if_true = f(conditional.if_true, false);
+ conditional.if_false = f(conditional.if_false, false);
Statement::Conditional(conditional)
}
Statement::Conversion(mut conv) => {
- conv.dst = f(conv.dst);
- conv.src = f(conv.src);
+ conv.dst = f(conv.dst, true);
+ conv.src = f(conv.src, false);
Statement::Conversion(conv)
}
Statement::Constant(mut constant) => {
- constant.dst = f(constant.dst);
+ constant.dst = f(constant.dst, true);
Statement::Constant(constant)
}
Statement::RetValue(data, id) => {
- let id = f(id);
+ let id = f(id, false);
Statement::RetValue(data, id)
}
Statement::Undef(typ, id) => {
- let id = f(id);
+ let id = f(id, true);
Statement::Undef(typ, id)
}
Statement::PtrAdd {
@@ -4625,17 +4930,17 @@ impl ExpandedStatement {
state_space,
dst,
ptr_src,
- constant_src,
+ offset_src: constant_src,
} => {
- let dst = f(dst);
- let ptr_src = f(ptr_src);
- let constant_src = f(constant_src);
+ let dst = f(dst, true);
+ let ptr_src = f(ptr_src, false);
+ let constant_src = f(constant_src, false);
Statement::PtrAdd {
underlying_type,
state_space,
dst,
ptr_src,
- constant_src,
+ offset_src: constant_src,
}
}
}
@@ -5278,6 +5583,43 @@ impl VisitVariable for ast::Instruction<TypedArgParams> {
}
}
+impl VisitVariable for ImplicitConversion {
+ fn visit_variable<
+ 'a,
+ F: FnMut(
+ ArgumentDescriptor<spirv_headers::Word>,
+ Option<&ast::Type>,
+ ) -> Result<spirv_headers::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<TypedStatement, TranslateError> {
+ let new_src = f(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(&self.from),
+ )?;
+ let new_dst = f(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(&self.from),
+ )?;
+ Ok(Statement::Conversion({
+ ImplicitConversion {
+ src: new_src,
+ dst: new_dst,
+ ..self
+ }
+ }))
+ }
+}
+
impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T
where
T: FnMut(