diff options
Diffstat (limited to 'ptx/src/ast.rs')
-rw-r--r-- | ptx/src/ast.rs | 127 |
1 files changed, 93 insertions, 34 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 1e90eba..1cbe721 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -28,8 +28,11 @@ quick_error! { } } -macro_rules! sub_scalar_type { +macro_rules! sub_enum { ($name:ident { $($variant:ident),+ $(,)? }) => { + sub_enum!{ $name : ScalarType { $($variant),+ } } + }; + ($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => { #[derive(PartialEq, Eq, Clone, Copy)] pub enum $name { $( @@ -37,23 +40,23 @@ macro_rules! sub_scalar_type { )+ } - impl From<$name> for ScalarType { - fn from(t: $name) -> ScalarType { + impl From<$name> for $base_type { + fn from(t: $name) -> $base_type { match t { $( - $name::$variant => ScalarType::$variant, + $name::$variant => $base_type::$variant, )+ } } } - impl std::convert::TryFrom<ScalarType> for $name { + impl std::convert::TryFrom<$base_type> for $name { type Error = (); - fn try_from(t: ScalarType) -> Result<Self, Self::Error> { + fn try_from(t: $base_type) -> Result<Self, Self::Error> { match t { $( - ScalarType::$variant => Ok($name::$variant), + $base_type::$variant => Ok($name::$variant), )+ _ => Err(()), } @@ -64,6 +67,13 @@ macro_rules! sub_scalar_type { macro_rules! sub_type { ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { + sub_type! { $type_name : Type { + $( + $variant ($($field_type),+), + )+ + }} + }; + ($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { #[derive(PartialEq, Eq, Clone)] pub enum $type_name { $( @@ -71,26 +81,26 @@ macro_rules! sub_type { )+ } - impl From<$type_name> for Type { + impl From<$type_name> for $base_type { #[allow(non_snake_case)] - fn from(t: $type_name) -> Type { + fn from(t: $type_name) -> $base_type { match t { $( - $type_name::$variant ( $($field_type),+ ) => Type::$variant ( $($field_type.into()),+), + $type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+), )+ } } } - impl std::convert::TryFrom<Type> for $type_name { + impl std::convert::TryFrom<$base_type> for $type_name { type Error = (); #[allow(non_snake_case)] #[allow(unreachable_patterns)] - fn try_from(t: Type) -> Result<Self, Self::Error> { + fn try_from(t: $base_type) -> Result<Self, Self::Error> { match t { $( - Type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), + $base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), )+ _ => Err(()), } @@ -99,10 +109,12 @@ macro_rules! sub_type { }; } +// Pointer is used when doing SLM converison to SPIRV sub_type! { VariableRegType { Scalar(ScalarType), Vector(SizedScalarType, u8), + Pointer(SizedScalarType, PointerStateSpace) } } @@ -146,13 +158,13 @@ sub_type! { // .param .b32 foobar[] sub_type! { VariableParamType { - Scalar(ParamScalarType), + Scalar(LdStScalarType), Array(SizedScalarType, VecU32), Pointer(SizedScalarType, PointerStateSpace), } } -sub_scalar_type!(SizedScalarType { +sub_enum!(SizedScalarType { B8, B16, B32, @@ -171,7 +183,7 @@ sub_scalar_type!(SizedScalarType { F64, }); -sub_scalar_type!(ParamScalarType { +sub_enum!(LdStScalarType { B8, B16, B32, @@ -232,7 +244,11 @@ pub enum Directive<'a, P: ArgParams> { pub enum MethodDecl<'a, ID> { Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>), - Kernel(&'a str, Vec<KernelArgument<ID>>), + Kernel { + name: &'a str, + in_args: Vec<KernelArgument<ID>>, + uses_shared_mem: bool, + }, } pub type FnArgument<ID> = Variable<FnArgumentType, ID>; @@ -262,25 +278,52 @@ impl From<FnArgumentType> for Type { match t { FnArgumentType::Reg(x) => x.into(), FnArgumentType::Param(x) => x.into(), - FnArgumentType::Shared => Type::Scalar(ScalarType::B64), + FnArgumentType::Shared => { + Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) + } } } } -#[derive(PartialEq, Eq, Clone, Copy)] -pub enum PointerStateSpace { - Global, - Const, - Shared, - Param, -} +sub_enum!( + PointerStateSpace : LdStateSpace { + Global, + Const, + Shared, + Param, + } +); #[derive(PartialEq, Eq, Clone)] pub enum Type { Scalar(ScalarType), Vector(ScalarType, u8), Array(ScalarType, Vec<u32>), - Pointer(ScalarType, PointerStateSpace), + Pointer(PointerType, LdStateSpace), +} + +sub_type! { + PointerType { + Scalar(ScalarType), + Vector(ScalarType, u8), + } +} + +impl From<SizedScalarType> for PointerType { + fn from(t: SizedScalarType) -> Self { + PointerType::Scalar(t.into()) + } +} + +impl TryFrom<PointerType> for SizedScalarType { + type Error = (); + + fn try_from(value: PointerType) -> Result<Self, Self::Error> { + match value { + PointerType::Scalar(t) => Ok(t.try_into()?), + PointerType::Vector(_, _) => Err(()), + } + } } #[derive(PartialEq, Eq, Hash, Clone, Copy)] @@ -304,7 +347,7 @@ pub enum ScalarType { Pred, } -sub_scalar_type!(IntType { +sub_enum!(IntType { U8, U16, U32, @@ -315,9 +358,9 @@ sub_scalar_type!(IntType { S64 }); -sub_scalar_type!(UIntType { U8, U16, U32, U64 }); +sub_enum!(UIntType { U8, U16, U32, U64 }); -sub_scalar_type!(SIntType { S8, S16, S32, S64 }); +sub_enum!(SIntType { S8, S16, S32, S64 }); impl IntType { pub fn is_signed(self) -> bool { @@ -341,7 +384,7 @@ impl IntType { } } -sub_scalar_type!(FloatType { +sub_enum!(FloatType { F16, F16x2, F32, @@ -615,7 +658,23 @@ pub struct LdDetails { pub qualifier: LdStQualifier, pub state_space: LdStateSpace, pub caching: LdCacheOperator, - pub typ: Type, + pub typ: LdStType, +} + +sub_type! { + LdStType { + Scalar(LdStScalarType), + Vector(LdStScalarType, u8), + } +} + +impl From<LdStType> for PointerType { + fn from(t: LdStType) -> Self { + match t { + LdStType::Scalar(t) => PointerType::Scalar(t.into()), + LdStType::Vector(t, len) => PointerType::Vector(t.into(), len), + } + } } #[derive(Copy, Clone, PartialEq, Eq)] @@ -860,7 +919,7 @@ pub enum ShlType { B64, } -sub_scalar_type!(ShrType { +sub_enum!(ShrType { B16, B32, B64, @@ -876,7 +935,7 @@ pub struct StData { pub qualifier: LdStQualifier, pub state_space: StStateSpace, pub caching: StCacheOperator, - pub typ: Type, + pub typ: LdStType, } #[derive(PartialEq, Eq, Copy, Clone)] @@ -900,7 +959,7 @@ pub struct RetData { pub uniform: bool, } -sub_scalar_type!(OrType { +sub_enum!(OrType { Pred, B16, B32, |