summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-11-18 00:38:33 +0100
committerAndrzej Janik <[email protected]>2020-11-18 00:38:33 +0100
commit2d652ff9c81f81e4136316e15165cc6b6ebe96e4 (patch)
tree17356c656eb01fa41b1ee1b9bdf284b03409bdc7
parent3fd1ca9b53d9abb68cd08c638c24670cd79ae443 (diff)
downloadZLUDA-2d652ff9c81f81e4136316e15165cc6b6ebe96e4.tar.gz
ZLUDA-2d652ff9c81f81e4136316e15165cc6b6ebe96e4.zip
Fix various bugs
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_simple.ptx (renamed from ptx/src/test/spirv_run/ld_st_stateful.ptx)2
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt (renamed from ptx/src/test/spirv_run/ld_st_stateful.spvtxt)0
-rw-r--r--ptx/src/translate.rs168
3 files changed, 90 insertions, 80 deletions
diff --git a/ptx/src/test/spirv_run/ld_st_stateful.ptx b/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx
index 859b169..5650ada 100644
--- a/ptx/src/test/spirv_run/ld_st_stateful.ptx
+++ b/ptx/src/test/spirv_run/stateful_ld_st_simple.ptx
@@ -2,7 +2,7 @@
.target sm_30
.address_size 64
-.visible .entry ld_st_stateful(
+.visible .entry stateful_ld_st_simple(
.param .u64 input,
.param .u64 output
)
diff --git a/ptx/src/test/spirv_run/ld_st_stateful.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt
index 963d88a..963d88a 100644
--- a/ptx/src/test/spirv_run/ld_st_stateful.spvtxt
+++ b/ptx/src/test/spirv_run/stateful_ld_st_simple.spvtxt
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 6fc35d2..f644a27 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -857,6 +857,8 @@ fn replace_uses_of_shared_memory<'a>(
ast::LdStateSpace::Shared,
),
kind: ConversionKind::PtrToPtr { spirv_ptr: true },
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: ArgumentSemantics::Default,
}));
}
replacement_id
@@ -1297,15 +1299,11 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
- let mut numeric_id_defs = numeric_id_defs.finish();
- let (typed_statements, temporaries) =
+ let typed_statements =
convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
- let ssa_statements = insert_mem_ssa_statements(
- typed_statements,
- &mut numeric_id_defs,
- &mut spirv_decl,
- temporaries,
- )?;
+ let ssa_statements =
+ insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, &mut spirv_decl)?;
+ let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
@@ -1927,7 +1925,6 @@ fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
fn_decl: &mut SpirvMethodDecl,
- temporaries: HashSet<spirv::Word>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.output.iter() {
@@ -1969,7 +1966,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
for s in func {
match s {
Statement::Call(call) => {
- insert_mem_ssa_statement_default(id_def, &mut result, &temporaries, call.cast())?
+ insert_mem_ssa_statement_default(id_def, &mut result, call.cast())?
}
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
@@ -1989,7 +1986,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
result.push(Statement::Instruction(ast::Instruction::Ret(d)))
}
}
- inst => insert_mem_ssa_statement_default(id_def, &mut result, &temporaries, inst)?,
+ inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
},
Statement::Conditional(mut bra) => {
let generated_id =
@@ -2005,7 +2002,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
result.push(Statement::Conditional(bra));
}
Statement::Conversion(conv) => {
- insert_mem_ssa_statement_default(id_def, &mut result, &temporaries, conv)?
+ insert_mem_ssa_statement_default(id_def, &mut result, conv)?
}
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
_ => return Err(TranslateError::Unreachable),
@@ -2094,7 +2091,6 @@ impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded
fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
id_def: &mut NumericIdResolver,
result: &mut Vec<TypedStatement>,
- temporaries: &HashSet<spirv::Word>,
stmt: F,
) -> Result<(), TranslateError> {
let mut post_statements = Vec::new();
@@ -2612,13 +2608,13 @@ fn insert_implicit_conversions(
)?;
}
s @ Statement::Conditional(_)
+ | s @ Statement::Conversion(_)
| s @ Statement::Label(_)
| s @ Statement::Constant(_)
| s @ Statement::Variable(_)
| s @ Statement::LoadVar(_, _)
| s @ Statement::StoreVar(_, _)
| s @ Statement::Undef(_, _)
- | s @ Statement::Conversion(_)
| s @ Statement::RetValue(_, _) => result.push(s),
}
}
@@ -2685,6 +2681,8 @@ fn insert_implicit_conversions_impl(
from,
to,
kind: conv_kind,
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: ArgumentSemantics::Default,
}));
result
}
@@ -3741,6 +3739,8 @@ fn emit_cvt(
src_t.kind(),
)),
kind: ConversionKind::Default,
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: ArgumentSemantics::Default,
};
emit_implicit_conversion(builder, map, &cv)?;
new_dst
@@ -4113,6 +4113,8 @@ fn emit_implicit_conversion(
from: wide_bit_type,
to: cv.to.clone(),
kind: ConversionKind::Default,
+ src_sema: cv.src_sema,
+ dst_sema: cv.dst_sema,
},
)?;
}
@@ -4244,8 +4246,8 @@ fn expand_map_variables<'a, 'b>(
fn convert_to_stateful_memory_access<'a>(
func_args: &mut SpirvMethodDecl,
func_body: Vec<TypedStatement>,
- id_defs: &mut MutableNumericIdResolver<'a>,
-) -> Result<(Vec<TypedStatement>, HashSet<spirv::Word>), TranslateError> {
+ id_defs: &mut NumericIdResolver<'a>,
+) -> Result<Vec<TypedStatement>, TranslateError> {
let func_args_64bit = func_args
.input
.iter()
@@ -4371,7 +4373,7 @@ fn convert_to_stateful_memory_access<'a>(
let mut remapped_ids = HashMap::new();
let mut result = Vec::with_capacity(regs_ptr_current.len() + func_body.len());
for reg in regs_ptr_current {
- let new_id = id_defs.new_id(ast::Type::Pointer(
+ let new_id = id_defs.new_variable(ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::U8),
ast::LdStateSpace::Global,
));
@@ -4386,7 +4388,6 @@ fn convert_to_stateful_memory_access<'a>(
}));
remapped_ids.insert(reg, new_id);
}
- let mut temporaries = HashSet::new();
for statement in func_body {
match statement {
l @ Statement::Label(_) => result.push(l),
@@ -4402,10 +4403,9 @@ fn convert_to_stateful_memory_access<'a>(
inst.visit_variable(&mut |arg_desc: ArgumentDescriptor<spirv::Word>, typ| {
Ok(match remapped_ids.get(&arg_desc.op) {
Some(new_id) => {
- let old_type_full = id_defs.get_typed(arg_desc.op)?;
+ let (old_type_full, _) = id_defs.get_typed(arg_desc.op)?;
let old_type = old_type_full.clone();
- let converting_id = id_defs.new_id(old_type_full);
- temporaries.insert(converting_id);
+ let converting_id = id_defs.new_non_variable(Some(old_type_full));
if arg_desc.is_dst {
post_statements.push(Statement::Conversion(
ImplicitConversion {
@@ -4419,6 +4419,8 @@ fn convert_to_stateful_memory_access<'a>(
kind: ConversionKind::BitToPtr(
ast::LdStateSpace::Global,
),
+ src_sema: ArgumentSemantics::Default,
+ dst_sema: arg_desc.sema,
},
));
converting_id
@@ -4432,6 +4434,8 @@ fn convert_to_stateful_memory_access<'a>(
),
to: old_type,
kind: ConversionKind::PtrToBit(ast::UIntType::U64),
+ src_sema: arg_desc.sema,
+ dst_sema: ArgumentSemantics::Default,
}));
converting_id
}
@@ -4441,19 +4445,23 @@ fn convert_to_stateful_memory_access<'a>(
if arg_desc.is_dst {
return Err(TranslateError::Unreachable);
}
- let old_type = id_defs.get_typed(arg_desc.op)?;
+ let (old_type, _) = id_defs.get_typed(arg_desc.op)?;
let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_id(old_type);
- temporaries.insert(converting_id);
+ let converting_id = id_defs.new_non_variable(Some(old_type));
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
+ ast::PointerType::Pointer(
+ ast::ScalarType::U8,
+ ast::LdStateSpace::Global,
+ ),
+ ast::LdStateSpace::Param,
),
to: old_type_clone,
- kind: ConversionKind::PtrToBit(ast::UIntType::U64),
+ kind: ConversionKind::PtrToPtr { spirv_ptr: false },
+ src_sema: arg_desc.sema,
+ dst_sema: ArgumentSemantics::Default,
}));
converting_id
}
@@ -4469,33 +4477,6 @@ fn convert_to_stateful_memory_access<'a>(
Statement::Call(call) => todo!(),
_ => return Err(TranslateError::Unreachable),
}
- /*
- match statement {
- statement.
- /*
- Statement::Instruction(inst) => result.push(
- inst.visit_variable_extended(
- &mut |id_desc: ArgumentDescriptor<spirv::Word>, typ| {
- if let Some(new_id) = remapped_ids.get(&id_desc.op) {
- if id_desc.is_dst {
- panic!()
- } else {
- result.push(Statement::Conversion(ImplicitConversion {
- src
- }));
- Ok(*new_id)
- }
- } else {
- Ok(id_desc.op)
- }
- },
- )
- .unwrap(),
- ),
- s => result.push(s),\
- */
- }
- */
}
for arg in func_args.input.iter_mut() {
if func_args_ptr.contains(&arg.name) {
@@ -4505,14 +4486,14 @@ fn convert_to_stateful_memory_access<'a>(
);
}
}
- Ok((result, temporaries))
+ Ok(result)
}
-fn is_64_bit_integer(id_defs: &MutableNumericIdResolver, id: spirv::Word) -> bool {
+fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
match id_defs.get_typed(id) {
- Ok(ast::Type::Scalar(ast::ScalarType::U64))
- | Ok(ast::Type::Scalar(ast::ScalarType::S64))
- | Ok(ast::Type::Scalar(ast::ScalarType::B64)) => true,
+ Ok((ast::Type::Scalar(ast::ScalarType::U64), _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::S64), _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true,
_ => false,
}
}
@@ -5572,9 +5553,9 @@ impl VisitVariable for ast::Instruction<TypedArgParams> {
fn visit_variable<
'a,
F: FnMut(
- ArgumentDescriptor<spirv_headers::Word>,
+ ArgumentDescriptor<spirv::Word>,
Option<&ast::Type>,
- ) -> Result<spirv_headers::Word, TranslateError>,
+ ) -> Result<spirv::Word, TranslateError>,
>(
self,
f: &mut F,
@@ -5583,30 +5564,28 @@ impl VisitVariable for ast::Instruction<TypedArgParams> {
}
}
-impl VisitVariable for ImplicitConversion {
- fn visit_variable<
- 'a,
- F: FnMut(
- ArgumentDescriptor<spirv_headers::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv_headers::Word, TranslateError>,
+impl ImplicitConversion {
+ fn map<
+ T: ArgParamsEx<Id = spirv::Word>,
+ U: ArgParamsEx<Id = spirv::Word>,
+ V: ArgumentMapVisitor<T, U>,
>(
self,
- f: &mut F,
- ) -> Result<TypedStatement, TranslateError> {
- let new_src = f(
+ visitor: &mut V,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ let new_dst = visitor.id(
ArgumentDescriptor {
- op: self.src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
+ op: self.dst,
+ is_dst: true,
+ sema: self.dst_sema,
},
- Some(&self.from),
+ Some(&self.to),
)?;
- let new_dst = f(
+ let new_src = visitor.id(
ArgumentDescriptor {
- op: self.dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
+ op: self.src,
+ is_dst: false,
+ sema: self.src_sema,
},
Some(&self.from),
)?;
@@ -5620,6 +5599,35 @@ impl VisitVariable for ImplicitConversion {
}
}
+impl VisitVariable for ImplicitConversion {
+ fn visit_variable<
+ 'a,
+ F: FnMut(
+ ArgumentDescriptor<spirv_headers::Word>,
+ Option<&ast::Type>,
+ ) -> Result<spirv_headers::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<TypedStatement, TranslateError> {
+ self.map(f)
+ }
+}
+
+impl VisitVariableExpanded for ImplicitConversion {
+ fn visit_variable_extended<
+ F: FnMut(
+ ArgumentDescriptor<spirv_headers::Word>,
+ Option<&ast::Type>,
+ ) -> Result<spirv_headers::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<ExpandedStatement, TranslateError> {
+ self.map(f)
+ }
+}
+
impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T
where
T: FnMut(
@@ -6051,6 +6059,8 @@ struct ImplicitConversion {
from: ast::Type,
to: ast::Type,
kind: ConversionKind,
+ src_sema: ArgumentSemantics,
+ dst_sema: ArgumentSemantics,
}
#[derive(PartialEq, Copy, Clone)]