diff options
Diffstat (limited to 'ptx/src/ast.rs')
-rw-r--r-- | ptx/src/ast.rs | 84 |
1 files changed, 80 insertions, 4 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index c6510da..1e90eba 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,4 +1,5 @@ -use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; +use std::convert::TryInto; +use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; use half::f16; @@ -22,6 +23,8 @@ quick_error! { WrongVectorElement {} MultiArrayVariable {} ZeroDimensionArray {} + ArrayInitalizer {} + NonExternPointer {} } } @@ -78,6 +81,21 @@ macro_rules! sub_type { } } } + + impl std::convert::TryFrom<Type> for $type_name { + type Error = (); + + #[allow(non_snake_case)] + #[allow(unreachable_patterns)] + fn try_from(t: Type) -> Result<Self, Self::Error> { + match t { + $( + Type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), + )+ + _ => Err(()), + } + } + } }; } @@ -98,14 +116,39 @@ sub_type! { } } +impl TryFrom<VariableGlobalType> for VariableLocalType { + type Error = PtxError; + + fn try_from(value: VariableGlobalType) -> Result<Self, Self::Error> { + match value { + VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)), + VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)), + VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)), + VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray), + } + } +} + +sub_type! { + VariableGlobalType { + Scalar(SizedScalarType), + Vector(SizedScalarType, u8), + Array(SizedScalarType, VecU32), + Pointer(SizedScalarType, PointerStateSpace), + } +} + // For some weird reson this is illegal: // .param .f16x2 foobar; // but this is legal: // .param .f16x2 foobar[1]; +// even more interestingly this is legal, but only in .func (not in .entry): +// .param .b32 foobar[] sub_type! { VariableParamType { Scalar(ParamScalarType), Array(SizedScalarType, VecU32), + Pointer(SizedScalarType, PointerStateSpace), } } @@ -193,7 +236,7 @@ pub enum MethodDecl<'a, ID> { } pub type FnArgument<ID> = Variable<FnArgumentType, ID>; -pub type KernelArgument<ID> = Variable<VariableParamType, ID>; +pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>; pub struct Function<'a, ID, S> { pub func_directive: MethodDecl<'a, ID>, @@ -206,6 +249,12 @@ pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a pub enum FnArgumentType { Reg(VariableRegType), Param(VariableParamType), + Shared, +} +#[derive(PartialEq, Eq, Clone)] +pub enum KernelArgumentType { + Normal(VariableParamType), + Shared, } impl From<FnArgumentType> for Type { @@ -213,15 +262,25 @@ impl From<FnArgumentType> for Type { match t { FnArgumentType::Reg(x) => x.into(), FnArgumentType::Param(x) => x.into(), + FnArgumentType::Shared => Type::Scalar(ScalarType::B64), } } } -#[derive(PartialEq, Eq, Hash, Clone)] +#[derive(PartialEq, Eq, Clone, Copy)] +pub enum PointerStateSpace { + Global, + Const, + Shared, + Param, +} + +#[derive(PartialEq, Eq, Clone)] pub enum Type { Scalar(ScalarType), Vector(ScalarType, u8), Array(ScalarType, Vec<u32>), + Pointer(ScalarType, PointerStateSpace), } #[derive(PartialEq, Eq, Hash, Clone, Copy)] @@ -343,7 +402,8 @@ pub enum VariableType { Reg(VariableRegType), Local(VariableLocalType), Param(VariableParamType), - Global(VariableLocalType), + Global(VariableGlobalType), + Shared(VariableGlobalType), } impl VariableType { @@ -353,6 +413,7 @@ impl VariableType { VariableType::Local(t) => (StateSpace::Local, t.clone().into()), VariableType::Param(t) => (StateSpace::Param, t.clone().into()), VariableType::Global(t) => (StateSpace::Global, t.clone().into()), + VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()), } } } @@ -364,6 +425,7 @@ impl From<VariableType> for Type { VariableType::Local(t) => t.into(), VariableType::Param(t) => t.into(), VariableType::Global(t) => t.into(), + VariableType::Shared(t) => t.into(), } } } @@ -1039,6 +1101,20 @@ impl<'a> NumsOrArrays<'a> { } } +pub enum ArrayOrPointer { + Array { dimensions: Vec<u32>, init: Vec<u8> }, + Pointer, +} + +bitflags! { + pub struct LinkingDirective: u8 { + const NONE = 0b000; + const EXTERN = 0b001; + const VISIBLE = 0b10; + const WEAK = 0b100; + } +} + #[cfg(test)] mod tests { use super::*; |