summaryrefslogtreecommitdiffhomepage
path: root/ptx/src/ast.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/ast.rs')
-rw-r--r--ptx/src/ast.rs84
1 files changed, 80 insertions, 4 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index c6510da..1e90eba 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -1,4 +1,5 @@
-use std::{convert::From, mem, num::ParseFloatError, str::FromStr};
+use std::convert::TryInto;
+use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr};
use std::{marker::PhantomData, num::ParseIntError};
use half::f16;
@@ -22,6 +23,8 @@ quick_error! {
WrongVectorElement {}
MultiArrayVariable {}
ZeroDimensionArray {}
+ ArrayInitalizer {}
+ NonExternPointer {}
}
}
@@ -78,6 +81,21 @@ macro_rules! sub_type {
}
}
}
+
+ impl std::convert::TryFrom<Type> for $type_name {
+ type Error = ();
+
+ #[allow(non_snake_case)]
+ #[allow(unreachable_patterns)]
+ fn try_from(t: Type) -> Result<Self, Self::Error> {
+ match t {
+ $(
+ Type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
+ )+
+ _ => Err(()),
+ }
+ }
+ }
};
}
@@ -98,14 +116,39 @@ sub_type! {
}
}
+impl TryFrom<VariableGlobalType> for VariableLocalType {
+ type Error = PtxError;
+
+ fn try_from(value: VariableGlobalType) -> Result<Self, Self::Error> {
+ match value {
+ VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)),
+ VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)),
+ VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)),
+ VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray),
+ }
+ }
+}
+
+sub_type! {
+ VariableGlobalType {
+ Scalar(SizedScalarType),
+ Vector(SizedScalarType, u8),
+ Array(SizedScalarType, VecU32),
+ Pointer(SizedScalarType, PointerStateSpace),
+ }
+}
+
// For some weird reson this is illegal:
// .param .f16x2 foobar;
// but this is legal:
// .param .f16x2 foobar[1];
+// even more interestingly this is legal, but only in .func (not in .entry):
+// .param .b32 foobar[]
sub_type! {
VariableParamType {
Scalar(ParamScalarType),
Array(SizedScalarType, VecU32),
+ Pointer(SizedScalarType, PointerStateSpace),
}
}
@@ -193,7 +236,7 @@ pub enum MethodDecl<'a, ID> {
}
pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
-pub type KernelArgument<ID> = Variable<VariableParamType, ID>;
+pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>;
pub struct Function<'a, ID, S> {
pub func_directive: MethodDecl<'a, ID>,
@@ -206,6 +249,12 @@ pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a
pub enum FnArgumentType {
Reg(VariableRegType),
Param(VariableParamType),
+ Shared,
+}
+#[derive(PartialEq, Eq, Clone)]
+pub enum KernelArgumentType {
+ Normal(VariableParamType),
+ Shared,
}
impl From<FnArgumentType> for Type {
@@ -213,15 +262,25 @@ impl From<FnArgumentType> for Type {
match t {
FnArgumentType::Reg(x) => x.into(),
FnArgumentType::Param(x) => x.into(),
+ FnArgumentType::Shared => Type::Scalar(ScalarType::B64),
}
}
}
-#[derive(PartialEq, Eq, Hash, Clone)]
+#[derive(PartialEq, Eq, Clone, Copy)]
+pub enum PointerStateSpace {
+ Global,
+ Const,
+ Shared,
+ Param,
+}
+
+#[derive(PartialEq, Eq, Clone)]
pub enum Type {
Scalar(ScalarType),
Vector(ScalarType, u8),
Array(ScalarType, Vec<u32>),
+ Pointer(ScalarType, PointerStateSpace),
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
@@ -343,7 +402,8 @@ pub enum VariableType {
Reg(VariableRegType),
Local(VariableLocalType),
Param(VariableParamType),
- Global(VariableLocalType),
+ Global(VariableGlobalType),
+ Shared(VariableGlobalType),
}
impl VariableType {
@@ -353,6 +413,7 @@ impl VariableType {
VariableType::Local(t) => (StateSpace::Local, t.clone().into()),
VariableType::Param(t) => (StateSpace::Param, t.clone().into()),
VariableType::Global(t) => (StateSpace::Global, t.clone().into()),
+ VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()),
}
}
}
@@ -364,6 +425,7 @@ impl From<VariableType> for Type {
VariableType::Local(t) => t.into(),
VariableType::Param(t) => t.into(),
VariableType::Global(t) => t.into(),
+ VariableType::Shared(t) => t.into(),
}
}
}
@@ -1039,6 +1101,20 @@ impl<'a> NumsOrArrays<'a> {
}
}
+pub enum ArrayOrPointer {
+ Array { dimensions: Vec<u32>, init: Vec<u8> },
+ Pointer,
+}
+
+bitflags! {
+ pub struct LinkingDirective: u8 {
+ const NONE = 0b000;
+ const EXTERN = 0b001;
+ const VISIBLE = 0b10;
+ const WEAK = 0b100;
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;