summaryrefslogtreecommitdiffhomepage
path: root/ptx/src/ast.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/ast.rs')
-rw-r--r--ptx/src/ast.rs127
1 files changed, 93 insertions, 34 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 1e90eba..1cbe721 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -28,8 +28,11 @@ quick_error! {
}
}
-macro_rules! sub_scalar_type {
+macro_rules! sub_enum {
($name:ident { $($variant:ident),+ $(,)? }) => {
+ sub_enum!{ $name : ScalarType { $($variant),+ } }
+ };
+ ($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => {
#[derive(PartialEq, Eq, Clone, Copy)]
pub enum $name {
$(
@@ -37,23 +40,23 @@ macro_rules! sub_scalar_type {
)+
}
- impl From<$name> for ScalarType {
- fn from(t: $name) -> ScalarType {
+ impl From<$name> for $base_type {
+ fn from(t: $name) -> $base_type {
match t {
$(
- $name::$variant => ScalarType::$variant,
+ $name::$variant => $base_type::$variant,
)+
}
}
}
- impl std::convert::TryFrom<ScalarType> for $name {
+ impl std::convert::TryFrom<$base_type> for $name {
type Error = ();
- fn try_from(t: ScalarType) -> Result<Self, Self::Error> {
+ fn try_from(t: $base_type) -> Result<Self, Self::Error> {
match t {
$(
- ScalarType::$variant => Ok($name::$variant),
+ $base_type::$variant => Ok($name::$variant),
)+
_ => Err(()),
}
@@ -64,6 +67,13 @@ macro_rules! sub_scalar_type {
macro_rules! sub_type {
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
+ sub_type! { $type_name : Type {
+ $(
+ $variant ($($field_type),+),
+ )+
+ }}
+ };
+ ($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
#[derive(PartialEq, Eq, Clone)]
pub enum $type_name {
$(
@@ -71,26 +81,26 @@ macro_rules! sub_type {
)+
}
- impl From<$type_name> for Type {
+ impl From<$type_name> for $base_type {
#[allow(non_snake_case)]
- fn from(t: $type_name) -> Type {
+ fn from(t: $type_name) -> $base_type {
match t {
$(
- $type_name::$variant ( $($field_type),+ ) => Type::$variant ( $($field_type.into()),+),
+ $type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+),
)+
}
}
}
- impl std::convert::TryFrom<Type> for $type_name {
+ impl std::convert::TryFrom<$base_type> for $type_name {
type Error = ();
#[allow(non_snake_case)]
#[allow(unreachable_patterns)]
- fn try_from(t: Type) -> Result<Self, Self::Error> {
+ fn try_from(t: $base_type) -> Result<Self, Self::Error> {
match t {
$(
- Type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
+ $base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
)+
_ => Err(()),
}
@@ -99,10 +109,12 @@ macro_rules! sub_type {
};
}
+// Pointer is used when doing SLM converison to SPIRV
sub_type! {
VariableRegType {
Scalar(ScalarType),
Vector(SizedScalarType, u8),
+ Pointer(SizedScalarType, PointerStateSpace)
}
}
@@ -146,13 +158,13 @@ sub_type! {
// .param .b32 foobar[]
sub_type! {
VariableParamType {
- Scalar(ParamScalarType),
+ Scalar(LdStScalarType),
Array(SizedScalarType, VecU32),
Pointer(SizedScalarType, PointerStateSpace),
}
}
-sub_scalar_type!(SizedScalarType {
+sub_enum!(SizedScalarType {
B8,
B16,
B32,
@@ -171,7 +183,7 @@ sub_scalar_type!(SizedScalarType {
F64,
});
-sub_scalar_type!(ParamScalarType {
+sub_enum!(LdStScalarType {
B8,
B16,
B32,
@@ -232,7 +244,11 @@ pub enum Directive<'a, P: ArgParams> {
pub enum MethodDecl<'a, ID> {
Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>),
- Kernel(&'a str, Vec<KernelArgument<ID>>),
+ Kernel {
+ name: &'a str,
+ in_args: Vec<KernelArgument<ID>>,
+ uses_shared_mem: bool,
+ },
}
pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
@@ -262,25 +278,52 @@ impl From<FnArgumentType> for Type {
match t {
FnArgumentType::Reg(x) => x.into(),
FnArgumentType::Param(x) => x.into(),
- FnArgumentType::Shared => Type::Scalar(ScalarType::B64),
+ FnArgumentType::Shared => {
+ Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
+ }
}
}
}
-#[derive(PartialEq, Eq, Clone, Copy)]
-pub enum PointerStateSpace {
- Global,
- Const,
- Shared,
- Param,
-}
+sub_enum!(
+ PointerStateSpace : LdStateSpace {
+ Global,
+ Const,
+ Shared,
+ Param,
+ }
+);
#[derive(PartialEq, Eq, Clone)]
pub enum Type {
Scalar(ScalarType),
Vector(ScalarType, u8),
Array(ScalarType, Vec<u32>),
- Pointer(ScalarType, PointerStateSpace),
+ Pointer(PointerType, LdStateSpace),
+}
+
+sub_type! {
+ PointerType {
+ Scalar(ScalarType),
+ Vector(ScalarType, u8),
+ }
+}
+
+impl From<SizedScalarType> for PointerType {
+ fn from(t: SizedScalarType) -> Self {
+ PointerType::Scalar(t.into())
+ }
+}
+
+impl TryFrom<PointerType> for SizedScalarType {
+ type Error = ();
+
+ fn try_from(value: PointerType) -> Result<Self, Self::Error> {
+ match value {
+ PointerType::Scalar(t) => Ok(t.try_into()?),
+ PointerType::Vector(_, _) => Err(()),
+ }
+ }
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
@@ -304,7 +347,7 @@ pub enum ScalarType {
Pred,
}
-sub_scalar_type!(IntType {
+sub_enum!(IntType {
U8,
U16,
U32,
@@ -315,9 +358,9 @@ sub_scalar_type!(IntType {
S64
});
-sub_scalar_type!(UIntType { U8, U16, U32, U64 });
+sub_enum!(UIntType { U8, U16, U32, U64 });
-sub_scalar_type!(SIntType { S8, S16, S32, S64 });
+sub_enum!(SIntType { S8, S16, S32, S64 });
impl IntType {
pub fn is_signed(self) -> bool {
@@ -341,7 +384,7 @@ impl IntType {
}
}
-sub_scalar_type!(FloatType {
+sub_enum!(FloatType {
F16,
F16x2,
F32,
@@ -615,7 +658,23 @@ pub struct LdDetails {
pub qualifier: LdStQualifier,
pub state_space: LdStateSpace,
pub caching: LdCacheOperator,
- pub typ: Type,
+ pub typ: LdStType,
+}
+
+sub_type! {
+ LdStType {
+ Scalar(LdStScalarType),
+ Vector(LdStScalarType, u8),
+ }
+}
+
+impl From<LdStType> for PointerType {
+ fn from(t: LdStType) -> Self {
+ match t {
+ LdStType::Scalar(t) => PointerType::Scalar(t.into()),
+ LdStType::Vector(t, len) => PointerType::Vector(t.into(), len),
+ }
+ }
}
#[derive(Copy, Clone, PartialEq, Eq)]
@@ -860,7 +919,7 @@ pub enum ShlType {
B64,
}
-sub_scalar_type!(ShrType {
+sub_enum!(ShrType {
B16,
B32,
B64,
@@ -876,7 +935,7 @@ pub struct StData {
pub qualifier: LdStQualifier,
pub state_space: StStateSpace,
pub caching: StCacheOperator,
- pub typ: Type,
+ pub typ: LdStType,
}
#[derive(PartialEq, Eq, Copy, Clone)]
@@ -900,7 +959,7 @@ pub struct RetData {
pub uniform: bool,
}
-sub_scalar_type!(OrType {
+sub_enum!(OrType {
Pred,
B16,
B32,