aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-04-17 14:01:50 +0200
committerAndrzej Janik <[email protected]>2021-04-17 14:01:50 +0200
commitd51aaaf5529dbfec0735c73768e468728112c26b (patch)
tree2420b0e35bbc93dd0d53f37f6828541e2e76e878
parenta55c851eaa4ded60d5f62aba1d7da850a63163f3 (diff)
downloadZLUDA-d51aaaf5529dbfec0735c73768e468728112c26b.tar.gz
ZLUDA-d51aaaf5529dbfec0735c73768e468728112c26b.zip
Throw away special variable types
-rw-r--r--ptx/src/ast.rs215
-rw-r--r--ptx/src/ptx.lalrpop102
-rw-r--r--ptx/src/translate.rs429
3 files changed, 256 insertions, 490 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 3e62cb1..c7b9563 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -1,6 +1,5 @@
use half::f16;
use lalrpop_util::{lexer::Token, ParseError};
-use std::convert::TryInto;
use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr};
use std::{marker::PhantomData, num::ParseIntError};
@@ -34,107 +33,12 @@ pub enum PtxError {
NonExternPointer,
}
-macro_rules! sub_type {
- ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
- sub_type! { $type_name : Type {
- $(
- $variant ($($field_type),+),
- )+
- }}
- };
- ($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
- #[derive(PartialEq, Eq, Clone)]
- pub enum $type_name {
- $(
- $variant ($($field_type),+),
- )+
- }
-
- impl From<$type_name> for $base_type {
- #[allow(non_snake_case)]
- fn from(t: $type_name) -> $base_type {
- match t {
- $(
- $type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+),
- )+
- }
- }
- }
-
- impl std::convert::TryFrom<$base_type> for $type_name {
- type Error = ();
-
- #[allow(non_snake_case)]
- #[allow(unreachable_patterns)]
- fn try_from(t: $base_type) -> Result<Self, Self::Error> {
- match t {
- $(
- $base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
- )+
- _ => Err(()),
- }
- }
- }
- };
-}
-
-sub_type! {
- VariableRegType {
- Scalar(ScalarType),
- Vector(ScalarType, u8),
- // Array type is used when emiting SSA statements at the start of a method
- Array(ScalarType, VecU32),
- // Pointer variant is used when passing around SLM pointer between
- // function calls for dynamic SLM
- Pointer(ScalarType, LdStateSpace)
- }
-}
-
-type VecU32 = Vec<u32>;
-
-sub_type! {
- VariableLocalType {
- Scalar(ScalarType),
- Vector(ScalarType, u8),
- Array(ScalarType, VecU32),
- }
-}
-
-impl TryFrom<VariableGlobalType> for VariableLocalType {
- type Error = PtxError;
-
- fn try_from(value: VariableGlobalType) -> Result<Self, Self::Error> {
- match value {
- VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)),
- VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)),
- VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)),
- VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray),
- }
- }
-}
-
-sub_type! {
- VariableGlobalType {
- Scalar(ScalarType),
- Vector(ScalarType, u8),
- Array(ScalarType, VecU32),
- Pointer(ScalarType, LdStateSpace),
- }
-}
-
// For some weird reson this is illegal:
// .param .f16x2 foobar;
// but this is legal:
// .param .f16x2 foobar[1];
// even more interestingly this is legal, but only in .func (not in .entry):
// .param .b32 foobar[]
-sub_type! {
- VariableParamType {
- Scalar(ScalarType),
- Array(ScalarType, VecU32),
- Pointer(ScalarType, LdStateSpace),
- }
-}
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum BarDetails {
@@ -178,7 +82,7 @@ pub struct Module<'a> {
}
pub enum Directive<'a, P: ArgParams> {
- Variable(Variable<VariableType, P::Id>),
+ Variable(Variable<P::Id>),
Method(Function<'a, &'a str, Statement<P>>),
}
@@ -190,8 +94,8 @@ pub enum MethodDecl<'a, ID> {
},
}
-pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
-pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>;
+pub type FnArgument<ID> = Variable<ID>;
+pub type KernelArgument<ID> = Variable<ID>;
pub struct Function<'a, ID, S> {
pub func_directive: MethodDecl<'a, ID>,
@@ -202,76 +106,6 @@ pub struct Function<'a, ID, S> {
pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>;
#[derive(PartialEq, Eq, Clone)]
-pub enum FnArgumentType {
- Reg(VariableRegType),
- Param(VariableParamType),
- Shared,
-}
-#[derive(PartialEq, Eq, Clone)]
-pub enum KernelArgumentType {
- Normal(VariableParamType),
- Shared,
-}
-
-impl From<KernelArgumentType> for Type {
- fn from(this: KernelArgumentType) -> Self {
- match this {
- KernelArgumentType::Normal(typ) => typ.into(),
- KernelArgumentType::Shared => {
- Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
- }
- }
- }
-}
-
-impl FnArgumentType {
- pub fn to_type(&self, is_kernel: bool) -> Type {
- if is_kernel {
- self.to_kernel_type()
- } else {
- self.to_func_type()
- }
- }
-
- pub fn to_kernel_type(&self) -> Type {
- match self {
- FnArgumentType::Reg(x) => x.clone().into(),
- FnArgumentType::Param(x) => x.clone().into(),
- FnArgumentType::Shared => {
- Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
- }
- }
- }
-
- pub fn to_func_type(&self) -> Type {
- match self {
- FnArgumentType::Reg(x) => x.clone().into(),
- FnArgumentType::Param(VariableParamType::Scalar(t)) => {
- Type::Pointer(PointerType::Scalar((*t).into()), LdStateSpace::Param)
- }
- FnArgumentType::Param(VariableParamType::Array(t, dims)) => Type::Pointer(
- PointerType::Array((*t).into(), dims.clone()),
- LdStateSpace::Param,
- ),
- FnArgumentType::Param(VariableParamType::Pointer(t, space)) => Type::Pointer(
- PointerType::Pointer((*t).into(), (*space).into()),
- LdStateSpace::Param,
- ),
- FnArgumentType::Shared => {
- Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
- }
- }
- }
-
- pub fn is_param(&self) -> bool {
- match self {
- FnArgumentType::Param(_) => true,
- _ => false,
- }
- }
-}
-
-#[derive(PartialEq, Eq, Clone)]
pub enum Type {
Scalar(ScalarType),
Vector(ScalarType, u8),
@@ -283,7 +117,7 @@ pub enum Type {
pub enum PointerType {
Scalar(ScalarType),
Vector(ScalarType, u8),
- Array(ScalarType, VecU32),
+ Array(ScalarType, Vec<u32>),
// Instances of this variant are generated during stateful conversion
Pointer(ScalarType, LdStateSpace),
}
@@ -366,51 +200,19 @@ pub enum Statement<P: ArgParams> {
}
pub struct MultiVariable<ID> {
- pub var: Variable<VariableType, ID>,
+ pub var: Variable<ID>,
pub count: Option<u32>,
}
#[derive(Clone)]
-pub struct Variable<T, ID> {
+pub struct Variable<ID> {
pub align: Option<u32>,
- pub v_type: T,
+ pub v_type: Type,
+ pub state_space: StateSpace,
pub name: ID,
pub array_init: Vec<u8>,
}
-#[derive(Eq, PartialEq, Clone)]
-pub enum VariableType {
- Reg(VariableRegType),
- Local(VariableLocalType),
- Param(VariableParamType),
- Global(VariableGlobalType),
- Shared(VariableGlobalType),
-}
-
-impl VariableType {
- pub fn to_type(&self) -> (StateSpace, Type) {
- match self {
- VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()),
- VariableType::Local(t) => (StateSpace::Local, t.clone().into()),
- VariableType::Param(t) => (StateSpace::Param, t.clone().into()),
- VariableType::Global(t) => (StateSpace::Global, t.clone().into()),
- VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()),
- }
- }
-}
-
-impl From<VariableType> for Type {
- fn from(t: VariableType) -> Self {
- match t {
- VariableType::Reg(t) => t.into(),
- VariableType::Local(t) => t.into(),
- VariableType::Param(t) => t.into(),
- VariableType::Global(t) => t.into(),
- VariableType::Shared(t) => t.into(),
- }
- }
-}
-
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum StateSpace {
Reg,
@@ -419,6 +221,7 @@ pub enum StateSpace {
Local,
Shared,
Param,
+ Generic,
}
pub struct PredAt<ID> {
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 44852a2..dc439b7 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -404,28 +404,29 @@ FnArguments: Vec<ast::FnArgument<&'input str>> = {
"(" <args:Comma<FnInput>> ")" => args
};
-KernelInput: ast::Variable<ast::KernelArgumentType, &'input str> = {
+KernelInput: ast::Variable<&'input str> = {
<v:ParamDeclaration> => {
let (align, v_type, name) = v;
ast::Variable {
align,
- v_type: ast::KernelArgumentType::Normal(v_type),
+ v_type,
+ state_space: ast::StateSpace::Param,
name,
array_init: Vec::new()
}
}
}
-FnInput: ast::Variable<ast::FnArgumentType, &'input str> = {
+FnInput: ast::Variable<&'input str> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
- let v_type = ast::FnArgumentType::Reg(v_type);
- ast::Variable{ align, v_type, name, array_init: Vec::new() }
+ let state_space = ast::StateSpace::Reg;
+ ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() }
},
<v:ParamDeclaration> => {
let (align, v_type, name) = v;
- let v_type = ast::FnArgumentType::Param(v_type);
- ast::Variable{ align, v_type, name, array_init: Vec::new() }
+ let state_space = ast::StateSpace::Param;
+ ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() }
}
}
@@ -508,102 +509,109 @@ VariableParam: u32 = {
"<" <n:U32Num> ">" => n
}
-Variable: ast::Variable<ast::VariableType, &'input str> = {
+Variable: ast::Variable<&'input str> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
- let v_type = ast::VariableType::Reg(v_type);
- ast::Variable {align, v_type, name, array_init: Vec::new()}
+ let state_space = ast::StateSpace::Reg;
+ ast::Variable {align, v_type, state_space, name, array_init: Vec::new()}
},
LocalVariable,
<v:ParamVariable> => {
let (align, array_init, v_type, name) = v;
- let v_type = ast::VariableType::Param(v_type);
- ast::Variable {align, v_type, name, array_init}
+ let state_space = ast::StateSpace::Param;
+ ast::Variable {align, v_type, state_space, name, array_init}
},
SharedVariable,
};
-RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
+RegVariable: (Option<u32>, ast::Type, &'input str) = {
".reg" <var:VariableScalar<ScalarType>> => {
let (align, t, name) = var;
- let v_type = ast::VariableRegType::Scalar(t);
+ let v_type = ast::Type::Scalar(t);
(align, v_type, name)
},
".reg" <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var;
- let v_type = ast::VariableRegType::Vector(t, v_len);
+ let v_type = ast::Type::Vector(t, v_len);
(align, v_type, name)
}
}
-LocalVariable: ast::Variable<ast::VariableType, &'input str> = {
+LocalVariable: ast::Variable<&'input str> = {
".local" <var:VariableScalar<SizedScalarType>> => {
let (align, t, name) = var;
- let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
- ast::Variable { align, v_type, name, array_init: Vec::new() }
+ let v_type = ast::Type::Scalar(t);
+ let state_space = ast::StateSpace::Local;
+ ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
".local" <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var;
- let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
- ast::Variable { align, v_type, name, array_init: Vec::new() }
+ let v_type = ast::Type::Vector(t, v_len);
+ let state_space = ast::StateSpace::Local;
+ ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
".local" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var;
+ let state_space = ast::StateSpace::Local;
let (v_type, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
- (ast::VariableLocalType::Array(t, dimensions), init)
+ (ast::Type::Array(t, dimensions), init)
}
ast::ArrayOrPointer::Pointer => {
return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
}
};
- Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init })
+ Ok(ast::Variable { align, v_type, state_space, name, array_init })
}
}
-SharedVariable: ast::Variable<ast::VariableType, &'input str> = {
+SharedVariable: ast::Variable<&'input str> = {
".shared" <var:VariableScalar<SizedScalarType>> => {
let (align, t, name) = var;
- let v_type = ast::VariableGlobalType::Scalar(t);
- ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
+ let state_space = ast::StateSpace::Shared;
+ let v_type = ast::Type::Scalar(t);
+ ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
".shared" <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var;
- let v_type = ast::VariableGlobalType::Vector(t, v_len);
- ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
+ let state_space = ast::StateSpace::Shared;
+ let v_type = ast::Type::Vector(t, v_len);
+ ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
".shared" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var;
+ let state_space = ast::StateSpace::Shared;
let (v_type, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
- (ast::VariableGlobalType::Array(t, dimensions), init)
+ (ast::Type::Array(t, dimensions), init)
}
ast::ArrayOrPointer::Pointer => {
return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
}
};
- Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init })
+ Ok(ast::Variable { align, v_type, state_space, name, array_init })
}
}
-
-ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
+ModuleVariable: ast::Variable<&'input str> = {
LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => {
let (align, v_type, name, array_init) = def;
- ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init }
+ let state_space = ast::StateSpace::Global;
+ ast::Variable { align, v_type, state_space, name, array_init }
},
LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => {
let (align, v_type, name, array_init) = def;
- ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
+ let state_space = ast::StateSpace::Shared;
+ ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
<ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var;
- let (v_type, array_init) = match arr_or_ptr {
+ let (v_type, state_space, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
if space == ".global" {
- (ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init)
+ (ast::Type::Array(t, dimensions), ast::StateSpace::Global, init)
} else {
- (ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init)
+ (ast::Type::Array(t, dimensions), ast::StateSpace::Shared, init)
}
}
ast::ArrayOrPointer::Pointer => {
@@ -611,38 +619,38 @@ ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
}
if space == ".global" {
- (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Global)), Vec::new())
+ (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Global), ast::StateSpace::Global, Vec::new())
} else {
- (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::LdStateSpace::Shared)), Vec::new())
+ (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Shared), ast::StateSpace::Shared, Vec::new())
}
}
};
- Ok(ast::Variable{ align, array_init, v_type, name })
+ Ok(ast::Variable{ align, v_type, state_space, name, array_init })
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
-ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
+ParamVariable: (Option<u32>, Vec<u8>, ast::Type, &'input str) = {
".param" <var:VariableScalar<LdStScalarType>> => {
let (align, t, name) = var;
- let v_type = ast::VariableParamType::Scalar(t);
+ let v_type = ast::Type::Scalar(t);
(align, Vec::new(), v_type, name)
},
".param" <var:VariableArrayOrPointer<SizedScalarType>> => {
let (align, t, name, arr_or_ptr) = var;
let (v_type, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
- (ast::VariableParamType::Array(t, dimensions), init)
+ (ast::Type::Array(t, dimensions), init)
}
ast::ArrayOrPointer::Pointer => {
- (ast::VariableParamType::Pointer(t, ast::LdStateSpace::Param), Vec::new())
+ (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Param), Vec::new())
}
};
(align, array_init, v_type, name)
}
}
-ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
+ParamDeclaration: (Option<u32>, ast::Type, &'input str) = {
<var:ParamVariable> =>? {
let (align, array_init, v_type, name) = var;
if array_init.len() > 0 {
@@ -653,15 +661,15 @@ ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
}
}
-GlobalVariableDefinitionNoArray: (Option<u32>, ast::VariableGlobalType, &'input str, Vec<u8>) = {
+GlobalVariableDefinitionNoArray: (Option<u32>, ast::Type, &'input str, Vec<u8>) = {
<scalar:VariableScalar<SizedScalarType>> => {
let (align, t, name) = scalar;
- let v_type = ast::VariableGlobalType::Scalar(t);
+ let v_type = ast::Type::Scalar(t);
(align, v_type, name, Vec::new())
},
<var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var;
- let v_type = ast::VariableGlobalType::Vector(t, v_len);
+ let v_type = ast::Type::Vector(t, v_len);
(align, v_type, name, Vec::new())
},
}
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 1f647bd..4ba5729 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -714,12 +714,13 @@ fn convert_dynamic_shared_memory_usage<'input>(
let mut extern_shared_decls = HashMap::new();
for dir in module.iter() {
match dir {
- Directive::Variable(var) => {
- if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) =
- var.v_type
- {
- extern_shared_decls.insert(var.name, p_type);
- }
+ Directive::Variable(ast::Variable {
+ v_type: ast::Type::Pointer(p_type, ast::LdStateSpace::Shared),
+ state_space: ast::StateSpace::Shared,
+ name,
+ ..
+ }) => {
+ extern_shared_decls.insert(*name, p_type.clone());
}
_ => {}
}
@@ -796,25 +797,27 @@ fn convert_dynamic_shared_memory_usage<'input>(
let shared_id_param = new_id();
spirv_decl.input.push({
ast::Variable {
+ name: shared_id_param,
align: None,
v_type: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::PointerType::Scalar(ast::ScalarType::B8),
ast::LdStateSpace::Shared,
),
+ state_space: ast::StateSpace::Param,
array_init: Vec::new(),
- name: shared_id_param,
}
});
spirv_decl.uses_shared_mem = true;
let shared_var_id = new_id();
let shared_var = ExpandedStatement::Variable(ast::Variable {
- align: None,
name: shared_var_id,
- array_init: Vec::new(),
- v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
- ast::ScalarType::B8,
+ align: None,
+ v_type: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::B8),
ast::LdStateSpace::Shared,
- )),
+ ),
+ state_space: ast::StateSpace::Reg,
+ array_init: Vec::new(),
});
let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
arg: ast::Arg2St {
@@ -851,7 +854,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
fn replace_uses_of_shared_memory<'a>(
result: &mut Vec<ExpandedStatement>,
new_id: &mut impl FnMut() -> spirv::Word,
- extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
+ extern_shared_decls: &HashMap<spirv::Word, ast::PointerType>,
methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
shared_id_param: spirv::Word,
shared_var_id: spirv::Word,
@@ -864,14 +867,17 @@ fn replace_uses_of_shared_memory<'a>(
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
if methods_using_extern_shared.contains(&MethodName::Func(call.func)) {
- call.param_list
- .push((shared_id_param, ast::FnArgumentType::Shared));
+ call.param_list.push((
+ shared_id_param,
+ ast::Type::Scalar(ast::ScalarType::B8),
+ ast::StateSpace::Shared,
+ ));
}
result.push(Statement::Call(call))
}
statement => {
let new_statement = statement.map_id(&mut |id, _| {
- if let Some(typ) = extern_shared_decls.get(&id) {
+ if let Some(ast::PointerType::Scalar(typ)) = extern_shared_decls.get(&id) {
if *typ == ast::ScalarType::B8 {
return shared_var_id;
}
@@ -1067,7 +1073,7 @@ fn emit_function_header<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
defined_globals: &GlobalStringIdResolver<'a>,
- synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
+ synthetic_globals: &[ast::Variable<spirv::Word>],
func_decl: &SpirvMethodDecl<'a>,
_denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
@@ -1204,9 +1210,9 @@ fn translate_directive<'input>(
fn translate_variable<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
- var: ast::Variable<ast::VariableType, &'a str>,
-) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
- let (space, var_type) = var.v_type.to_type();
+ var: ast::Variable<&'a str>,
+) -> Result<ast::Variable<spirv::Word>, TranslateError> {
+ let (space, var_type) = (var.state_space, var.v_type.clone());
let mut is_variable = false;
let var_type = match space {
ast::StateSpace::Reg => {
@@ -1226,10 +1232,12 @@ fn translate_variable<'a>(
}
}
ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
+ ast::StateSpace::Generic => todo!(),
};
Ok(ast::Variable {
align: var.align,
v_type: var.v_type,
+ state_space: var.state_space,
name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
array_init: var.array_init,
})
@@ -1279,6 +1287,7 @@ fn expand_kernel_params<'a, 'b>(
false,
),
v_type: a.v_type.clone(),
+ state_space: a.state_space,
align: a.align,
array_init: Vec::new(),
})
@@ -1291,14 +1300,11 @@ fn expand_fn_params<'a, 'b>(
args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
args.map(|a| {
- let is_variable = match a.v_type {
- ast::FnArgumentType::Reg(_) => true,
- _ => false,
- };
- let var_type = a.v_type.to_func_type();
+ let is_variable = a.state_space == ast::StateSpace::Reg;
Ok(ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some(var_type), is_variable),
+ name: fn_resolver.add_def(a.name, Some(a.v_type.clone()), is_variable),
v_type: a.v_type.clone(),
+ state_space: a.state_space,
align: a.align,
array_init: Vec::new(),
})
@@ -1444,10 +1450,7 @@ fn extract_globals<'input, 'b>(
sorted_statements: Vec<ExpandedStatement>,
ptx_impl_imports: &mut HashMap<String, Directive>,
id_def: &mut NumericIdResolver,
-) -> (
- Vec<ExpandedStatement>,
- Vec<ast::Variable<ast::VariableType, spirv::Word>>,
-) {
+) -> (Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>) {
let mut local = Vec::with_capacity(sorted_statements.len());
let mut global = Vec::new();
for statement in sorted_statements {
@@ -1456,7 +1459,7 @@ fn extract_globals<'input, 'b>(
var
@
ast::Variable {
- v_type: ast::VariableType::Shared(_),
+ state_space: ast::StateSpace::Shared,
..
},
)
@@ -1464,7 +1467,7 @@ fn extract_globals<'input, 'b>(
var
@
ast::Variable {
- v_type: ast::VariableType::Global(_),
+ state_space: ast::StateSpace::Global,
..
},
) => global.push(var),
@@ -1592,10 +1595,10 @@ fn convert_to_typed_statements(
let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params);
let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args
.into_iter()
- .partition(|(_, arg_type)| arg_type.is_param());
+ .partition(|(_, _, space)| *space == ast::StateSpace::Param);
let normalized_input_args = out_params
.into_iter()
- .map(|(id, typ)| (ast::Operand::Reg(id), typ))
+ .map(|(id, typ, space)| (ast::Operand::Reg(id), typ, space))
.chain(in_args.into_iter())
.collect();
let resolved_call = ResolvedCall {
@@ -1744,7 +1747,8 @@ fn to_ptx_impl_atomic_call(
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
+ v_type: ast::Type::Scalar(scalar_typ),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
@@ -1752,15 +1756,15 @@ fn to_ptx_impl_atomic_call(
vec![
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
- typ, ptr_space,
- )),
+ v_type: ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
+ v_type: ast::Type::Scalar(scalar_typ),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
@@ -1789,18 +1793,17 @@ fn to_ptx_impl_atomic_call(
Statement::Call(ResolvedCall {
uniform: false,
func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
- )],
+ ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
param_list: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(typ, ptr_space)),
+ ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
+ ast::Type::Scalar(scalar_typ),
+ ast::StateSpace::Reg,
),
],
})
@@ -1827,7 +1830,8 @@ fn to_ptx_impl_bfe_call(
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
@@ -1835,23 +1839,22 @@ fn to_ptx_impl_bfe_call(
vec![
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
@@ -1880,22 +1883,22 @@ fn to_ptx_impl_bfe_call(
Statement::Call(ResolvedCall {
uniform: false,
func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- )],
+ ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
param_list: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
(
arg.src3,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
],
})
@@ -1920,7 +1923,8 @@ fn to_ptx_impl_bfi_call(
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
}],
@@ -1928,29 +1932,29 @@ fn to_ptx_impl_bfi_call(
vec![
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
name: id_defs.new_non_variable(None),
array_init: Vec::new(),
},
@@ -1979,26 +1983,27 @@ fn to_ptx_impl_bfi_call(
Statement::Call(ResolvedCall {
uniform: false,
func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- )],
+ ret_params: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
param_list: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src3,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
(
arg.src4,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
],
})
@@ -2006,12 +2011,12 @@ fn to_ptx_impl_bfi_call(
fn to_resolved_fn_args<T>(
params: Vec<T>,
- params_decl: &[ast::FnArgumentType],
-) -> Vec<(T, ast::FnArgumentType)> {
+ params_decl: &[(ast::Type, ast::StateSpace)],
+) -> Vec<(T, ast::Type, ast::StateSpace)> {
params
.into_iter()
.zip(params_decl.iter())
- .map(|(id, typ)| (id, typ.clone()))
+ .map(|(id, (typ, space))| (id, typ.clone(), *space))
.collect::<Vec<_>>()
}
@@ -2096,50 +2101,38 @@ fn normalize_predicates(
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
- ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>,
+ _: &'a ast::MethodDecl<'b, spirv::Word>,
fn_decl: &mut SpirvMethodDecl,
) -> Result<Vec<TypedStatement>, TranslateError> {
- let is_func = match ast_fn_decl {
- ast::MethodDecl::Func(..) => true,
- ast::MethodDecl::Kernel { .. } => false,
- };
let mut result = Vec::with_capacity(func.len());
for arg in fn_decl.output.iter() {
- match type_to_variable_type(&arg.v_type, is_func)? {
- Some(var_type) => {
- result.push(Statement::Variable(ast::Variable {
- align: arg.align,
- v_type: var_type,
- name: arg.name,
- array_init: arg.array_init.clone(),
- }));
- }
- None => return Err(error_unreachable()),
- }
+ result.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: arg.state_space,
+ name: arg.name,
+ array_init: arg.array_init.clone(),
+ }));
}
for spirv_arg in fn_decl.input.iter_mut() {
- match type_to_variable_type(&spirv_arg.v_type, is_func)? {
- Some(var_type) => {
- let typ = spirv_arg.v_type.clone();
- let new_id = id_def.new_non_variable(Some(typ.clone()));
- result.push(Statement::Variable(ast::Variable {
- align: spirv_arg.align,
- v_type: var_type,
- name: spirv_arg.name,
- array_init: spirv_arg.array_init.clone(),
- }));
- result.push(Statement::StoreVar(StoreVarDetails {
- arg: ast::Arg2St {
- src1: spirv_arg.name,
- src2: new_id,
- },
- typ,
- member_index: None,
- }));
- spirv_arg.name = new_id;
- }
- None => {}
- }
+ let typ = spirv_arg.v_type.clone();
+ let new_id = id_def.new_non_variable(Some(typ.clone()));
+ result.push(Statement::Variable(ast::Variable {
+ align: spirv_arg.align,
+ v_type: spirv_arg.v_type.clone(),
+ state_space: spirv_arg.state_space,
+ name: spirv_arg.name,
+ array_init: spirv_arg.array_init.clone(),
+ }));
+ result.push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::Arg2St {
+ src1: spirv_arg.name,
+ src2: new_id,
+ },
+ typ,
+ member_index: None,
+ }));
+ spirv_arg.name = new_id;
}
for s in func {
match s {
@@ -2197,41 +2190,6 @@ fn insert_mem_ssa_statements<'a, 'b>(
Ok(result)
}
-fn type_to_variable_type(
- t: &ast::Type,
- is_func: bool,
-) -> Result<Option<ast::VariableType>, TranslateError> {
- Ok(match t {
- ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
- ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
- (*typ)
- .try_into()
- .map_err(|_| TranslateError::MismatchedType)?,
- *len,
- ))),
- ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array(
- (*typ)
- .try_into()
- .map_err(|_| TranslateError::MismatchedType)?,
- len.clone(),
- ))),
- ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
- if is_func {
- return Ok(None);
- }
- Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
- scalar_type
- .clone()
- .try_into()
- .map_err(|_| error_unreachable())?,
- (*space).try_into().map_err(|_| error_unreachable())?,
- )))
- }
- ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
- _ => return Err(error_unreachable()),
- })
-}
-
trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
fn visit(
self,
@@ -2398,11 +2356,13 @@ fn expand_arguments<'a, 'b>(
Statement::Variable(ast::Variable {
align,
v_type,
+ state_space,
name,
array_init,
}) => result.push(Statement::Variable(ast::Variable {
align,
v_type,
+ state_space,
name,
array_init,
})),
@@ -2784,8 +2744,8 @@ fn insert_implicit_conversions_impl(
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- spirv_input: &[ast::Variable<ast::Type, spirv::Word>],
- spirv_output: &[ast::Variable<ast::Type, spirv::Word>],
+ spirv_input: &[ast::Variable<spirv::Word>],
+ spirv_output: &[ast::Variable<spirv::Word>],
) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn(
builder,
@@ -2822,8 +2782,8 @@ fn emit_function_body_ops(
Statement::Label(_) => (),
Statement::Call(call) => {
let (result_type, result_id) = match &*call.ret_params {
- [(id, typ)] => (
- map.get_or_add(builder, SpirvType::from(typ.to_func_type())),
+ [(id, typ, _)] => (
+ map.get_or_add(builder, SpirvType::from(typ.clone())),
Some(*id),
),
[] => (map.void(), None),
@@ -2832,7 +2792,7 @@ fn emit_function_body_ops(
let arg_list = call
.param_list
.iter()
- .map(|(id, _)| *id)
+ .map(|(id, _, _)| *id)
.collect::<Vec<_>>();
builder.function_call(result_type, result_id, call.func, arg_list)?;
}
@@ -3602,14 +3562,16 @@ fn vec_repr<T: Copy>(t: T) -> Vec<u8> {
fn emit_variable(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- var: &ast::Variable<ast::VariableType, spirv::Word>,
+ var: &ast::Variable<spirv::Word>,
) -> Result<(), TranslateError> {
- let (must_init, st_class) = match var.v_type {
- ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
+ let (must_init, st_class) = match var.state_space {
+ ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
(false, spirv::StorageClass::Function)
}
- ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
- ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup),
+ ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
+ ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
+ ast::StateSpace::Const => todo!(),
+ ast::StateSpace::Generic => todo!(),
};
let initalizer = if var.array_init.len() > 0 {
Some(map.get_or_add_constant(
@@ -4460,12 +4422,12 @@ fn expand_map_variables<'a, 'b>(
ast::Statement::Variable(var) => {
let mut var_type = ast::Type::from(var.var.v_type.clone());
let mut is_variable = false;
- var_type = match var.var.v_type {
- ast::VariableType::Reg(_) => {
+ var_type = match var.var.state_space {
+ ast::StateSpace::Reg => {
is_variable = true;
var_type
}
- ast::VariableType::Shared(_) => {
+ ast::StateSpace::Shared => {
// If it's a pointer it will be translated to a method parameter later
if let ast::Type::Pointer(..) = var_type {
is_variable = true;
@@ -4474,15 +4436,11 @@ fn expand_map_variables<'a, 'b>(
var_type.param_pointer_to(ast::LdStateSpace::Shared)?
}
}
- ast::VariableType::Global(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Global)?
- }
- ast::VariableType::Param(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Param)?
- }
- ast::VariableType::Local(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Local)?
- }
+ ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
+ ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
+ ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
+ ast::StateSpace::Const => todo!(),
+ ast::StateSpace::Generic => todo!(),
};
match var.count {
Some(count) => {
@@ -4490,6 +4448,7 @@ fn expand_map_variables<'a, 'b>(
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init.clone(),
}))
@@ -4500,6 +4459,7 @@ fn expand_map_variables<'a, 'b>(
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init,
}));
@@ -4659,10 +4619,11 @@ fn convert_to_stateful_memory_access<'a>(
align: None,
name: new_id,
array_init: Vec::new(),
- v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
- ast::ScalarType::U8,
+ v_type: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
ast::LdStateSpace::Global,
- )),
+ ),
+ state_space: ast::StateSpace::Reg,
}));
remapped_ids.insert(reg, new_id);
}
@@ -5052,8 +5013,8 @@ struct GlobalStringIdResolver<'input> {
}
pub struct FnDecl {
- ret_vals: Vec<ast::FnArgumentType>,
- params: Vec<ast::FnArgumentType>,
+ ret_vals: Vec<(ast::Type, ast::StateSpace)>,
+ params: Vec<(ast::Type, ast::StateSpace)>,
}
impl<'a> GlobalStringIdResolver<'a> {
@@ -5137,8 +5098,14 @@ impl<'a> GlobalStringIdResolver<'a> {
self.fns.insert(
name_id,
FnDecl {
- ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(),
- params: params_ids.iter().map(|p| p.v_type.clone()).collect(),
+ ret_vals: ret_params_ids
+ .iter()
+ .map(|p| (p.v_type.clone(), p.state_space))
+ .collect(),
+ params: params_ids
+ .iter()
+ .map(|p| (p.v_type.clone(), p.state_space))
+ .collect(),
},
);
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
@@ -5314,7 +5281,7 @@ impl<'b> MutableNumericIdResolver<'b> {
enum Statement<I, P: ast::ArgParams> {
Label(u32),
- Variable(ast::Variable<ast::VariableType, P::Id>),
+ Variable(ast::Variable<P::Id>),
Instruction(I),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
@@ -5352,16 +5319,17 @@ impl ExpandedStatement {
Statement::StoreVar(details)
}
Statement::Call(mut call) => {
- for (id, typ) in call.ret_params.iter_mut() {
- let is_dst = match typ {
- ast::FnArgumentType::Reg(_) => true,
- ast::FnArgumentType::Param(_) => false,
- ast::FnArgumentType::Shared => false,
+ for (id, _, space) in call.ret_params.iter_mut() {
+ let is_dst = match space {
+ ast::StateSpace::Reg => true,
+ ast::StateSpace::Param => false,
+ ast::StateSpace::Shared => false,
+ _ => todo!(),
};
*id = f(*id, is_dst);
}
call.func = f(call.func, false);
- for (id, _) in call.param_list.iter_mut() {
+ for (id, _, _) in call.param_list.iter_mut() {
*id = f(*id, false);
}
Statement::Call(call)
@@ -5502,9 +5470,9 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab
struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
- pub ret_params: Vec<(P::Id, ast::FnArgumentType)>,
+ pub ret_params: Vec<(P::Id, ast::Type, ast::StateSpace)>,
pub func: P::Id,
- pub param_list: Vec<(P::Operand, ast::FnArgumentType)>,
+ pub param_list: Vec<(P::Operand, ast::Type, ast::StateSpace)>,
}
impl<T: ast::ArgParams> ResolvedCall<T> {
@@ -5526,16 +5494,16 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
let ret_params = self
.ret_params
.into_iter()
- .map::<Result<_, TranslateError>, _>(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.id(
ArgumentDescriptor {
op: id,
- is_dst: !typ.is_param(),
- sema: typ.semantics(),
+ is_dst: space != ast::StateSpace::Param,
+ sema: space.semantics(),
},
- Some(&typ.to_func_type()),
+ Some(&typ),
)?;
- Ok((new_id, typ))
+ Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
let func = visitor.id(
@@ -5549,16 +5517,16 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
let param_list = self
.param_list
.into_iter()
- .map::<Result<_, TranslateError>, _>(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.operand(
ArgumentDescriptor {
op: id,
is_dst: false,
- sema: typ.semantics(),
+ sema: space.semantics(),
},
- &typ.to_func_type(),
+ &typ,
)?;
- Ok((new_id, typ))
+ Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedCall {
@@ -5738,14 +5706,14 @@ impl ArgParamsEx for ExpandedArgParams {
}
enum Directive<'input> {
- Variable(ast::Variable<ast::VariableType, spirv::Word>),
+ Variable(ast::Variable<spirv::Word>),
Method(Function<'input>),
}
struct Function<'input> {
pub func_decl: ast::MethodDecl<'input, spirv::Word>,
pub spirv_decl: SpirvMethodDecl<'input>,
- pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
+ pub globals: Vec<ast::Variable<spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
@@ -7300,16 +7268,6 @@ impl ast::LdStateSpace {
}
}
-impl From<ast::FnArgumentType> for ast::VariableType {
- fn from(t: ast::FnArgumentType) -> Self {
- match t {
- ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
- ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
- ast::FnArgumentType::Shared => todo!(),
- }
- }
-}
-
impl<T> ast::Operand<T> {
fn underlying(&self) -> Option<&T> {
match self {
@@ -7362,12 +7320,13 @@ impl ast::AtomSemantics {
}
}
-impl ast::FnArgumentType {
- fn semantics(&self) -> ArgumentSemantics {
+impl ast::StateSpace {
+ fn semantics(self) -> ArgumentSemantics {
match self {
- ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default,
- ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer,
- ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer,
+ ast::StateSpace::Reg => ArgumentSemantics::Default,
+ ast::StateSpace::Param => ArgumentSemantics::RegisterPointer,
+ ast::StateSpace::Shared => ArgumentSemantics::PhysicalPointer,
+ _ => todo!(),
}
}
}
@@ -7677,8 +7636,8 @@ impl<'a> ast::MethodDecl<'a, &'a str> {
}
struct SpirvMethodDecl<'input> {
- input: Vec<ast::Variable<ast::Type, spirv::Word>>,
- output: Vec<ast::Variable<ast::Type, spirv::Word>>,
+ input: Vec<ast::Variable<spirv::Word>>,
+ output: Vec<ast::Variable<spirv::Word>>,
name: MethodName<'input>,
uses_shared_mem: bool,
}
@@ -7689,33 +7648,28 @@ impl<'input> SpirvMethodDecl<'input> {
ast::MethodDecl::Kernel { in_args, .. } => {
let spirv_input = in_args
.iter()
- .map(|var| {
- let v_type = match &var.v_type {
- ast::KernelArgumentType::Normal(t) => {
- ast::FnArgumentType::Param(t.clone())
- }
- ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared,
- };
- ast::Variable {
- name: var.name,
- align: var.align,
- v_type: v_type.to_kernel_type(),
- array_init: var.array_init.clone(),
- }
+ .map(|var| ast::Variable {
+ name: var.name,
+ align: var.align,
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
+ array_init: var.array_init.clone(),
})
.collect();
(spirv_input, Vec::new())
}
ast::MethodDecl::Func(out_args, _, in_args) => {
- let (param_output, non_param_output): (Vec<_>, Vec<_>) =
- out_args.iter().partition(|var| var.v_type.is_param());
+ let (param_output, non_param_output): (Vec<_>, Vec<_>) = out_args
+ .iter()
+ .partition(|var| var.state_space == ast::StateSpace::Param);
let spirv_output = non_param_output
.into_iter()
.cloned()
.map(|var| ast::Variable {
name: var.name,
align: var.align,
- v_type: var.v_type.to_func_type(),
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
array_init: var.array_init.clone(),
})
.collect();
@@ -7726,7 +7680,8 @@ impl<'input> SpirvMethodDecl<'input> {
.map(|var| ast::Variable {
name: var.name,
align: var.align,
- v_type: var.v_type.to_func_type(),
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
array_init: var.array_init.clone(),
})
.collect();