diff options
author | Andrzej Janik <[email protected]> | 2021-04-17 14:01:50 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2021-04-17 14:01:50 +0200 |
commit | d51aaaf5529dbfec0735c73768e468728112c26b (patch) | |
tree | 2420b0e35bbc93dd0d53f37f6828541e2e76e878 | |
parent | a55c851eaa4ded60d5f62aba1d7da850a63163f3 (diff) | |
download | ZLUDA-d51aaaf5529dbfec0735c73768e468728112c26b.tar.gz ZLUDA-d51aaaf5529dbfec0735c73768e468728112c26b.zip |
Throw away special variable types
-rw-r--r-- | ptx/src/ast.rs | 215 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 102 | ||||
-rw-r--r-- | ptx/src/translate.rs | 429 |
3 files changed, 256 insertions, 490 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 3e62cb1..c7b9563 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,6 +1,5 @@ use half::f16; use lalrpop_util::{lexer::Token, ParseError}; -use std::convert::TryInto; use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; @@ -34,107 +33,12 @@ pub enum PtxError { NonExternPointer, } -macro_rules! sub_type { - ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - sub_type! { $type_name : Type { - $( - $variant ($($field_type),+), - )+ - }} - }; - ($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - #[derive(PartialEq, Eq, Clone)] - pub enum $type_name { - $( - $variant ($($field_type),+), - )+ - } - - impl From<$type_name> for $base_type { - #[allow(non_snake_case)] - fn from(t: $type_name) -> $base_type { - match t { - $( - $type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+), - )+ - } - } - } - - impl std::convert::TryFrom<$base_type> for $type_name { - type Error = (); - - #[allow(non_snake_case)] - #[allow(unreachable_patterns)] - fn try_from(t: $base_type) -> Result<Self, Self::Error> { - match t { - $( - $base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), - )+ - _ => Err(()), - } - } - } - }; -} - -sub_type! { - VariableRegType { - Scalar(ScalarType), - Vector(ScalarType, u8), - // Array type is used when emiting SSA statements at the start of a method - Array(ScalarType, VecU32), - // Pointer variant is used when passing around SLM pointer between - // function calls for dynamic SLM - Pointer(ScalarType, LdStateSpace) - } -} - -type VecU32 = Vec<u32>; - -sub_type! { - VariableLocalType { - Scalar(ScalarType), - Vector(ScalarType, u8), - Array(ScalarType, VecU32), - } -} - -impl TryFrom<VariableGlobalType> for VariableLocalType { - type Error = PtxError; - - fn try_from(value: VariableGlobalType) -> Result<Self, Self::Error> { - match value { - VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)), - VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)), - VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)), - VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray), - } - } -} - -sub_type! { - VariableGlobalType { - Scalar(ScalarType), - Vector(ScalarType, u8), - Array(ScalarType, VecU32), - Pointer(ScalarType, LdStateSpace), - } -} - // For some weird reson this is illegal: // .param .f16x2 foobar; // but this is legal: // .param .f16x2 foobar[1]; // even more interestingly this is legal, but only in .func (not in .entry): // .param .b32 foobar[] -sub_type! { - VariableParamType { - Scalar(ScalarType), - Array(ScalarType, VecU32), - Pointer(ScalarType, LdStateSpace), - } -} #[derive(Copy, Clone, Eq, PartialEq)] pub enum BarDetails { @@ -178,7 +82,7 @@ pub struct Module<'a> { } pub enum Directive<'a, P: ArgParams> { - Variable(Variable<VariableType, P::Id>), + Variable(Variable<P::Id>), Method(Function<'a, &'a str, Statement<P>>), } @@ -190,8 +94,8 @@ pub enum MethodDecl<'a, ID> { }, } -pub type FnArgument<ID> = Variable<FnArgumentType, ID>; -pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>; +pub type FnArgument<ID> = Variable<ID>; +pub type KernelArgument<ID> = Variable<ID>; pub struct Function<'a, ID, S> { pub func_directive: MethodDecl<'a, ID>, @@ -202,76 +106,6 @@ pub struct Function<'a, ID, S> { pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>; #[derive(PartialEq, Eq, Clone)] -pub enum FnArgumentType { - Reg(VariableRegType), - Param(VariableParamType), - Shared, -} -#[derive(PartialEq, Eq, Clone)] -pub enum KernelArgumentType { - Normal(VariableParamType), - Shared, -} - -impl From<KernelArgumentType> for Type { - fn from(this: KernelArgumentType) -> Self { - match this { - KernelArgumentType::Normal(typ) => typ.into(), - KernelArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } -} - -impl FnArgumentType { - pub fn to_type(&self, is_kernel: bool) -> Type { - if is_kernel { - self.to_kernel_type() - } else { - self.to_func_type() - } - } - - pub fn to_kernel_type(&self) -> Type { - match self { - FnArgumentType::Reg(x) => x.clone().into(), - FnArgumentType::Param(x) => x.clone().into(), - FnArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } - - pub fn to_func_type(&self) -> Type { - match self { - FnArgumentType::Reg(x) => x.clone().into(), - FnArgumentType::Param(VariableParamType::Scalar(t)) => { - Type::Pointer(PointerType::Scalar((*t).into()), LdStateSpace::Param) - } - FnArgumentType::Param(VariableParamType::Array(t, dims)) => Type::Pointer( - PointerType::Array((*t).into(), dims.clone()), - LdStateSpace::Param, - ), - FnArgumentType::Param(VariableParamType::Pointer(t, space)) => Type::Pointer( - PointerType::Pointer((*t).into(), (*space).into()), - LdStateSpace::Param, - ), - FnArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } - - pub fn is_param(&self) -> bool { - match self { - FnArgumentType::Param(_) => true, - _ => false, - } - } -} - -#[derive(PartialEq, Eq, Clone)] pub enum Type { Scalar(ScalarType), Vector(ScalarType, u8), @@ -283,7 +117,7 @@ pub enum Type { pub enum PointerType { Scalar(ScalarType), Vector(ScalarType, u8), - Array(ScalarType, VecU32), + Array(ScalarType, Vec<u32>), // Instances of this variant are generated during stateful conversion Pointer(ScalarType, LdStateSpace), } @@ -366,51 +200,19 @@ pub enum Statement<P: ArgParams> { } pub struct MultiVariable<ID> { - pub var: Variable<VariableType, ID>, + pub var: Variable<ID>, pub count: Option<u32>, } #[derive(Clone)] -pub struct Variable<T, ID> { +pub struct Variable<ID> { pub align: Option<u32>, - pub v_type: T, + pub v_type: Type, + pub state_space: StateSpace, pub name: ID, pub array_init: Vec<u8>, } -#[derive(Eq, PartialEq, Clone)] -pub enum VariableType { - Reg(VariableRegType), - Local(VariableLocalType), - Param(VariableParamType), - Global(VariableGlobalType), - Shared(VariableGlobalType), -} - -impl VariableType { - pub fn to_type(&self) -> (StateSpace, Type) { - match self { - VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()), - VariableType::Local(t) => (StateSpace::Local, t.clone().into()), - VariableType::Param(t) => (StateSpace::Param, t.clone().into()), - VariableType::Global(t) => (StateSpace::Global, t.clone().into()), - VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()), - } - } -} - -impl From<VariableType> for Type { - fn from(t: VariableType) -> Self { - match t { - VariableType::Reg(t) => t.into(), - VariableType::Local(t) => t.into(), - VariableType::Param(t) => t.into(), - VariableType::Global(t) => t.into(), - VariableType::Shared(t) => t.into(), - } - } -} - #[derive(Copy, Clone, PartialEq, Eq)] pub enum StateSpace { Reg, @@ -419,6 +221,7 @@ pub enum StateSpace { Local, Shared, Param, + Generic, } pub struct PredAt<ID> { diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 44852a2..dc439b7 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -404,28 +404,29 @@ FnArguments: Vec<ast::FnArgument<&'input str>> = { "(" <args:Comma<FnInput>> ")" => args }; -KernelInput: ast::Variable<ast::KernelArgumentType, &'input str> = { +KernelInput: ast::Variable<&'input str> = { <v:ParamDeclaration> => { let (align, v_type, name) = v; ast::Variable { align, - v_type: ast::KernelArgumentType::Normal(v_type), + v_type, + state_space: ast::StateSpace::Param, name, array_init: Vec::new() } } } -FnInput: ast::Variable<ast::FnArgumentType, &'input str> = { +FnInput: ast::Variable<&'input str> = { <v:RegVariable> => { let (align, v_type, name) = v; - let v_type = ast::FnArgumentType::Reg(v_type); - ast::Variable{ align, v_type, name, array_init: Vec::new() } + let state_space = ast::StateSpace::Reg; + ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() } }, <v:ParamDeclaration> => { let (align, v_type, name) = v; - let v_type = ast::FnArgumentType::Param(v_type); - ast::Variable{ align, v_type, name, array_init: Vec::new() } + let state_space = ast::StateSpace::Param; + ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() } } } @@ -508,102 +509,109 @@ VariableParam: u32 = { "<" <n:U32Num> ">" => n } -Variable: ast::Variable<ast::VariableType, &'input str> = { +Variable: ast::Variable<&'input str> = { <v:RegVariable> => { let (align, v_type, name) = v; - let v_type = ast::VariableType::Reg(v_type); - ast::Variable {align, v_type, name, array_init: Vec::new()} + let state_space = ast::StateSpace::Reg; + ast::Variable {align, v_type, state_space, name, array_init: Vec::new()} }, LocalVariable, <v:ParamVariable> => { let (align, array_init, v_type, name) = v; - let v_type = ast::VariableType::Param(v_type); - ast::Variable {align, v_type, name, array_init} + let state_space = ast::StateSpace::Param; + ast::Variable {align, v_type, state_space, name, array_init} }, SharedVariable, }; -RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = { +RegVariable: (Option<u32>, ast::Type, &'input str) = { ".reg" <var:VariableScalar<ScalarType>> => { let (align, t, name) = var; - let v_type = ast::VariableRegType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, v_type, name) }, ".reg" <var:VariableVector<SizedScalarType>> => { let (align, v_len, t, name) = var; - let v_type = ast::VariableRegType::Vector(t, v_len); + let v_type = ast::Type::Vector(t, v_len); (align, v_type, name) } } -LocalVariable: ast::Variable<ast::VariableType, &'input str> = { +LocalVariable: ast::Variable<&'input str> = { ".local" <var:VariableScalar<SizedScalarType>> => { let (align, t, name) = var; - let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t)); - ast::Variable { align, v_type, name, array_init: Vec::new() } + let v_type = ast::Type::Scalar(t); + let state_space = ast::StateSpace::Local; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".local" <var:VariableVector<SizedScalarType>> => { let (align, v_len, t, name) = var; - let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len)); - ast::Variable { align, v_type, name, array_init: Vec::new() } + let v_type = ast::Type::Vector(t, v_len); + let state_space = ast::StateSpace::Local; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".local" <var:VariableArrayOrPointer<SizedScalarType>> =>? { let (align, t, name, arr_or_ptr) = var; + let state_space = ast::StateSpace::Local; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableLocalType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); } }; - Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init }) + Ok(ast::Variable { align, v_type, state_space, name, array_init }) } } -SharedVariable: ast::Variable<ast::VariableType, &'input str> = { +SharedVariable: ast::Variable<&'input str> = { ".shared" <var:VariableScalar<SizedScalarType>> => { let (align, t, name) = var; - let v_type = ast::VariableGlobalType::Scalar(t); - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + let v_type = ast::Type::Scalar(t); + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".shared" <var:VariableVector<SizedScalarType>> => { let (align, v_len, t, name) = var; - let v_type = ast::VariableGlobalType::Vector(t, v_len); - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + let v_type = ast::Type::Vector(t, v_len); + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".shared" <var:VariableArrayOrPointer<SizedScalarType>> =>? { let (align, t, name, arr_or_ptr) = var; + let state_space = ast::StateSpace::Shared; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableGlobalType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); } }; - Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init }) + Ok(ast::Variable { align, v_type, state_space, name, array_init }) } } - -ModuleVariable: ast::Variable<ast::VariableType, &'input str> = { +ModuleVariable: ast::Variable<&'input str> = { LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => { let (align, v_type, name, array_init) = def; - ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init } + let state_space = ast::StateSpace::Global; + ast::Variable { align, v_type, state_space, name, array_init } }, LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => { let (align, v_type, name, array_init) = def; - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, <ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? { let (align, t, name, arr_or_ptr) = var; - let (v_type, array_init) = match arr_or_ptr { + let (v_type, state_space, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { if space == ".global" { - (ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init) + (ast::Type::Array(t, dimensions), ast::StateSpace::Global, init) } else { - (ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init) + (ast::Type::Array(t, dimensions), ast::StateSpace::Shared, init) } } ast::ArrayOrPointer::Pointer => { @@ -611,38 +619,38 @@ ModuleVariable: ast::Variable<ast::VariableType, &'input str> = { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Global)), Vec::new()) + (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Global), ast::StateSpace::Global, Vec::new()) } else { - (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Shared)), Vec::new()) + (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Shared), ast::StateSpace::Shared, Vec::new()) } } }; - Ok(ast::Variable{ align, array_init, v_type, name }) + Ok(ast::Variable{ align, v_type, state_space, name, array_init }) } } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space -ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = { +ParamVariable: (Option<u32>, Vec<u8>, ast::Type, &'input str) = { ".param" <var:VariableScalar<LdStScalarType>> => { let (align, t, name) = var; - let v_type = ast::VariableParamType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, Vec::new(), v_type, name) }, ".param" <var:VariableArrayOrPointer<SizedScalarType>> => { let (align, t, name, arr_or_ptr) = var; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableParamType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { - (ast::VariableParamType::Pointer(t, ast::LdStateSpace::Param), Vec::new()) + (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Param), Vec::new()) } }; (align, array_init, v_type, name) } } -ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = { +ParamDeclaration: (Option<u32>, ast::Type, &'input str) = { <var:ParamVariable> =>? { let (align, array_init, v_type, name) = var; if array_init.len() > 0 { @@ -653,15 +661,15 @@ ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = { } } -GlobalVariableDefinitionNoArray: (Option<u32>, ast::VariableGlobalType, &'input str, Vec<u8>) = { +GlobalVariableDefinitionNoArray: (Option<u32>, ast::Type, &'input str, Vec<u8>) = { <scalar:VariableScalar<SizedScalarType>> => { let (align, t, name) = scalar; - let v_type = ast::VariableGlobalType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, v_type, name, Vec::new()) }, <var:VariableVector<SizedScalarType>> => { let (align, v_len, t, name) = var; - let v_type = ast::VariableGlobalType::Vector(t, v_len); + let v_type = ast::Type::Vector(t, v_len); (align, v_type, name, Vec::new()) }, } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 1f647bd..4ba5729 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -714,12 +714,13 @@ fn convert_dynamic_shared_memory_usage<'input>( let mut extern_shared_decls = HashMap::new();
for dir in module.iter() {
match dir {
- Directive::Variable(var) => {
- if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) =
- var.v_type
- {
- extern_shared_decls.insert(var.name, p_type);
- }
+ Directive::Variable(ast::Variable {
+ v_type: ast::Type::Pointer(p_type, ast::LdStateSpace::Shared),
+ state_space: ast::StateSpace::Shared,
+ name,
+ ..
+ }) => {
+ extern_shared_decls.insert(*name, p_type.clone());
}
_ => {}
}
@@ -796,25 +797,27 @@ fn convert_dynamic_shared_memory_usage<'input>( let shared_id_param = new_id();
spirv_decl.input.push({
ast::Variable {
+ name: shared_id_param,
align: None,
v_type: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::PointerType::Scalar(ast::ScalarType::B8),
ast::LdStateSpace::Shared,
),
+ state_space: ast::StateSpace::Param,
array_init: Vec::new(),
- name: shared_id_param,
}
});
spirv_decl.uses_shared_mem = true;
let shared_var_id = new_id();
let shared_var = ExpandedStatement::Variable(ast::Variable {
- align: None,
name: shared_var_id,
- array_init: Vec::new(),
- v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
- ast::ScalarType::B8,
+ align: None,
+ v_type: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::B8),
ast::LdStateSpace::Shared,
- )),
+ ),
+ state_space: ast::StateSpace::Reg,
+ array_init: Vec::new(),
});
let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
arg: ast::Arg2St {
@@ -851,7 +854,7 @@ fn convert_dynamic_shared_memory_usage<'input>( fn replace_uses_of_shared_memory<'a>(
result: &mut Vec<ExpandedStatement>,
new_id: &mut impl FnMut() -> spirv::Word,
- extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
+ extern_shared_decls: &HashMap<spirv::Word, ast::PointerType>,
methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
shared_id_param: spirv::Word,
shared_var_id: spirv::Word,
@@ -864,14 +867,17 @@ fn replace_uses_of_shared_memory<'a>( // because there's simply no way to pass shared ptr
// without converting it to .b64 first
if methods_using_extern_shared.contains(&MethodName::Func(call.func)) {
- call.param_list
- .push((shared_id_param, ast::FnArgumentType::Shared));
+ call.param_list.push((
+ shared_id_param,
+ ast::Type::Scalar(ast::ScalarType::B8),
+ ast::StateSpace::Shared,
+ ));
}
result.push(Statement::Call(call))
}
statement => {
let new_statement = statement.map_id(&mut |id, _| {
- if let Some(typ) = extern_shared_decls.get(&id) {
+ if let Some(ast::PointerType::Scalar(typ)) = extern_shared_decls.get(&id) {
if *typ == ast::ScalarType::B8 {
return shared_var_id;
}
@@ -1067,7 +1073,7 @@ fn emit_function_header<'a>( builder: &mut dr::Builder,
map: &mut TypeWordMap,
defined_globals: &GlobalStringIdResolver<'a>,
- synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
+ synthetic_globals: &[ast::Variable<spirv::Word>],
func_decl: &SpirvMethodDecl<'a>,
_denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
@@ -1204,9 +1210,9 @@ fn translate_directive<'input>( fn translate_variable<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
- var: ast::Variable<ast::VariableType, &'a str>,
-) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
- let (space, var_type) = var.v_type.to_type();
+ var: ast::Variable<&'a str>,
+) -> Result<ast::Variable<spirv::Word>, TranslateError> {
+ let (space, var_type) = (var.state_space, var.v_type.clone());
let mut is_variable = false;
let var_type = match space {
ast::StateSpace::Reg => {
@@ -1226,10 +1232,12 @@ fn translate_variable<'a>( }
}
ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
+ ast::StateSpace::Generic => todo!(),
};
Ok(ast::Variable {
align: var.align,
v_type: var.v_type,
+ state_space: var.state_space,
name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
array_init: var.array_init,
})
@@ -1279,6 +1287,7 @@ fn expand_kernel_params<'a, 'b>( false,
),
v_type: a.v_type.clone(),
+ state_space: a.state_space,
align: a.align,
array_init: Vec::new(),
})
@@ -1291,14 +1300,11 @@ fn expand_fn_params<'a, 'b>( args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
args.map(|a| {
- let is_variable = match a.v_type {
- ast::FnArgumentType::Reg(_) => true,
- _ => false,
- };
- let var_type = a.v_type.to_func_type();
+ let is_variable = a.state_space == ast::StateSpace::Reg;
Ok(ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some(var_type), is_variable),
+ name: fn_resolver.add_def(a.name, Some(a.v_type.clone()), is_variable),
v_type: a.v_type.clone(),
+ state_space: a.state_space,
align: a.align,
array_init: Vec::new(),
})
@@ -1444,10 +1450,7 @@ fn extract_globals<'input, 'b>( sorted_statements: Vec<ExpandedStatement>,
ptx_impl_imports: &mut HashMap<String, Directive>,
id_def: &mut NumericIdResolver,
-) -> (
- Vec<ExpandedStatement>,
- Vec<ast::Variable<ast::VariableType, spirv::Word>>,
-) {
+) -> (Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>) {
let mut local = Vec::with_capacity(sorted_statements.len());
let mut global = Vec::new();
for statement in sorted_statements {
@@ -1456,7 +1459,7 @@ fn extract_globals<'input, 'b>( var
@
ast::Variable {
- v_type: ast::VariableType::Shared(_),
+ state_space: ast::StateSpace::Shared,
..
},
)
@@ -1464,7 +1467,7 @@ fn extract_globals<'input, 'b>( var
@
ast::Variable {
- v_type: ast::VariableType::Global(_),
+ state_space: ast::StateSpace::Global,
..
},
) => global.push(var),
@@ -1592,10 +1595,10 @@ fn convert_to_typed_statements( let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params);
let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args
.into_iter()
- .partition(|(_, arg_type)| arg_type.is_param());
+ .partition(|(_, _, space)| *space == ast::StateSpace::Param);
let normalized_input_args = out_params
.into_iter()
- .map(|(id, typ)| (ast::Operand::Reg(id), typ))
+ .map(|(id, typ, space)| (ast::Operand::Reg(id), typ, space))
.chain(in_args.into_iter())
.collect();
let resolved_call = ResolvedCall {
@@ -1744,7 +1747,8 @@ fn to_ptx_impl_atomic_call( let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
+ v_type: ast::Type::Scalar(scalar_typ),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
@@ -1752,15 +1756,15 @@ fn to_ptx_impl_atomic_call( vec![
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
- typ, ptr_space,
- )),
+ v_type: ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
+ v_type: ast::Type::Scalar(scalar_typ),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
@@ -1789,18 +1793,17 @@ fn to_ptx_impl_atomic_call( Statement::Call(ResolvedCall {
uniform: false,
func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
- )],
+ ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
param_list: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(typ, ptr_space)),
+ ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
+ ast::Type::Scalar(scalar_typ),
+ ast::StateSpace::Reg,
),
],
})
@@ -1827,7 +1830,8 @@ fn to_ptx_impl_bfe_call( let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
@@ -1835,23 +1839,22 @@ fn to_ptx_impl_bfe_call( vec![
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
@@ -1880,22 +1883,22 @@ fn to_ptx_impl_bfe_call( Statement::Call(ResolvedCall {
uniform: false,
func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- )],
+ ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
param_list: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
(
arg.src3,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
],
})
@@ -1920,7 +1923,8 @@ fn to_ptx_impl_bfi_call( let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
@@ -1928,29 +1932,29 @@ fn to_ptx_impl_bfi_call( vec![
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
@@ -1979,26 +1983,27 @@ fn to_ptx_impl_bfi_call( Statement::Call(ResolvedCall {
uniform: false,
func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- )],
+ ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
param_list: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src3,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
(
arg.src4,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
],
})
@@ -2006,12 +2011,12 @@ fn to_ptx_impl_bfi_call( fn to_resolved_fn_args<T>(
params: Vec<T>,
- params_decl: &[ast::FnArgumentType],
-) -> Vec<(T, ast::FnArgumentType)> {
+ params_decl: &[(ast::Type, ast::StateSpace)],
+) -> Vec<(T, ast::Type, ast::StateSpace)> {
params
.into_iter()
.zip(params_decl.iter())
- .map(|(id, typ)| (id, typ.clone()))
+ .map(|(id, (typ, space))| (id, typ.clone(), *space))
.collect::<Vec<_>>()
}
@@ -2096,50 +2101,38 @@ fn normalize_predicates( fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
- ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>,
+ _: &'a ast::MethodDecl<'b, spirv::Word>,
fn_decl: &mut SpirvMethodDecl,
) -> Result<Vec<TypedStatement>, TranslateError> {
- let is_func = match ast_fn_decl {
- ast::MethodDecl::Func(..) => true,
- ast::MethodDecl::Kernel { .. } => false,
- };
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.output.iter() {
- match type_to_variable_type(&arg.v_type, is_func)? {
- Some(var_type) => {
- result.push(Statement::Variable(ast::Variable {
- align: arg.align,
- v_type: var_type,
- name: arg.name,
- array_init: arg.array_init.clone(),
- }));
- }
- None => return Err(error_unreachable()),
- }
+ result.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: arg.state_space,
+ name: arg.name,
+ array_init: arg.array_init.clone(),
+ }));
}
for spirv_arg in fn_decl.input.iter_mut() {
- match type_to_variable_type(&spirv_arg.v_type, is_func)? {
- Some(var_type) => {
- let typ = spirv_arg.v_type.clone();
- let new_id = id_def.new_non_variable(Some(typ.clone()));
- result.push(Statement::Variable(ast::Variable {
- align: spirv_arg.align,
- v_type: var_type,
- name: spirv_arg.name,
- array_init: spirv_arg.array_init.clone(),
- }));
- result.push(Statement::StoreVar(StoreVarDetails {
- arg: ast::Arg2St {
- src1: spirv_arg.name,
- src2: new_id,
- },
- typ,
- member_index: None,
- }));
- spirv_arg.name = new_id;
- }
- None => {}
- }
+ let typ = spirv_arg.v_type.clone();
+ let new_id = id_def.new_non_variable(Some(typ.clone()));
+ result.push(Statement::Variable(ast::Variable {
+ align: spirv_arg.align,
+ v_type: spirv_arg.v_type.clone(),
+ state_space: spirv_arg.state_space,
+ name: spirv_arg.name,
+ array_init: spirv_arg.array_init.clone(),
+ }));
+ result.push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::Arg2St {
+ src1: spirv_arg.name,
+ src2: new_id,
+ },
+ typ,
+ member_index: None,
+ }));
+ spirv_arg.name = new_id;
}
for s in func {
match s {
@@ -2197,41 +2190,6 @@ fn insert_mem_ssa_statements<'a, 'b>( Ok(result)
}
-fn type_to_variable_type(
- t: &ast::Type,
- is_func: bool,
-) -> Result<Option<ast::VariableType>, TranslateError> {
- Ok(match t {
- ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
- ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
- (*typ)
- .try_into()
- .map_err(|_| TranslateError::MismatchedType)?,
- *len,
- ))),
- ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array(
- (*typ)
- .try_into()
- .map_err(|_| TranslateError::MismatchedType)?,
- len.clone(),
- ))),
- ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
- if is_func {
- return Ok(None);
- }
- Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
- scalar_type
- .clone()
- .try_into()
- .map_err(|_| error_unreachable())?,
- (*space).try_into().map_err(|_| error_unreachable())?,
- )))
- }
- ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
- _ => return Err(error_unreachable()),
- })
-}
-
trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
fn visit(
self,
@@ -2398,11 +2356,13 @@ fn expand_arguments<'a, 'b>( Statement::Variable(ast::Variable {
align,
v_type,
+ state_space,
name,
array_init,
}) => result.push(Statement::Variable(ast::Variable {
align,
v_type,
+ state_space,
name,
array_init,
})),
@@ -2784,8 +2744,8 @@ fn insert_implicit_conversions_impl( fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- spirv_input: &[ast::Variable<ast::Type, spirv::Word>],
- spirv_output: &[ast::Variable<ast::Type, spirv::Word>],
+ spirv_input: &[ast::Variable<spirv::Word>],
+ spirv_output: &[ast::Variable<spirv::Word>],
) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn(
builder,
@@ -2822,8 +2782,8 @@ fn emit_function_body_ops( Statement::Label(_) => (),
Statement::Call(call) => {
let (result_type, result_id) = match &*call.ret_params {
- [(id, typ)] => (
- map.get_or_add(builder, SpirvType::from(typ.to_func_type())),
+ [(id, typ, _)] => (
+ map.get_or_add(builder, SpirvType::from(typ.clone())),
Some(*id),
),
[] => (map.void(), None),
@@ -2832,7 +2792,7 @@ fn emit_function_body_ops( let arg_list = call
.param_list
.iter()
- .map(|(id, _)| *id)
+ .map(|(id, _, _)| *id)
.collect::<Vec<_>>();
builder.function_call(result_type, result_id, call.func, arg_list)?;
}
@@ -3602,14 +3562,16 @@ fn vec_repr<T: Copy>(t: T) -> Vec<u8> { fn emit_variable(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- var: &ast::Variable<ast::VariableType, spirv::Word>,
+ var: &ast::Variable<spirv::Word>,
) -> Result<(), TranslateError> {
- let (must_init, st_class) = match var.v_type {
- ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
+ let (must_init, st_class) = match var.state_space {
+ ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
(false, spirv::StorageClass::Function)
}
- ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
- ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup),
+ ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
+ ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
+ ast::StateSpace::Const => todo!(),
+ ast::StateSpace::Generic => todo!(),
};
let initalizer = if var.array_init.len() > 0 {
Some(map.get_or_add_constant(
@@ -4460,12 +4422,12 @@ fn expand_map_variables<'a, 'b>( ast::Statement::Variable(var) => {
let mut var_type = ast::Type::from(var.var.v_type.clone());
let mut is_variable = false;
- var_type = match var.var.v_type {
- ast::VariableType::Reg(_) => {
+ var_type = match var.var.state_space {
+ ast::StateSpace::Reg => {
is_variable = true;
var_type
}
- ast::VariableType::Shared(_) => {
+ ast::StateSpace::Shared => {
// If it's a pointer it will be translated to a method parameter later
if let ast::Type::Pointer(..) = var_type {
is_variable = true;
@@ -4474,15 +4436,11 @@ fn expand_map_variables<'a, 'b>( var_type.param_pointer_to(ast::LdStateSpace::Shared)?
}
}
- ast::VariableType::Global(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Global)?
- }
- ast::VariableType::Param(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Param)?
- }
- ast::VariableType::Local(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Local)?
- }
+ ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
+ ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
+ ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
+ ast::StateSpace::Const => todo!(),
+ ast::StateSpace::Generic => todo!(),
};
match var.count {
Some(count) => {
@@ -4490,6 +4448,7 @@ fn expand_map_variables<'a, 'b>( result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init.clone(),
}))
@@ -4500,6 +4459,7 @@ fn expand_map_variables<'a, 'b>( result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init,
}));
@@ -4659,10 +4619,11 @@ fn convert_to_stateful_memory_access<'a>( align: None,
name: new_id,
array_init: Vec::new(),
- v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
- ast::ScalarType::U8,
+ v_type: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
ast::LdStateSpace::Global,
- )),
+ ),
+ state_space: ast::StateSpace::Reg,
}));
remapped_ids.insert(reg, new_id);
}
@@ -5052,8 +5013,8 @@ struct GlobalStringIdResolver<'input> { }
pub struct FnDecl {
- ret_vals: Vec<ast::FnArgumentType>,
- params: Vec<ast::FnArgumentType>,
+ ret_vals: Vec<(ast::Type, ast::StateSpace)>,
+ params: Vec<(ast::Type, ast::StateSpace)>,
}
impl<'a> GlobalStringIdResolver<'a> {
@@ -5137,8 +5098,14 @@ impl<'a> GlobalStringIdResolver<'a> { self.fns.insert(
name_id,
FnDecl {
- ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(),
- params: params_ids.iter().map(|p| p.v_type.clone()).collect(),
+ ret_vals: ret_params_ids
+ .iter()
+ .map(|p| (p.v_type.clone(), p.state_space))
+ .collect(),
+ params: params_ids
+ .iter()
+ .map(|p| (p.v_type.clone(), p.state_space))
+ .collect(),
},
);
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
@@ -5314,7 +5281,7 @@ impl<'b> MutableNumericIdResolver<'b> { enum Statement<I, P: ast::ArgParams> {
Label(u32),
- Variable(ast::Variable<ast::VariableType, P::Id>),
+ Variable(ast::Variable<P::Id>),
Instruction(I),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
@@ -5352,16 +5319,17 @@ impl ExpandedStatement { Statement::StoreVar(details)
}
Statement::Call(mut call) => {
- for (id, typ) in call.ret_params.iter_mut() {
- let is_dst = match typ {
- ast::FnArgumentType::Reg(_) => true,
- ast::FnArgumentType::Param(_) => false,
- ast::FnArgumentType::Shared => false,
+ for (id, _, space) in call.ret_params.iter_mut() {
+ let is_dst = match space {
+ ast::StateSpace::Reg => true,
+ ast::StateSpace::Param => false,
+ ast::StateSpace::Shared => false,
+ _ => todo!(),
};
*id = f(*id, is_dst);
}
call.func = f(call.func, false);
- for (id, _) in call.param_list.iter_mut() {
+ for (id, _, _) in call.param_list.iter_mut() {
*id = f(*id, false);
}
Statement::Call(call)
@@ -5502,9 +5470,9 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
- pub ret_params: Vec<(P::Id, ast::FnArgumentType)>,
+ pub ret_params: Vec<(P::Id, ast::Type, ast::StateSpace)>,
pub func: P::Id,
- pub param_list: Vec<(P::Operand, ast::FnArgumentType)>,
+ pub param_list: Vec<(P::Operand, ast::Type, ast::StateSpace)>,
}
impl<T: ast::ArgParams> ResolvedCall<T> {
@@ -5526,16 +5494,16 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> { let ret_params = self
.ret_params
.into_iter()
- .map::<Result<_, TranslateError>, _>(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.id(
ArgumentDescriptor {
op: id,
- is_dst: !typ.is_param(),
- sema: typ.semantics(),
+ is_dst: space != ast::StateSpace::Param,
+ sema: space.semantics(),
},
- Some(&typ.to_func_type()),
+ Some(&typ),
)?;
- Ok((new_id, typ))
+ Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
let func = visitor.id(
@@ -5549,16 +5517,16 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> { let param_list = self
.param_list
.into_iter()
- .map::<Result<_, TranslateError>, _>(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.operand(
ArgumentDescriptor {
op: id,
is_dst: false,
- sema: typ.semantics(),
+ sema: space.semantics(),
},
- &typ.to_func_type(),
+ &typ,
)?;
- Ok((new_id, typ))
+ Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedCall {
@@ -5738,14 +5706,14 @@ impl ArgParamsEx for ExpandedArgParams { }
enum Directive<'input> {
- Variable(ast::Variable<ast::VariableType, spirv::Word>),
+ Variable(ast::Variable<spirv::Word>),
Method(Function<'input>),
}
struct Function<'input> {
pub func_decl: ast::MethodDecl<'input, spirv::Word>,
pub spirv_decl: SpirvMethodDecl<'input>,
- pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
+ pub globals: Vec<ast::Variable<spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
@@ -7300,16 +7268,6 @@ impl ast::LdStateSpace { }
}
-impl From<ast::FnArgumentType> for ast::VariableType {
- fn from(t: ast::FnArgumentType) -> Self {
- match t {
- ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
- ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
- ast::FnArgumentType::Shared => todo!(),
- }
- }
-}
-
impl<T> ast::Operand<T> {
fn underlying(&self) -> Option<&T> {
match self {
@@ -7362,12 +7320,13 @@ impl ast::AtomSemantics { }
}
-impl ast::FnArgumentType {
- fn semantics(&self) -> ArgumentSemantics {
+impl ast::StateSpace {
+ fn semantics(self) -> ArgumentSemantics {
match self {
- ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default,
- ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer,
- ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer,
+ ast::StateSpace::Reg => ArgumentSemantics::Default,
+ ast::StateSpace::Param => ArgumentSemantics::RegisterPointer,
+ ast::StateSpace::Shared => ArgumentSemantics::PhysicalPointer,
+ _ => todo!(),
}
}
}
@@ -7677,8 +7636,8 @@ impl<'a> ast::MethodDecl<'a, &'a str> { }
struct SpirvMethodDecl<'input> {
- input: Vec<ast::Variable<ast::Type, spirv::Word>>,
- output: Vec<ast::Variable<ast::Type, spirv::Word>>,
+ input: Vec<ast::Variable<spirv::Word>>,
+ output: Vec<ast::Variable<spirv::Word>>,
name: MethodName<'input>,
uses_shared_mem: bool,
}
@@ -7689,33 +7648,28 @@ impl<'input> SpirvMethodDecl<'input> { ast::MethodDecl::Kernel { in_args, .. } => {
let spirv_input = in_args
.iter()
- .map(|var| {
- let v_type = match &var.v_type {
- ast::KernelArgumentType::Normal(t) => {
- ast::FnArgumentType::Param(t.clone())
- }
- ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared,
- };
- ast::Variable {
- name: var.name,
- align: var.align,
- v_type: v_type.to_kernel_type(),
- array_init: var.array_init.clone(),
- }
+ .map(|var| ast::Variable {
+ name: var.name,
+ align: var.align,
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
+ array_init: var.array_init.clone(),
})
.collect();
(spirv_input, Vec::new())
}
ast::MethodDecl::Func(out_args, _, in_args) => {
- let (param_output, non_param_output): (Vec<_>, Vec<_>) =
- out_args.iter().partition(|var| var.v_type.is_param());
+ let (param_output, non_param_output): (Vec<_>, Vec<_>) = out_args
+ .iter()
+ .partition(|var| var.state_space == ast::StateSpace::Param);
let spirv_output = non_param_output
.into_iter()
.cloned()
.map(|var| ast::Variable {
name: var.name,
align: var.align,
- v_type: var.v_type.to_func_type(),
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
array_init: var.array_init.clone(),
})
.collect();
@@ -7726,7 +7680,8 @@ impl<'input> SpirvMethodDecl<'input> { .map(|var| ast::Variable {
name: var.name,
align: var.align,
- v_type: var.v_type.to_func_type(),
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
array_init: var.array_init.clone(),
})
.collect();
|