diff options
author | Andrzej Janik <[email protected]> | 2020-11-17 22:03:10 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-11-17 22:03:10 +0100 |
commit | f3aba1746443c4dba06ce2eb8634f16600acdea9 (patch) | |
tree | 7ed87d577626f21a8afffb6a3cfd5cb96cd198c0 | |
parent | a99be72c8b01a2edc7d958971a53e12e99dc0c2e (diff) | |
download | ZLUDA-f3aba1746443c4dba06ce2eb8634f16600acdea9.tar.gz ZLUDA-f3aba1746443c4dba06ce2eb8634f16600acdea9.zip |
Add missing changes
-rw-r--r-- | ptx/src/translate.rs | 94 |
1 files changed, 66 insertions, 28 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 728d641..86a7e73 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -72,15 +72,13 @@ impl From<ast::PointerType> for ast::Type { }
impl ast::Type {
- fn pointer_to(self, space: ast::LdStateSpace) -> Result<Self, TranslateError> {
+ fn param_pointer_to(self, space: ast::LdStateSpace) -> Result<Self, TranslateError> {
Ok(match self {
ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
ast::Type::Vector(t, len) => {
ast::Type::Pointer(ast::PointerType::Vector(t, len), space)
}
- ast::Type::Array(t, dims) => {
- ast::Type::Pointer(ast::PointerType::Array(t, dims), space)
- }
+ ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
ast::Type::Pointer(ast::PointerType::Scalar(t), space) => {
ast::Type::Pointer(ast::PointerType::Pointer(t, space), space)
}
@@ -1192,11 +1190,31 @@ fn translate_variable<'a>( id_defs: &mut GlobalStringIdResolver<'a>,
var: ast::Variable<ast::VariableType, &'a str>,
) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
- let (_, typ) = var.v_type.to_type();
+ let (space, var_type) = var.v_type.to_type();
+ let mut is_variable = false;
+ let var_type = match space {
+ ast::StateSpace::Reg => {
+ is_variable = true;
+ var_type
+ }
+ ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?,
+ ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
+ ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
+ ast::StateSpace::Shared => {
+ // If it's a pointer it will be translated to a method parameter later
+ if let ast::Type::Pointer(..) = var_type {
+ is_variable = true;
+ var_type.param_pointer_to(ast::LdStateSpace::Param)?
+ } else {
+ var_type.param_pointer_to(ast::LdStateSpace::Shared)?
+ }
+ }
+ ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
+ };
Ok(ast::Variable {
align: var.align,
v_type: var.v_type,
- name: id_defs.get_or_add_def_typed(var.name, typ),
+ name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
array_init: var.array_init,
})
}
@@ -1218,8 +1236,8 @@ fn expand_kernel_params<'a, 'b>( Ok(ast::KernelArgument {
name: fn_resolver.add_def(
a.name,
- Some(ast::Type::from(a.v_type.clone()).pointer_to(ast::LdStateSpace::Param)?),
- false, // This is debatable if should be true or false
+ Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?),
+ false,
),
v_type: a.v_type.clone(),
align: a.align,
@@ -1234,11 +1252,15 @@ fn expand_fn_params<'a, 'b>( args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
args.map(|a| {
- let var_type = a.v_type.to_func_type();
- let is_variable = match a.v_type {
- ast::FnArgumentType::Reg(_) => true,
- ast::FnArgumentType::Param(_) => false,
- ast::FnArgumentType::Shared => false,
+ let mut var_type = a.v_type.to_func_type();
+ let mut is_variable = false;
+ var_type = match a.v_type {
+ ast::FnArgumentType::Reg(_) => {
+ is_variable = true;
+ var_type
+ }
+ ast::FnArgumentType::Shared => var_type.param_pointer_to(ast::LdStateSpace::Shared)?,
+ ast::FnArgumentType::Param(_) => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
};
Ok(ast::FnArgument {
name: fn_resolver.add_def(a.name, Some(var_type), is_variable),
@@ -2079,10 +2101,6 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( if !is_variable {
return Ok(desc.op);
}
- match var_type {
- ast::Type::Array(..) => return Ok(desc.op),
- _ => {}
- }
let generated_id = id_def.new_non_variable(Some(var_type.clone()));
if !desc.is_dst {
result.push(Statement::LoadVar(
@@ -4152,13 +4170,28 @@ fn expand_map_variables<'a, 'b>( let mut var_type = ast::Type::from(var.var.v_type.clone());
let mut is_variable = false;
var_type = match var.var.v_type {
- ast::VariableType::Reg(_) | ast::VariableType::Shared(_) => {
+ ast::VariableType::Reg(_) => {
is_variable = true;
var_type
}
- ast::VariableType::Global(_) => var_type.pointer_to(ast::LdStateSpace::Global)?,
- ast::VariableType::Param(_) => var_type.pointer_to(ast::LdStateSpace::Param)?,
- ast::VariableType::Local(_) => var_type.pointer_to(ast::LdStateSpace::Local)?,
+ ast::VariableType::Shared(_) => {
+ // If it's a pointer it will be translated to a method parameter later
+ if let ast::Type::Pointer(..) = var_type {
+ is_variable = true;
+ var_type.param_pointer_to(ast::LdStateSpace::Param)?
+ } else {
+ var_type.param_pointer_to(ast::LdStateSpace::Shared)?
+ }
+ }
+ ast::VariableType::Global(_) => {
+ var_type.param_pointer_to(ast::LdStateSpace::Global)?
+ }
+ ast::VariableType::Param(_) => {
+ var_type.param_pointer_to(ast::LdStateSpace::Param)?
+ }
+ ast::VariableType::Local(_) => {
+ var_type.param_pointer_to(ast::LdStateSpace::Local)?
+ }
};
match var.count {
Some(count) => {
@@ -4227,7 +4260,7 @@ impl PtxSpecialRegister { struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
- variables_type_check: HashMap<u32, Option<ast::Type>>,
+ variables_type_check: HashMap<u32, Option<(ast::Type, bool)>>,
special_registers: HashMap<PtxSpecialRegister, spirv::Word>,
fns: HashMap<spirv::Word, FnDecl>,
}
@@ -4252,11 +4285,16 @@ impl<'a> GlobalStringIdResolver<'a> { self.get_or_add_impl(id, None)
}
- fn get_or_add_def_typed(&mut self, id: &'a str, typ: ast::Type) -> spirv::Word {
- self.get_or_add_impl(id, Some(typ))
+ fn get_or_add_def_typed(
+ &mut self,
+ id: &'a str,
+ typ: ast::Type,
+ is_variable: bool,
+ ) -> spirv::Word {
+ self.get_or_add_impl(id, Some((typ, is_variable)))
}
- fn get_or_add_impl(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
+ fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word {
let id = match self.variables.entry(Cow::Borrowed(id)) {
hash_map::Entry::Occupied(e) => *(e.get()),
hash_map::Entry::Vacant(e) => {
@@ -4352,7 +4390,7 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut spirv::Word,
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
- global_type_check: &'b HashMap<u32, Option<ast::Type>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
type_check: HashMap<u32, Option<(ast::Type, bool)>>,
@@ -4441,7 +4479,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
- global_type_check: &'b HashMap<u32, Option<ast::Type>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
type_check: HashMap<u32, Option<(ast::Type, bool)>>,
special_registers: HashMap<spirv::Word, PtxSpecialRegister>,
}
@@ -4458,7 +4496,7 @@ impl<'b> NumericIdResolver<'b> { None => match self.special_registers.get(&id) {
Some(x) => Ok((x.get_type(), true)),
None => match self.global_type_check.get(&id) {
- Some(Some(x)) => Ok((x.clone(), true)),
+ Some(Some(result)) => Ok(result.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol),
},
},
|