diff options
-rw-r--r-- | ptx/src/translate.rs | 159 |
1 files changed, 120 insertions, 39 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 0af3c07..ab53e74 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1277,11 +1277,15 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
- let typed_statements =
- convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
- let ssa_statements =
- insert_mem_ssa_statements(typed_statements, &mut numeric_id_defs, &mut spirv_decl)?;
+ let (typed_statements, temporaries) =
+ convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
+ let ssa_statements = insert_mem_ssa_statements(
+ typed_statements,
+ &mut numeric_id_defs,
+ &mut spirv_decl,
+ temporaries,
+ )?;
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
@@ -1938,11 +1942,11 @@ fn normalize_predicates( Ok(result)
}
-// TODO: Don't lift temporaries and move this pass to a later stage
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut MutableNumericIdResolver,
fn_decl: &mut SpirvMethodDecl,
+ temporaries: HashSet<spirv::Word>,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.output.iter() {
@@ -1984,7 +1988,7 @@ fn insert_mem_ssa_statements<'a, 'b>( for s in func {
match s {
Statement::Call(call) => {
- insert_mem_ssa_statement_default(id_def, &mut result, call.cast())?
+ insert_mem_ssa_statement_default(id_def, &mut result, &temporaries, call.cast())?
}
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
@@ -2004,7 +2008,7 @@ fn insert_mem_ssa_statements<'a, 'b>( result.push(Statement::Instruction(ast::Instruction::Ret(d)))
}
}
- inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
+ inst => insert_mem_ssa_statement_default(id_def, &mut result, &temporaries, inst)?,
},
Statement::Conditional(mut bra) => {
let generated_id = id_def.new_id(ast::Type::Scalar(ast::ScalarType::Pred));
@@ -2018,7 +2022,9 @@ fn insert_mem_ssa_statements<'a, 'b>( bra.predicate = generated_id;
result.push(Statement::Conditional(bra));
}
- Statement::Conversion(conv) => todo!(),
+ Statement::Conversion(conv) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, &temporaries, conv)?
+ }
s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
_ => return Err(TranslateError::Unreachable),
}
@@ -2041,7 +2047,19 @@ fn type_to_variable_type(t: &ast::Type) -> Result<Option<ast::VariableType>, Tra .map_err(|_| TranslateError::MismatchedType)?,
len.clone(),
))),
- ast::Type::Pointer(_, _) => None,
+ ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
+ Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
+ scalar_type
+ .clone()
+ .try_into()
+ .map_err(|_| TranslateError::Unreachable)?,
+ (*space)
+ .try_into()
+ .map_err(|_| TranslateError::Unreachable)?,
+ )))
+ }
+ ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
+ _ => return Err(TranslateError::Unreachable),
})
}
@@ -2094,12 +2112,16 @@ impl<'a, Ctor: FnOnce(spirv::Word) -> ExpandedStatement> VisitVariableExpanded fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
id_def: &mut MutableNumericIdResolver,
result: &mut Vec<TypedStatement>,
+ temporaries: &HashSet<spirv::Word>,
stmt: F,
) -> Result<(), TranslateError> {
let mut post_statements = Vec::new();
let new_statement =
stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, instr_type| {
- if instr_type.is_none() || desc.sema == ArgumentSemantics::RegisterPointer {
+ if temporaries.contains(&desc.op)
+ || instr_type.is_none()
+ || desc.sema == ArgumentSemantics::RegisterPointer
+ {
return Ok(desc.op);
}
let id_type = match (id_def.get_typed(desc.op)?, desc.sema) {
@@ -2222,10 +2244,10 @@ fn expand_arguments<'a, 'b>( Statement::LoadVar(arg, typ) => result.push(Statement::LoadVar(arg, typ)),
Statement::StoreVar(arg, typ) => result.push(Statement::StoreVar(arg, typ)),
Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Composite(_)
- | Statement::Conversion(_)
- | Statement::Constant(_)
- | Statement::Undef(_, _) => return Err(TranslateError::Unreachable),
+ Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
+ Statement::Composite(_) | Statement::Constant(_) | Statement::Undef(_, _) => {
+ return Err(TranslateError::Unreachable)
+ }
}
}
Ok(result)
@@ -2621,8 +2643,8 @@ fn insert_implicit_conversions( | s @ Statement::LoadVar(_, _)
| s @ Statement::StoreVar(_, _)
| s @ Statement::Undef(_, _)
+ | s @ Statement::Conversion(_)
| s @ Statement::RetValue(_, _) => result.push(s),
- Statement::Conversion(_) => unreachable!(),
}
}
Ok(result)
@@ -4235,8 +4257,8 @@ fn expand_map_variables<'a, 'b>( fn convert_to_stateful_memory_access<'a>(
func_args: &mut SpirvMethodDecl,
func_body: Vec<TypedStatement>,
- id_defs: &mut NumericIdResolver<'a>,
-) -> Result<Vec<TypedStatement>, TranslateError> {
+ id_defs: &mut MutableNumericIdResolver<'a>,
+) -> Result<(Vec<TypedStatement>, HashSet<spirv::Word>), TranslateError> {
let func_args_64bit = func_args
.input
.iter()
@@ -4362,13 +4384,10 @@ fn convert_to_stateful_memory_access<'a>( let mut remapped_ids = HashMap::new();
let mut result = Vec::with_capacity(regs_ptr_current.len() + func_body.len());
for reg in regs_ptr_current {
- let new_id = id_defs.new_id(Some((
- StateSpace::Reg,
- ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- )));
+ let new_id = id_defs.new_id(ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ));
result.push(Statement::Variable(ast::Variable {
align: None,
name: new_id,
@@ -4380,6 +4399,7 @@ fn convert_to_stateful_memory_access<'a>( }));
remapped_ids.insert(reg, new_id);
}
+ let mut temporaries = HashSet::new();
for statement in func_body {
match statement {
l @ Statement::Label(_) => result.push(l),
@@ -4391,24 +4411,27 @@ fn convert_to_stateful_memory_access<'a>( }
Statement::Instruction(inst) => {
let mut post_statements = Vec::new();
- let new_statement = inst
- .visit_variable(&mut |arg_desc: ArgumentDescriptor<spirv::Word>, typ| {
+ let new_statement =
+ inst.visit_variable(&mut |arg_desc: ArgumentDescriptor<spirv::Word>, typ| {
Ok(match remapped_ids.get(&arg_desc.op) {
Some(new_id) => {
let old_type_full = id_defs.get_typed(arg_desc.op)?;
- let old_type = old_type_full.1.clone();
- let converting_id = id_defs.new_id(Some(old_type_full));
+ let old_type = old_type_full.clone();
+ let converting_id = id_defs.new_id(old_type_full);
+ temporaries.insert(converting_id);
if arg_desc.is_dst {
post_statements.push(Statement::Conversion(
ImplicitConversion {
src: converting_id,
- dst: converting_id,
+ dst: *new_id,
from: old_type,
to: ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::U8),
ast::LdStateSpace::Global,
),
- kind: ConversionKind::PtrToBit(ast::UIntType::U64),
+ kind: ConversionKind::BitToPtr(
+ ast::LdStateSpace::Global,
+ ),
},
));
converting_id
@@ -4421,15 +4444,36 @@ fn convert_to_stateful_memory_access<'a>( ast::LdStateSpace::Global,
),
to: old_type,
- kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global),
+ kind: ConversionKind::PtrToBit(ast::UIntType::U64),
}));
converting_id
}
}
- None => arg_desc.op,
+ None => match func_args_ptr.get(&arg_desc.op) {
+ Some(new_id) => {
+ if arg_desc.is_dst {
+ return Err(TranslateError::Unreachable);
+ }
+ let old_type = id_defs.get_typed(arg_desc.op)?;
+ let old_type_clone = old_type.clone();
+ let converting_id = id_defs.new_id(old_type);
+ temporaries.insert(converting_id);
+ result.push(Statement::Conversion(ImplicitConversion {
+ src: *new_id,
+ dst: converting_id,
+ from: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ to: old_type_clone,
+ kind: ConversionKind::PtrToBit(ast::UIntType::U64),
+ }));
+ converting_id
+ }
+ None => arg_desc.op,
+ },
})
- })
- .unwrap();
+ })?;
result.push(new_statement);
for s in post_statements {
result.push(s);
@@ -4474,14 +4518,14 @@ fn convert_to_stateful_memory_access<'a>( );
}
}
- Ok(result)
+ Ok((result, temporaries))
}
-fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
+fn is_64_bit_integer(id_defs: &MutableNumericIdResolver, id: spirv::Word) -> bool {
match id_defs.get_typed(id) {
- Ok((_, ast::Type::Scalar(ast::ScalarType::U64)))
- | Ok((_, ast::Type::Scalar(ast::ScalarType::S64)))
- | Ok((_, ast::Type::Scalar(ast::ScalarType::B64))) => true,
+ Ok(ast::Type::Scalar(ast::ScalarType::U64))
+ | Ok(ast::Type::Scalar(ast::ScalarType::S64))
+ | Ok(ast::Type::Scalar(ast::ScalarType::B64)) => true,
_ => false,
}
}
@@ -5541,6 +5585,43 @@ impl VisitVariable for ast::Instruction<TypedArgParams> { }
}
+impl VisitVariable for ImplicitConversion {
+ fn visit_variable<
+ 'a,
+ F: FnMut(
+ ArgumentDescriptor<spirv_headers::Word>,
+ Option<&ast::Type>,
+ ) -> Result<spirv_headers::Word, TranslateError>,
+ >(
+ self,
+ f: &mut F,
+ ) -> Result<TypedStatement, TranslateError> {
+ let new_src = f(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(&self.from),
+ )?;
+ let new_dst = f(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(&self.from),
+ )?;
+ Ok(Statement::Conversion({
+ ImplicitConversion {
+ src: new_src,
+ dst: new_dst,
+ ..self
+ }
+ }))
+ }
+}
+
impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T
where
T: FnMut(
|