diff options
Diffstat (limited to 'ptx/src/ast.rs')
-rw-r--r-- | ptx/src/ast.rs | 2844 |
1 files changed, 1438 insertions, 1406 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index d81cd66..0281961 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,1406 +1,1438 @@ -use std::convert::TryInto; -use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr}; -use std::{marker::PhantomData, num::ParseIntError}; - -use half::f16; - -quick_error! { - #[derive(Debug)] - pub enum PtxError { - ParseInt (err: ParseIntError) { - from() - display("{}", err) - cause(err) - } - ParseFloat (err: ParseFloatError) { - from() - display("{}", err) - cause(err) - } - SyntaxError {} - NonF32Ftz {} - WrongArrayType {} - WrongVectorElement {} - MultiArrayVariable {} - ZeroDimensionArray {} - ArrayInitalizer {} - NonExternPointer {} - } -} - -macro_rules! sub_enum { - ($name:ident { $($variant:ident),+ $(,)? }) => { - sub_enum!{ $name : ScalarType { $($variant),+ } } - }; - ($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => { - #[derive(PartialEq, Eq, Clone, Copy)] - pub enum $name { - $( - $variant, - )+ - } - - impl From<$name> for $base_type { - fn from(t: $name) -> $base_type { - match t { - $( - $name::$variant => $base_type::$variant, - )+ - } - } - } - - impl std::convert::TryFrom<$base_type> for $name { - type Error = (); - - fn try_from(t: $base_type) -> Result<Self, Self::Error> { - match t { - $( - $base_type::$variant => Ok($name::$variant), - )+ - _ => Err(()), - } - } - } - }; -} - -macro_rules! sub_type { - ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - sub_type! { $type_name : Type { - $( - $variant ($($field_type),+), - )+ - }} - }; - ($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - #[derive(PartialEq, Eq, Clone)] - pub enum $type_name { - $( - $variant ($($field_type),+), - )+ - } - - impl From<$type_name> for $base_type { - #[allow(non_snake_case)] - fn from(t: $type_name) -> $base_type { - match t { - $( - $type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+), - )+ - } - } - } - - impl std::convert::TryFrom<$base_type> for $type_name { - type Error = (); - - #[allow(non_snake_case)] - #[allow(unreachable_patterns)] - fn try_from(t: $base_type) -> Result<Self, Self::Error> { - match t { - $( - $base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), - )+ - _ => Err(()), - } - } - } - }; -} - -sub_type! { - VariableRegType { - Scalar(ScalarType), - Vector(SizedScalarType, u8), - // Array type is used when emiting SSA statements at the start of a method - Array(ScalarType, VecU32), - // Pointer variant is used when passing around SLM pointer between - // function calls for dynamic SLM - Pointer(SizedScalarType, PointerStateSpace) - } -} - -type VecU32 = Vec<u32>; - -sub_type! { - VariableLocalType { - Scalar(SizedScalarType), - Vector(SizedScalarType, u8), - Array(SizedScalarType, VecU32), - } -} - -impl TryFrom<VariableGlobalType> for VariableLocalType { - type Error = PtxError; - - fn try_from(value: VariableGlobalType) -> Result<Self, Self::Error> { - match value { - VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)), - VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)), - VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)), - VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray), - } - } -} - -sub_type! { - VariableGlobalType { - Scalar(SizedScalarType), - Vector(SizedScalarType, u8), - Array(SizedScalarType, VecU32), - Pointer(SizedScalarType, PointerStateSpace), - } -} - -// For some weird reson this is illegal: -// .param .f16x2 foobar; -// but this is legal: -// .param .f16x2 foobar[1]; -// even more interestingly this is legal, but only in .func (not in .entry): -// .param .b32 foobar[] -sub_type! { - VariableParamType { - Scalar(LdStScalarType), - Array(SizedScalarType, VecU32), - Pointer(SizedScalarType, PointerStateSpace), - } -} - -sub_enum!(SizedScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F16x2, - F32, - F64, -}); - -sub_enum!(LdStScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F32, - F64, -}); - -sub_enum!(SelpType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, - F32, - F64, -}); - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum BarDetails { - SyncAligned, -} - -pub trait UnwrapWithVec<E, To> { - fn unwrap_with(self, errs: &mut Vec<E>) -> To; -} - -impl<R: Default, EFrom: std::convert::Into<EInto>, EInto> UnwrapWithVec<EInto, R> - for Result<R, EFrom> -{ - fn unwrap_with(self, errs: &mut Vec<EInto>) -> R { - self.unwrap_or_else(|e| { - errs.push(e.into()); - R::default() - }) - } -} - -impl< - R1: Default, - EFrom1: std::convert::Into<EInto>, - R2: Default, - EFrom2: std::convert::Into<EInto>, - EInto, - > UnwrapWithVec<EInto, (R1, R2)> for (Result<R1, EFrom1>, Result<R2, EFrom2>) -{ - fn unwrap_with(self, errs: &mut Vec<EInto>) -> (R1, R2) { - let (x, y) = self; - let r1 = x.unwrap_with(errs); - let r2 = y.unwrap_with(errs); - (r1, r2) - } -} - -pub struct Module<'a> { - pub version: (u8, u8), - pub directives: Vec<Directive<'a, ParsedArgParams<'a>>>, -} - -pub enum Directive<'a, P: ArgParams> { - Variable(Variable<VariableType, P::Id>), - Method(Function<'a, &'a str, Statement<P>>), -} - -pub enum MethodDecl<'a, ID> { - Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>), - Kernel { - name: &'a str, - in_args: Vec<KernelArgument<ID>>, - }, -} - -pub type FnArgument<ID> = Variable<FnArgumentType, ID>; -pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>; - -pub struct Function<'a, ID, S> { - pub func_directive: MethodDecl<'a, ID>, - pub body: Option<Vec<S>>, -} - -pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>; - -#[derive(PartialEq, Eq, Clone)] -pub enum FnArgumentType { - Reg(VariableRegType), - Param(VariableParamType), - Shared, -} -#[derive(PartialEq, Eq, Clone)] -pub enum KernelArgumentType { - Normal(VariableParamType), - Shared, -} - -impl From<KernelArgumentType> for Type { - fn from(this: KernelArgumentType) -> Self { - match this { - KernelArgumentType::Normal(typ) => typ.into(), - KernelArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } -} - -impl FnArgumentType { - pub fn to_type(&self, is_kernel: bool) -> Type { - if is_kernel { - self.to_kernel_type() - } else { - self.to_func_type() - } - } - - pub fn to_kernel_type(&self) -> Type { - match self { - FnArgumentType::Reg(x) => x.clone().into(), - FnArgumentType::Param(x) => x.clone().into(), - FnArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } - - pub fn to_func_type(&self) -> Type { - match self { - FnArgumentType::Reg(x) => x.clone().into(), - FnArgumentType::Param(VariableParamType::Scalar(t)) => { - Type::Pointer(PointerType::Scalar((*t).into()), LdStateSpace::Param) - } - FnArgumentType::Param(VariableParamType::Array(t, dims)) => Type::Pointer( - PointerType::Array((*t).into(), dims.clone()), - LdStateSpace::Param, - ), - FnArgumentType::Param(VariableParamType::Pointer(t, space)) => Type::Pointer( - PointerType::Pointer((*t).into(), (*space).into()), - LdStateSpace::Param, - ), - FnArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } - - pub fn is_param(&self) -> bool { - match self { - FnArgumentType::Param(_) => true, - _ => false, - } - } -} - -sub_enum!( - PointerStateSpace : LdStateSpace { - Generic, - Global, - Const, - Shared, - Param, - } -); - -#[derive(PartialEq, Eq, Clone)] -pub enum Type { - Scalar(ScalarType), - Vector(ScalarType, u8), - Array(ScalarType, Vec<u32>), - Pointer(PointerType, LdStateSpace), -} - -#[derive(PartialEq, Eq, Clone)] -pub enum PointerType { - Scalar(ScalarType), - Vector(ScalarType, u8), - Array(ScalarType, VecU32), - Pointer(ScalarType, LdStateSpace), -} - -impl From<SizedScalarType> for PointerType { - fn from(t: SizedScalarType) -> Self { - PointerType::Scalar(t.into()) - } -} - -impl TryFrom<PointerType> for SizedScalarType { - type Error = (); - - fn try_from(value: PointerType) -> Result<Self, Self::Error> { - match value { - PointerType::Scalar(t) => Ok(t.try_into()?), - PointerType::Vector(_, _) => Err(()), - PointerType::Array(_, _) => Err(()), - PointerType::Pointer(_, _) => Err(()), - } - } -} - -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub enum ScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F32, - F64, - F16x2, - Pred, -} - -sub_enum!(IntType { - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64 -}); - -sub_enum!(BitType { B8, B16, B32, B64 }); - -sub_enum!(UIntType { U8, U16, U32, U64 }); - -sub_enum!(SIntType { S8, S16, S32, S64 }); - -impl IntType { - pub fn is_signed(self) -> bool { - match self { - IntType::U8 | IntType::U16 | IntType::U32 | IntType::U64 => false, - IntType::S8 | IntType::S16 | IntType::S32 | IntType::S64 => true, - } - } - - pub fn width(self) -> u8 { - match self { - IntType::U8 => 1, - IntType::U16 => 2, - IntType::U32 => 4, - IntType::U64 => 8, - IntType::S8 => 1, - IntType::S16 => 2, - IntType::S32 => 4, - IntType::S64 => 8, - } - } -} - -sub_enum!(FloatType { - F16, - F16x2, - F32, - F64 -}); - -impl ScalarType { - pub fn size_of(self) -> u8 { - match self { - ScalarType::U8 => 1, - ScalarType::S8 => 1, - ScalarType::B8 => 1, - ScalarType::U16 => 2, - ScalarType::S16 => 2, - ScalarType::B16 => 2, - ScalarType::F16 => 2, - ScalarType::U32 => 4, - ScalarType::S32 => 4, - ScalarType::B32 => 4, - ScalarType::F32 => 4, - ScalarType::U64 => 8, - ScalarType::S64 => 8, - ScalarType::B64 => 8, - ScalarType::F64 => 8, - ScalarType::F16x2 => 4, - ScalarType::Pred => 1, - } - } -} - -impl Default for ScalarType { - fn default() -> Self { - ScalarType::B8 - } -} - -pub enum Statement<P: ArgParams> { - Label(P::Id), - Variable(MultiVariable<P::Id>), - Instruction(Option<PredAt<P::Id>>, Instruction<P>), - Block(Vec<Statement<P>>), -} - -pub struct MultiVariable<ID> { - pub var: Variable<VariableType, ID>, - pub count: Option<u32>, -} - -#[derive(Clone)] -pub struct Variable<T, ID> { - pub align: Option<u32>, - pub v_type: T, - pub name: ID, - pub array_init: Vec<u8>, -} - -#[derive(Eq, PartialEq, Clone)] -pub enum VariableType { - Reg(VariableRegType), - Local(VariableLocalType), - Param(VariableParamType), - Global(VariableGlobalType), - Shared(VariableGlobalType), -} - -impl VariableType { - pub fn to_type(&self) -> (StateSpace, Type) { - match self { - VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()), - VariableType::Local(t) => (StateSpace::Local, t.clone().into()), - VariableType::Param(t) => (StateSpace::Param, t.clone().into()), - VariableType::Global(t) => (StateSpace::Global, t.clone().into()), - VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()), - } - } -} - -impl From<VariableType> for Type { - fn from(t: VariableType) -> Self { - match t { - VariableType::Reg(t) => t.into(), - VariableType::Local(t) => t.into(), - VariableType::Param(t) => t.into(), - VariableType::Global(t) => t.into(), - VariableType::Shared(t) => t.into(), - } - } -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum StateSpace { - Reg, - Const, - Global, - Local, - Shared, - Param, -} - -pub struct PredAt<ID> { - pub not: bool, - pub label: ID, -} - -pub enum Instruction<P: ArgParams> { - Ld(LdDetails, Arg2Ld<P>), - Mov(MovDetails, Arg2Mov<P>), - Mul(MulDetails, Arg3<P>), - Add(ArithDetails, Arg3<P>), - Setp(SetpData, Arg4Setp<P>), - SetpBool(SetpBoolData, Arg5Setp<P>), - Not(BooleanType, Arg2<P>), - Bra(BraData, Arg1<P>), - Cvt(CvtDetails, Arg2<P>), - Cvta(CvtaDetails, Arg2<P>), - Shl(ShlType, Arg3<P>), - Shr(ShrType, Arg3<P>), - St(StData, Arg2St<P>), - Ret(RetData), - Call(CallInst<P>), - Abs(AbsDetails, Arg2<P>), - Mad(MulDetails, Arg4<P>), - Or(BooleanType, Arg3<P>), - Sub(ArithDetails, Arg3<P>), - Min(MinMaxDetails, Arg3<P>), - Max(MinMaxDetails, Arg3<P>), - Rcp(RcpDetails, Arg2<P>), - And(BooleanType, Arg3<P>), - Selp(SelpType, Arg4<P>), - Bar(BarDetails, Arg1Bar<P>), - Atom(AtomDetails, Arg3<P>), - AtomCas(AtomCasDetails, Arg4<P>), - Div(DivDetails, Arg3<P>), - Sqrt(SqrtDetails, Arg2<P>), - Rsqrt(RsqrtDetails, Arg2<P>), - Neg(NegDetails, Arg2<P>), - Sin { flush_to_zero: bool, arg: Arg2<P> }, - Cos { flush_to_zero: bool, arg: Arg2<P> }, - Lg2 { flush_to_zero: bool, arg: Arg2<P> }, - Ex2 { flush_to_zero: bool, arg: Arg2<P> }, - Clz { typ: BitType, arg: Arg2<P> }, - Brev { typ: BitType, arg: Arg2<P> }, - Popc { typ: BitType, arg: Arg2<P> }, - Xor { typ: BooleanType, arg: Arg3<P> }, - Bfe { typ: IntType, arg: Arg4<P> }, - Rem { typ: IntType, arg: Arg3<P> }, -} - -#[derive(Copy, Clone)] -pub struct MadFloatDesc {} - -#[derive(Copy, Clone)] -pub struct AbsDetails { - pub flush_to_zero: Option<bool>, - pub typ: ScalarType, -} -#[derive(Copy, Clone)] -pub struct RcpDetails { - pub rounding: Option<RoundingMode>, - pub flush_to_zero: Option<bool>, - pub is_f64: bool, -} - -pub struct CallInst<P: ArgParams> { - pub uniform: bool, - pub ret_params: Vec<P::Id>, - pub func: P::Id, - pub param_list: Vec<P::Operand>, -} - -pub trait ArgParams { - type Id; - type Operand; -} - -pub struct ParsedArgParams<'a> { - _marker: PhantomData<&'a ()>, -} - -impl<'a> ArgParams for ParsedArgParams<'a> { - type Id = &'a str; - type Operand = Operand<&'a str>; -} - -pub struct Arg1<P: ArgParams> { - pub src: P::Id, // it is a jump destination, but in terms of operands it is a source operand -} - -pub struct Arg1Bar<P: ArgParams> { - pub src: P::Operand, -} - -pub struct Arg2<P: ArgParams> { - pub dst: P::Operand, - pub src: P::Operand, -} -pub struct Arg2Ld<P: ArgParams> { - pub dst: P::Operand, - pub src: P::Operand, -} - -pub struct Arg2St<P: ArgParams> { - pub src1: P::Operand, - pub src2: P::Operand, -} - -pub struct Arg2Mov<P: ArgParams> { - pub dst: P::Operand, - pub src: P::Operand, -} - -pub struct Arg3<P: ArgParams> { - pub dst: P::Operand, - pub src1: P::Operand, - pub src2: P::Operand, -} - -pub struct Arg4<P: ArgParams> { - pub dst: P::Operand, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, -} - -pub struct Arg4Setp<P: ArgParams> { - pub dst1: P::Id, - pub dst2: Option<P::Id>, - pub src1: P::Operand, - pub src2: P::Operand, -} - -pub struct Arg5Setp<P: ArgParams> { - pub dst1: P::Id, - pub dst2: Option<P::Id>, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, -} - -#[derive(Copy, Clone)] -pub enum ImmediateValue { - U64(u64), - S64(i64), - F32(f32), - F64(f64), -} - -#[derive(Clone)] -pub enum Operand<Id> { - Reg(Id), - RegOffset(Id, i32), - Imm(ImmediateValue), - VecMember(Id, u8), - VecPack(Vec<Id>), -} - -pub enum VectorPrefix { - V2, - V4, -} - -pub struct LdDetails { - pub qualifier: LdStQualifier, - pub state_space: LdStateSpace, - pub caching: LdCacheOperator, - pub typ: LdStType, -} - -sub_type! { - LdStType { - Scalar(LdStScalarType), - Vector(LdStScalarType, u8), - // Used in generated code - Pointer(PointerType, LdStateSpace), - } -} - -impl From<LdStType> for PointerType { - fn from(t: LdStType) -> Self { - match t { - LdStType::Scalar(t) => PointerType::Scalar(t.into()), - LdStType::Vector(t, len) => PointerType::Vector(t.into(), len), - LdStType::Pointer(PointerType::Scalar(scalar_type), space) => { - PointerType::Pointer(scalar_type, space) - } - LdStType::Pointer(..) => unreachable!(), - } - } -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdStQualifier { - Weak, - Volatile, - Relaxed(MemScope), - Acquire(MemScope), -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum MemScope { - Cta, - Gpu, - Sys, -} - -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -#[repr(u8)] -pub enum LdStateSpace { - Generic, - Const, - Global, - Local, - Param, - Shared, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdCacheOperator { - Cached, - L2Only, - Streaming, - LastUse, - Uncached, -} - -#[derive(Clone)] -pub struct MovDetails { - pub typ: Type, - pub src_is_address: bool, - // two fields below are in use by member moves - pub dst_width: u8, - pub src_width: u8, - // This is in use by auto-generated movs - pub relaxed_src2_conv: bool, -} - -impl MovDetails { - pub fn new(typ: Type) -> Self { - MovDetails { - typ, - src_is_address: false, - dst_width: 0, - src_width: 0, - relaxed_src2_conv: false, - } - } -} - -#[derive(Copy, Clone)] -pub struct MulIntDesc { - pub typ: IntType, - pub control: MulIntControl, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum MulIntControl { - Low, - High, - Wide, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum RoundingMode { - NearestEven, - Zero, - NegativeInf, - PositiveInf, -} - -pub struct AddIntDesc { - pub typ: IntType, - pub saturate: bool, -} - -pub struct SetpData { - pub typ: ScalarType, - pub flush_to_zero: Option<bool>, - pub cmp_op: SetpCompareOp, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum SetpCompareOp { - Eq, - NotEq, - Less, - LessOrEq, - Greater, - GreaterOrEq, - NanEq, - NanNotEq, - NanLess, - NanLessOrEq, - NanGreater, - NanGreaterOrEq, - IsNotNan, - IsNan, -} - -pub enum SetpBoolPostOp { - And, - Or, - Xor, -} - -pub struct SetpBoolData { - pub typ: ScalarType, - pub flush_to_zero: Option<bool>, - pub cmp_op: SetpCompareOp, - pub bool_op: SetpBoolPostOp, -} - -pub struct BraData { - pub uniform: bool, -} - -pub enum CvtDetails { - IntFromInt(CvtIntToIntDesc), - FloatFromFloat(CvtDesc<FloatType, FloatType>), - IntFromFloat(CvtDesc<IntType, FloatType>), - FloatFromInt(CvtDesc<FloatType, IntType>), -} - -pub struct CvtIntToIntDesc { - pub dst: IntType, - pub src: IntType, - pub saturate: bool, -} - -pub struct CvtDesc<Dst, Src> { - pub rounding: Option<RoundingMode>, - pub flush_to_zero: Option<bool>, - pub saturate: bool, - pub dst: Dst, - pub src: Src, -} - -impl CvtDetails { - pub fn new_int_from_int_checked( - saturate: bool, - dst: IntType, - src: IntType, - err: &mut Vec<PtxError>, - ) -> Self { - if saturate { - if src.is_signed() { - if dst.is_signed() && dst.width() >= src.width() { - err.push(PtxError::SyntaxError); - } - } else { - if dst == src || dst.width() >= src.width() { - err.push(PtxError::SyntaxError); - } - } - } - CvtDetails::IntFromInt(CvtIntToIntDesc { dst, src, saturate }) - } - - pub fn new_float_from_int_checked( - rounding: RoundingMode, - flush_to_zero: bool, - saturate: bool, - dst: FloatType, - src: IntType, - err: &mut Vec<PtxError>, - ) -> Self { - if flush_to_zero && dst != FloatType::F32 { - err.push(PtxError::NonF32Ftz); - } - CvtDetails::FloatFromInt(CvtDesc { - dst, - src, - saturate, - flush_to_zero: Some(flush_to_zero), - rounding: Some(rounding), - }) - } - - pub fn new_int_from_float_checked( - rounding: RoundingMode, - flush_to_zero: bool, - saturate: bool, - dst: IntType, - src: FloatType, - err: &mut Vec<PtxError>, - ) -> Self { - if flush_to_zero && src != FloatType::F32 { - err.push(PtxError::NonF32Ftz); - } - CvtDetails::IntFromFloat(CvtDesc { - dst, - src, - saturate, - flush_to_zero: Some(flush_to_zero), - rounding: Some(rounding), - }) - } -} - -pub struct CvtaDetails { - pub to: CvtaStateSpace, - pub from: CvtaStateSpace, - pub size: CvtaSize, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum CvtaStateSpace { - Generic, - Const, - Global, - Local, - Shared, -} - -pub enum CvtaSize { - U32, - U64, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum ShlType { - B16, - B32, - B64, -} - -sub_enum!(ShrType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, -}); - -pub struct StData { - pub qualifier: LdStQualifier, - pub state_space: StStateSpace, - pub caching: StCacheOperator, - pub typ: LdStType, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum StStateSpace { - Generic, - Global, - Local, - Param, - Shared, -} - -#[derive(PartialEq, Eq)] -pub enum StCacheOperator { - Writeback, - L2Only, - Streaming, - Writethrough, -} - -pub struct RetData { - pub uniform: bool, -} - -sub_enum!(BooleanType { - Pred, - B16, - B32, - B64, -}); - -#[derive(Copy, Clone)] -pub enum MulDetails { - Unsigned(MulUInt), - Signed(MulSInt), - Float(ArithFloat), -} - -#[derive(Copy, Clone)] -pub struct MulUInt { - pub typ: UIntType, - pub control: MulIntControl, -} - -#[derive(Copy, Clone)] -pub struct MulSInt { - pub typ: SIntType, - pub control: MulIntControl, -} - -#[derive(Copy, Clone)] -pub enum ArithDetails { - Unsigned(UIntType), - Signed(ArithSInt), - Float(ArithFloat), -} - -#[derive(Copy, Clone)] -pub struct ArithSInt { - pub typ: SIntType, - pub saturate: bool, -} - -#[derive(Copy, Clone)] -pub struct ArithFloat { - pub typ: FloatType, - pub rounding: Option<RoundingMode>, - pub flush_to_zero: Option<bool>, - pub saturate: bool, -} - -#[derive(Copy, Clone)] -pub enum MinMaxDetails { - Signed(SIntType), - Unsigned(UIntType), - Float(MinMaxFloat), -} - -#[derive(Copy, Clone)] -pub struct MinMaxFloat { - pub flush_to_zero: Option<bool>, - pub nan: bool, - pub typ: FloatType, -} - -#[derive(Copy, Clone)] -pub struct AtomDetails { - pub semantics: AtomSemantics, - pub scope: MemScope, - pub space: AtomSpace, - pub inner: AtomInnerDetails, -} - -#[derive(Copy, Clone)] -pub enum AtomSemantics { - Relaxed, - Acquire, - Release, - AcquireRelease, -} - -#[derive(Copy, Clone)] -pub enum AtomSpace { - Generic, - Global, - Shared, -} - -#[derive(Copy, Clone)] -pub enum AtomInnerDetails { - Bit { op: AtomBitOp, typ: BitType }, - Unsigned { op: AtomUIntOp, typ: UIntType }, - Signed { op: AtomSIntOp, typ: SIntType }, - Float { op: AtomFloatOp, typ: FloatType }, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomBitOp { - And, - Or, - Xor, - Exchange, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomUIntOp { - Add, - Inc, - Dec, - Min, - Max, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomSIntOp { - Add, - Min, - Max, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomFloatOp { - Add, -} - -#[derive(Copy, Clone)] -pub struct AtomCasDetails { - pub semantics: AtomSemantics, - pub scope: MemScope, - pub space: AtomSpace, - pub typ: BitType, -} - -#[derive(Copy, Clone)] -pub enum DivDetails { - Unsigned(UIntType), - Signed(SIntType), - Float(DivFloatDetails), -} - -#[derive(Copy, Clone)] -pub struct DivFloatDetails { - pub typ: FloatType, - pub flush_to_zero: Option<bool>, - pub kind: DivFloatKind, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum DivFloatKind { - Approx, - Full, - Rounding(RoundingMode), -} - -pub enum NumsOrArrays<'a> { - Nums(Vec<(&'a str, u32)>), - Arrays(Vec<NumsOrArrays<'a>>), -} - -#[derive(Copy, Clone)] -pub struct SqrtDetails { - pub typ: FloatType, - pub flush_to_zero: Option<bool>, - pub kind: SqrtKind, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum SqrtKind { - Approx, - Rounding(RoundingMode), -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub struct RsqrtDetails { - pub typ: FloatType, - pub flush_to_zero: bool, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub struct NegDetails { - pub typ: ScalarType, - pub flush_to_zero: Option<bool>, -} - -impl<'a> NumsOrArrays<'a> { - pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> { - self.normalize_dimensions(dimensions)?; - let sizeof_t = ScalarType::from(typ).size_of() as usize; - let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize)); - let mut result = vec![0; result_size]; - self.parse_and_copy(typ, sizeof_t, dimensions, &mut result)?; - Ok(result) - } - - fn normalize_dimensions(&self, dimensions: &mut [u32]) -> Result<(), PtxError> { - match dimensions.first_mut() { - Some(first) => { - if *first == 0 { - *first = match self { - NumsOrArrays::Nums(v) => v.len() as u32, - NumsOrArrays::Arrays(v) => v.len() as u32, - }; - } - } - None => return Err(PtxError::ZeroDimensionArray), - } - for dim in dimensions { - if *dim == 0 { - return Err(PtxError::ZeroDimensionArray); - } - } - Ok(()) - } - - fn parse_and_copy( - &self, - t: SizedScalarType, - size_of_t: usize, - dimensions: &[u32], - result: &mut [u8], - ) -> Result<(), PtxError> { - match dimensions { - [] => unreachable!(), - [dim] => match self { - NumsOrArrays::Nums(vec) => { - if vec.len() > *dim as usize { - return Err(PtxError::ZeroDimensionArray); - } - for (idx, (val, radix)) in vec.iter().enumerate() { - Self::parse_and_copy_single(t, idx, val, *radix, result)?; - } - } - NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray), - }, - [first_dim, rest @ ..] => match self { - NumsOrArrays::Arrays(vec) => { - if vec.len() > *first_dim as usize { - return Err(PtxError::ZeroDimensionArray); - } - let size_of_element = rest.iter().fold(size_of_t, |x, y| x * (*y as usize)); - for (idx, this) in vec.iter().enumerate() { - this.parse_and_copy( - t, - size_of_t, - rest, - &mut result[(size_of_element * idx)..], - )?; - } - } - NumsOrArrays::Nums(_) => return Err(PtxError::ZeroDimensionArray), - }, - } - Ok(()) - } - - fn parse_and_copy_single( - t: SizedScalarType, - idx: usize, - str_val: &str, - radix: u32, - output: &mut [u8], - ) -> Result<(), PtxError> { - match t { - SizedScalarType::B8 | SizedScalarType::U8 => { - Self::parse_and_copy_single_t::<u8>(idx, str_val, radix, output)?; - } - SizedScalarType::B16 | SizedScalarType::U16 => { - Self::parse_and_copy_single_t::<u16>(idx, str_val, radix, output)?; - } - SizedScalarType::B32 | SizedScalarType::U32 => { - Self::parse_and_copy_single_t::<u32>(idx, str_val, radix, output)?; - } - SizedScalarType::B64 | SizedScalarType::U64 => { - Self::parse_and_copy_single_t::<u64>(idx, str_val, radix, output)?; - } - SizedScalarType::S8 => { - Self::parse_and_copy_single_t::<i8>(idx, str_val, radix, output)?; - } - SizedScalarType::S16 => { - Self::parse_and_copy_single_t::<i16>(idx, str_val, radix, output)?; - } - SizedScalarType::S32 => { - Self::parse_and_copy_single_t::<i32>(idx, str_val, radix, output)?; - } - SizedScalarType::S64 => { - Self::parse_and_copy_single_t::<i64>(idx, str_val, radix, output)?; - } - SizedScalarType::F16 => { - Self::parse_and_copy_single_t::<f16>(idx, str_val, radix, output)?; - } - SizedScalarType::F16x2 => todo!(), - SizedScalarType::F32 => { - Self::parse_and_copy_single_t::<f32>(idx, str_val, radix, output)?; - } - SizedScalarType::F64 => { - Self::parse_and_copy_single_t::<f64>(idx, str_val, radix, output)?; - } - } - Ok(()) - } - - fn parse_and_copy_single_t<T: Copy + FromStr>( - idx: usize, - str_val: &str, - _radix: u32, // TODO: use this to properly support hex literals - output: &mut [u8], - ) -> Result<(), PtxError> - where - T::Err: Into<PtxError>, - { - let typed_output = unsafe { - std::slice::from_raw_parts_mut::<T>( - output.as_mut_ptr() as *mut _, - output.len() / mem::size_of::<T>(), - ) - }; - typed_output[idx] = str_val.parse::<T>().map_err(|e| e.into())?; - Ok(()) - } -} - -pub enum ArrayOrPointer { - Array { dimensions: Vec<u32>, init: Vec<u8> }, - Pointer, -} - -bitflags! { - pub struct LinkingDirective: u8 { - const NONE = 0b000; - const EXTERN = 0b001; - const VISIBLE = 0b10; - const WEAK = 0b100; - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn array_fails_multiple_0_dmiensions() { - let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err()); - } - - #[test] - fn array_fails_on_empty() { - let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err()); - } - - #[test] - fn array_auto_sizes_0_dimension() { - let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]), - NumsOrArrays::Nums(vec![("3", 10), ("4", 10)]), - ]); - let mut dimensions = vec![0u32, 2]; - assert_eq!( - vec![1u8, 2, 3, 4], - inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap() - ); - assert_eq!(dimensions, vec![2u32, 2]); - } - - #[test] - fn array_fails_wrong_structure() { - let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]), - NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]), - ]); - let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); - } - - #[test] - fn array_fails_too_long_component() { - let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec![("1", 10), ("2", 10), ("3", 10)]), - NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]), - ]); - let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); - } -} +use half::f16;
+use lalrpop_util::{lexer::Token, ParseError};
+use std::alloc::Layout;
+use std::{convert::From, mem, num::ParseFloatError, str::FromStr};
+use std::{marker::PhantomData, num::ParseIntError};
+
+#[derive(Debug, thiserror::Error)]
+pub enum PtxError {
+ #[error("{source}")]
+ ParseInt {
+ #[from]
+ source: ParseIntError,
+ },
+ #[error("{source}")]
+ ParseFloat {
+ #[from]
+ source: ParseFloatError,
+ },
+ #[error("")]
+ SyntaxError,
+ #[error("")]
+ NonF32Ftz,
+ #[error("")]
+ WrongArrayType,
+ #[error("")]
+ WrongVectorElement,
+ #[error("")]
+ MultiArrayVariable,
+ #[error("")]
+ ZeroDimensionArray,
+ #[error("")]
+ ArrayInitializer,
+ #[error("")]
+ ScalarInitalizer,
+ #[error("")]
+ NonScalarArray,
+ #[error("")]
+ InvalidStateSpace,
+ #[error("")]
+ BlankVariableName,
+ #[error("")]
+ NonRegPredVariable,
+ #[error("")]
+ InitializerTypeMismatch,
+ #[error("")]
+ NonExternPointer,
+ #[error("{start}:{end}")]
+ UnrecognizedStatement { start: usize, end: usize },
+ #[error("{start}:{end}")]
+ UnrecognizedDirective { start: usize, end: usize },
+ #[error("")]
+ NoSmVersion,
+ #[error("")]
+ UnexpectedMultivariable,
+ #[error("")]
+ ExternDefinition,
+}
+
+// For some weird reson this is illegal:
+// .param .f16x2 foobar;
+// but this is legal:
+// .param .f16x2 foobar[1];
+// even more interestingly this is legal, but only in .func (not in .entry):
+// .param .b32 foobar[]
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum BarDetails {
+ SyncAligned,
+}
+
+#[derive(Eq, PartialEq, Clone, Copy)]
+pub enum ReductionOp {
+ And,
+ Or,
+ Popc,
+}
+
+pub trait UnwrapWithVec<E, To> {
+ fn unwrap_with(self, errs: &mut Vec<E>) -> To;
+}
+
+impl<R: Default, EFrom: std::convert::Into<EInto>, EInto> UnwrapWithVec<EInto, R>
+ for Result<R, EFrom>
+{
+ fn unwrap_with(self, errs: &mut Vec<EInto>) -> R {
+ self.unwrap_or_else(|e| {
+ errs.push(e.into());
+ R::default()
+ })
+ }
+}
+
+impl<
+ R1: Default,
+ EFrom1: std::convert::Into<EInto>,
+ R2: Default,
+ EFrom2: std::convert::Into<EInto>,
+ EInto,
+ > UnwrapWithVec<EInto, (R1, R2)> for (Result<R1, EFrom1>, Result<R2, EFrom2>)
+{
+ fn unwrap_with(self, errs: &mut Vec<EInto>) -> (R1, R2) {
+ let (x, y) = self;
+ let r1 = x.unwrap_with(errs);
+ let r2 = y.unwrap_with(errs);
+ (r1, r2)
+ }
+}
+
+pub struct Module<'a> {
+ pub sm_version: u32,
+ pub directives: Vec<Directive<'a, ParsedArgParams<'a>>>,
+}
+
+pub enum Directive<'a, P: ArgParams> {
+ Variable(LinkingDirective, MultiVariableDefinition<P::Id>),
+ Method(LinkingDirective, Function<'a, &'a str, Statement<P>>),
+}
+
+#[derive(Hash, PartialEq, Eq, Copy, Clone)]
+pub enum MethodName<'input, ID> {
+ Kernel(&'input str),
+ Func(ID),
+}
+
+impl<'input, ID> MethodName<'input, ID> {
+ pub fn is_kernel(&self) -> bool {
+ match self {
+ MethodName::Kernel(..) => true,
+ MethodName::Func(..) => false,
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct MethodDeclaration<'input, ID> {
+ pub return_arguments: Vec<VariableDeclaration<ID>>,
+ pub name: MethodName<'input, ID>,
+ pub input_arguments: Vec<VariableDeclaration<ID>>,
+}
+
+pub struct Function<'a, ID, S> {
+ pub func_directive: MethodDeclaration<'a, ID>,
+ pub tuning: Vec<TuningDirective>,
+ pub body: Option<Vec<S>>,
+}
+
+#[derive(PartialEq, Eq, Clone, Hash)]
+pub enum Type {
+ // .param.b32 foo;
+ // -> OpTypeInt
+ Scalar(ScalarType),
+ // .param.v2.b32 foo;
+ // -> OpTypeVector
+ Vector(ScalarType, u8),
+ // .param.b32 foo[4];
+ // -> OpTypeArray
+ Array(ScalarType, Vec<u32>),
+ /*
+ Variables of this type almost never exist in the original .ptx and are
+ usually artificially created. Some examples below:
+ - extern pointers to the .shared memory in the form:
+ .extern .shared .b32 shared_mem[];
+ which we first parse as
+ .extern .shared .b32 shared_mem;
+ and then convert to an additional function parameter:
+ .param .ptr<.b32.shared> shared_mem;
+ and do a load at the start of the function (and renames inside fn):
+ .reg .ptr<.b32.shared> temp;
+ ld.param.ptr<.b32.shared> temp, [shared_mem];
+ note, we don't support non-.shared extern pointers, because there's
+ zero use for them in the ptxas
+ - artifical pointers created by stateful conversion, which work
+ similiarly to the above
+ - function parameters:
+ foobar(.param .align 4 .b8 numbers[])
+ which get parsed to
+ foobar(.param .align 4 .b8 numbers)
+ and then converted to
+ foobar(.reg .align 4 .ptr<.b8.param> numbers)
+ - ld/st with offset:
+ .reg.b32 x;
+ .param.b64 arg0;
+ st.param.b32 [arg0+4], x;
+ Yes, this code is legal and actually emitted by the NV compiler!
+ We convert the st to:
+ .reg ptr<.b64.param> temp = ptr_offset(arg0, 4);
+ st.param.b32 [temp], x;
+ */
+ // .reg ptr<.b64.param>
+ // -> OpTypePointer Function
+ Pointer(ScalarType, StateSpace),
+ Texref,
+ Surfref,
+ // Structs exist only to support certain internal, compiler-generated patterns
+ Struct(Vec<StructField>),
+}
+
+#[derive(PartialEq, Eq, Hash, Clone, Copy)]
+pub enum ScalarType {
+ B8,
+ B16,
+ B32,
+ B64,
+ U8,
+ U16,
+ U32,
+ U64,
+ S8,
+ S16,
+ S32,
+ S64,
+ F16,
+ F32,
+ F64,
+ F16x2,
+ Pred,
+}
+
+impl ScalarType {
+ pub(crate) fn to_ptx_name(self) -> &'static str {
+ match self {
+ ScalarType::B8 => "b8",
+ ScalarType::B16 => "b16",
+ ScalarType::B32 => "b32",
+ ScalarType::B64 => "b64",
+ ScalarType::U8 => "u8",
+ ScalarType::U16 => "u16",
+ ScalarType::U32 => "u32",
+ ScalarType::U64 => "u64",
+ ScalarType::S8 => "s8",
+ ScalarType::S16 => "s16",
+ ScalarType::S32 => "s32",
+ ScalarType::S64 => "s64",
+ ScalarType::F16 => "f16",
+ ScalarType::F32 => "f32",
+ ScalarType::F64 => "f64",
+ ScalarType::F16x2 => "f16x2",
+ ScalarType::Pred => "pred",
+ }
+ }
+}
+
+impl ScalarType {
+ pub fn size_of(self) -> u8 {
+ match self {
+ ScalarType::U8 => 1,
+ ScalarType::S8 => 1,
+ ScalarType::B8 => 1,
+ ScalarType::U16 => 2,
+ ScalarType::S16 => 2,
+ ScalarType::B16 => 2,
+ ScalarType::F16 => 2,
+ ScalarType::U32 => 4,
+ ScalarType::S32 => 4,
+ ScalarType::B32 => 4,
+ ScalarType::F32 => 4,
+ ScalarType::U64 => 8,
+ ScalarType::S64 => 8,
+ ScalarType::B64 => 8,
+ ScalarType::F64 => 8,
+ ScalarType::F16x2 => 4,
+ ScalarType::Pred => 1,
+ }
+ }
+}
+
+impl Default for ScalarType {
+ fn default() -> Self {
+ ScalarType::B8
+ }
+}
+
+#[derive(PartialEq, Eq, Hash, Clone, Copy)]
+pub enum StructField {
+ Scalar(ScalarType),
+ Vector(ScalarType, u8),
+}
+
+impl StructField {
+ pub fn to_type(self) -> Type {
+ match self {
+ Self::Scalar(type_) => Type::Scalar(type_),
+ Self::Vector(type_, size) => Type::Vector(type_, size),
+ }
+ }
+}
+
+pub enum Statement<P: ArgParams> {
+ Label(P::Id),
+ Callprototype(Callprototype<P::Id>),
+ Variable(Vec<MultiVariableDefinition<P::Id>>),
+ Instruction(Option<PredAt<P::Id>>, Instruction<P>),
+ Block(Vec<Statement<P>>),
+}
+
+#[derive(Clone)]
+pub struct Callprototype<ID> {
+ pub name: ID,
+ pub return_arguments: Vec<(Type, StateSpace)>,
+ pub input_arguments: Vec<(Type, StateSpace)>,
+}
+
+#[derive(Clone)]
+pub struct VariableDeclaration<ID> {
+ pub align: Option<u32>,
+ pub type_: Type,
+ pub state_space: StateSpace,
+ pub name: ID,
+}
+
+impl<ID> VariableDeclaration<ID> {
+ pub fn layout(&self) -> Layout {
+ let layout = self.type_.layout();
+ match self.align.map(|a| layout.align_to(a as usize)) {
+ Some(Ok(aligned_layout)) => aligned_layout,
+ _ => layout,
+ }
+ }
+}
+
+#[derive(Clone)]
+pub struct MultiVariableDefinition<ID> {
+ pub variable: VariableDeclaration<ID>,
+ pub suffix: Option<DeclarationSuffix<ID>>,
+}
+
+#[derive(Clone)]
+pub enum DeclarationSuffix<ID> {
+ Count(u32),
+ Initializer(Initializer<ID>),
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, Hash)]
+pub enum StateSpace {
+ Reg,
+ Const,
+ Global,
+ Local,
+ Shared,
+ Param,
+ Generic,
+ Sreg,
+}
+
+pub struct PredAt<ID> {
+ pub not: bool,
+ pub label: ID,
+}
+
+pub struct BfindDetails {
+ pub shift: bool,
+ pub type_: ScalarType,
+}
+
+pub enum Instruction<P: ArgParams> {
+ Ld(LdDetails, Arg2Ld<P>),
+ Mov(MovDetails, Arg2Mov<P>),
+ Mul(MulDetails, Arg3<P>),
+ Add(ArithDetails, Arg3<P>),
+ AddC(CarryInDetails, Arg3<P>),
+ AddCC(ScalarType, Arg3<P>),
+ Setp(SetpData, Arg4Setp<P>),
+ SetpBool(SetpBoolData, Arg5Setp<P>),
+ Not(ScalarType, Arg2<P>),
+ Bra(BraData, Arg1<P>),
+ Cvt(CvtDetails, Arg2<P>),
+ Cvta(CvtaDetails, Arg2<P>),
+ Shl(ScalarType, Arg3<P>),
+ Shr(ScalarType, Arg3<P>),
+ St(StData, Arg2St<P>),
+ Ret(RetData),
+ Call(CallInst<P>),
+ Abs(AbsDetails, Arg2<P>),
+ Mad(MulDetails, Arg4<P>),
+ MadC {
+ type_: ScalarType,
+ carry_out: bool,
+ is_hi: bool,
+ arg: Arg4<P>,
+ },
+ MadCC {
+ type_: ScalarType,
+ arg: Arg4<P>,
+ },
+ Fma(ArithFloat, Arg4<P>),
+ Or(ScalarType, Arg3<P>),
+ Sub(ArithDetails, Arg3<P>),
+ SubC(CarryInDetails, Arg3<P>),
+ SubCC(ScalarType, Arg3<P>),
+ Min(MinMaxDetails, Arg3<P>),
+ Max(MinMaxDetails, Arg3<P>),
+ Rcp(RcpSqrtDetails, Arg2<P>),
+ Sqrt(RcpSqrtDetails, Arg2<P>),
+ And(ScalarType, Arg3<P>),
+ Selp(ScalarType, Arg4<P>),
+ Bar(BarDetails, Arg1Bar<P>),
+ BarWarp(BarDetails, Arg1Bar<P>),
+ BarRed(ReductionOp, Arg3<P>),
+ Atom(AtomDetails, Arg3<P>),
+ AtomCas(AtomCasDetails, Arg4<P>),
+ Div(DivDetails, Arg3<P>),
+ Rsqrt(RsqrtDetails, Arg2<P>),
+ Neg(NegDetails, Arg2<P>),
+ Sin {
+ flush_to_zero: bool,
+ arg: Arg2<P>,
+ },
+ Cos {
+ flush_to_zero: bool,
+ arg: Arg2<P>,
+ },
+ Lg2 {
+ flush_to_zero: bool,
+ arg: Arg2<P>,
+ },
+ Ex2 {
+ flush_to_zero: bool,
+ arg: Arg2<P>,
+ },
+ Clz {
+ typ: ScalarType,
+ arg: Arg2<P>,
+ },
+ Brev {
+ typ: ScalarType,
+ arg: Arg2<P>,
+ },
+ Popc {
+ typ: ScalarType,
+ arg: Arg2<P>,
+ },
+ Xor {
+ typ: ScalarType,
+ arg: Arg3<P>,
+ },
+ Bfe {
+ typ: ScalarType,
+ arg: Arg4<P>,
+ },
+ Bfi {
+ typ: ScalarType,
+ arg: Arg5<P>,
+ },
+ Rem {
+ typ: ScalarType,
+ arg: Arg3<P>,
+ },
+ Prmt {
+ control: u16,
+ arg: Arg3<P>,
+ },
+ PrmtSlow {
+ control: P::Id,
+ arg: Arg3<P>,
+ },
+ Activemask {
+ arg: Arg1<P>,
+ },
+ Membar {
+ level: MemScope,
+ },
+ Tex(TexDetails, Arg4Tex<P>),
+ Suld(SurfaceDetails, Arg4Tex<P>),
+ Sust(SurfaceDetails, Arg4Sust<P>),
+ Shfl(ShflMode, Arg5Shfl<P>),
+ Shf(FunnelShift, Arg4<P>),
+ Vote(VoteDetails, Arg3<P>),
+ Exit,
+ Trap,
+ Brkpt,
+ Vshr(Arg4<P>),
+ Bfind(BfindDetails, Arg2<P>),
+ Set(SetData, Arg3<P>),
+ Dp4a(ScalarType, Arg4<P>),
+ MatchAny(Arg3<P>),
+ Red(AtomDetails, Arg2St<P>),
+ Nanosleep(Arg1<P>),
+}
+
+#[derive(Copy, Clone)]
+
+pub struct CarryInDetails {
+ pub type_: ScalarType,
+ pub carry_out: bool,
+}
+
+#[derive(Copy, Clone)]
+pub enum ShflMode {
+ Up,
+ Down,
+ Bfly,
+ Idx,
+}
+
+#[derive(Copy, Clone)]
+pub struct VoteDetails {
+ pub mode: VoteMode,
+ pub negate_pred: bool,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum VoteMode {
+ Ballot,
+ All,
+ Any,
+ Uni,
+}
+
+#[derive(Copy, Clone)]
+pub struct MadFloatDesc {}
+
+#[derive(Copy, Clone)]
+pub struct AbsDetails {
+ pub flush_to_zero: Option<bool>,
+ pub typ: ScalarType,
+}
+
+#[derive(Copy, Clone)]
+pub struct RcpSqrtDetails {
+ pub kind: RcpSqrtKind,
+ pub flush_to_zero: Option<bool>,
+ pub type_: ScalarType,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum RcpSqrtKind {
+ Approx,
+ Rounding(RoundingMode),
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub struct FunnelShift {
+ pub direction: FunnelDirection,
+ pub mode: ShiftNormalization,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum FunnelDirection {
+ Left,
+ Right,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum ShiftNormalization {
+ Wrap,
+ Clamp,
+}
+
+pub struct CallInst<P: ArgParams> {
+ pub uniform: bool,
+ pub ret_params: Vec<P::Id>,
+ pub func: P::Id,
+ pub param_list: Vec<P::Operand>,
+ pub prototype: Option<P::Id>,
+}
+
+pub trait ArgParams {
+ type Id;
+ type Operand;
+}
+
+pub struct ParsedArgParams<'a> {
+ _marker: PhantomData<&'a ()>,
+}
+
+impl<'a> ArgParams for ParsedArgParams<'a> {
+ type Id = &'a str;
+ type Operand = Operand<&'a str>;
+}
+
+pub struct Arg1<P: ArgParams> {
+ pub src: P::Id, // it is a jump destination, but in terms of operands it is a source operand
+}
+
+pub struct Arg1Bar<P: ArgParams> {
+ pub src: P::Operand,
+}
+
+pub struct Arg2<P: ArgParams> {
+ pub dst: P::Operand,
+ pub src: P::Operand,
+}
+pub struct Arg2Ld<P: ArgParams> {
+ pub dst: P::Operand,
+ pub src: P::Operand,
+}
+
+pub struct Arg2St<P: ArgParams> {
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+}
+
+pub struct Arg2Mov<P: ArgParams> {
+ pub dst: P::Operand,
+ pub src: P::Operand,
+}
+
+pub struct Arg3<P: ArgParams> {
+ pub dst: P::Operand,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+}
+
+pub struct Arg4<P: ArgParams> {
+ pub dst: P::Operand,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+ pub src3: P::Operand,
+}
+
+pub struct Arg4Setp<P: ArgParams> {
+ pub dst1: P::Id,
+ pub dst2: Option<P::Id>,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+}
+
+pub struct Arg4Tex<P: ArgParams> {
+ pub dst: P::Operand,
+ pub image: P::Operand,
+ pub layer: Option<P::Operand>,
+ pub coordinates: P::Operand,
+}
+
+pub struct Arg4Sust<P: ArgParams> {
+ pub image: P::Operand,
+ pub coordinates: P::Operand,
+ pub layer: Option<P::Operand>,
+ pub value: P::Operand,
+}
+
+pub struct Arg5<P: ArgParams> {
+ pub dst: P::Operand,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+ pub src3: P::Operand,
+ pub src4: P::Operand,
+}
+
+pub struct Arg5Setp<P: ArgParams> {
+ pub dst1: P::Id,
+ pub dst2: Option<P::Id>,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+ pub src3: P::Operand,
+}
+
+pub struct Arg5Shfl<P: ArgParams> {
+ pub dst1: P::Id,
+ pub dst2: Option<P::Id>,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+ pub src3: P::Operand,
+}
+
+#[derive(Copy, Clone)]
+pub enum ImmediateValue {
+ U64(u64),
+ S64(i64),
+ F32(f32),
+ F64(f64),
+}
+
+impl ImmediateValue {
+ pub fn to_bytes(self) -> Vec<u8> {
+ match self {
+ ImmediateValue::U64(x) => x.to_ne_bytes().to_vec(),
+ ImmediateValue::S64(x) => x.to_ne_bytes().to_vec(),
+ ImmediateValue::F32(x) => x.to_ne_bytes().to_vec(),
+ ImmediateValue::F64(x) => x.to_ne_bytes().to_vec(),
+ }
+ }
+
+ pub fn as_u8(self) -> Option<u8> {
+ match self {
+ ImmediateValue::U64(x) => Some(x as u8),
+ ImmediateValue::S64(x) => Some(x as u8),
+ ImmediateValue::F32(_) | ImmediateValue::F64(_) => None,
+ }
+ }
+
+ pub fn as_i8(self) -> Option<i8> {
+ match self {
+ ImmediateValue::U64(x) => Some(x as i8),
+ ImmediateValue::S64(x) => Some(x as i8),
+ ImmediateValue::F32(_) | ImmediateValue::F64(_) => None,
+ }
+ }
+
+ pub fn as_u16(self) -> Option<u16> {
+ match self {
+ ImmediateValue::U64(x) => Some(x as u16),
+ ImmediateValue::S64(x) => Some(x as u16),
+ ImmediateValue::F32(_) | ImmediateValue::F64(_) => None,
+ }
+ }
+
+ pub fn as_i16(self) -> Option<i16> {
+ match self {
+ ImmediateValue::U64(x) => Some(x as i16),
+ ImmediateValue::S64(x) => Some(x as i16),
+ ImmediateValue::F32(_) | ImmediateValue::F64(_) => None,
+ }
+ }
+
+ pub fn as_u32(self) -> Option<u32> {
+ match self {
+ ImmediateValue::U64(x) => Some(x as u32),
+ ImmediateValue::S64(x) => Some(x as u32),
+ ImmediateValue::F32(_) | ImmediateValue::F64(_) => None,
+ }
+ }
+
+ pub fn as_i32(self) -> Option<i32> {
+ match self {
+ ImmediateValue::U64(x) => Some(x as i32),
+ ImmediateValue::S64(x) => Some(x as i32),
+ ImmediateValue::F32(_) | ImmediateValue::F64(_) => None,
+ }
+ }
+
+ pub fn as_u64(self) -> Option<u64> {
+ match self {
+ ImmediateValue::U64(x) => Some(x),
+ ImmediateValue::S64(x) => Some(x as u64),
+ ImmediateValue::F32(_) | ImmediateValue::F64(_) => None,
+ }
+ }
+
+ pub fn as_i64(self) -> Option<i64> {
+ match self {
+ ImmediateValue::U64(x) => Some(x as i64),
+ ImmediateValue::S64(x) => Some(x),
+ ImmediateValue::F32(_) | ImmediateValue::F64(_) => None,
+ }
+ }
+
+ pub fn as_f32(self) -> Option<f32> {
+ match self {
+ ImmediateValue::F32(x) => Some(x),
+ ImmediateValue::F64(_) | ImmediateValue::U64(_) | ImmediateValue::S64(_) => None,
+ }
+ }
+
+ pub fn as_f64(self) -> Option<f64> {
+ match self {
+ ImmediateValue::F32(x) => Some(x as f64),
+ ImmediateValue::F64(x) => Some(x),
+ ImmediateValue::U64(_) | ImmediateValue::S64(_) => None,
+ }
+ }
+}
+
+#[derive(Clone)]
+pub enum Operand<Id> {
+ Reg(Id),
+ RegOffset(Id, i64),
+ Imm(ImmediateValue),
+ VecMember(Id, u8),
+ VecPack(Vec<RegOrImmediate<Id>>),
+}
+
+#[derive(Clone)]
+pub enum RegOrImmediate<Id> {
+ Reg(Id),
+ Imm(ImmediateValue),
+}
+
+pub struct LdDetails {
+ pub qualifier: LdStQualifier,
+ pub state_space: StateSpace,
+ pub caching: LdCacheOperator,
+ pub typ: Type,
+ pub non_coherent: bool,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum LdStQualifier {
+ Weak,
+ Volatile,
+ Relaxed(MemScope),
+ Acquire(MemScope),
+ Release(MemScope),
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum MemScope {
+ Cta,
+ Gpu,
+ Sys,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum LdCacheOperator {
+ Cached,
+ L2Only,
+ Streaming,
+ LastUse,
+ Uncached,
+}
+
+#[derive(Clone)]
+pub struct MovDetails {
+ pub typ: Type,
+ // two fields below are in use by member moves
+ pub dst_width: u8,
+ pub src_width: u8,
+ // This is in use by auto-generated movs
+ pub relaxed_src2_conv: bool,
+}
+
+impl MovDetails {
+ pub fn new(typ: Type) -> Self {
+ MovDetails {
+ typ,
+ dst_width: 0,
+ src_width: 0,
+ relaxed_src2_conv: false,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct MulIntDesc {
+ pub typ: ScalarType,
+ pub control: MulIntControl,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum MulIntControl {
+ Low,
+ High,
+ Wide,
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum RoundingMode {
+ NearestEven,
+ Zero,
+ NegativeInf,
+ PositiveInf,
+}
+
+pub struct AddIntDesc {
+ pub typ: ScalarType,
+ pub saturate: bool,
+}
+
+pub struct SetpData {
+ pub typ: ScalarType,
+ pub flush_to_zero: Option<bool>,
+ pub cmp_op: SetpCompareOp,
+}
+
+pub struct SetData {
+ pub dst_type: ScalarType,
+ pub src_type: ScalarType,
+ pub flush_to_zero: bool,
+ pub cmp_op: SetpCompareOp,
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum SetpCompareOp {
+ Eq,
+ NotEq,
+ Less,
+ LessOrEq,
+ Greater,
+ GreaterOrEq,
+ NanEq,
+ NanNotEq,
+ NanLess,
+ NanLessOrEq,
+ NanGreater,
+ NanGreaterOrEq,
+ IsNotNan,
+ IsAnyNan,
+}
+
+pub struct SetpBoolData {
+ pub base: SetpData,
+ pub bool_op: SetpBoolPostOp,
+ pub negate_src3: bool,
+}
+
+#[derive(Clone, Copy)]
+pub enum SetpBoolPostOp {
+ And,
+ Or,
+ Xor,
+}
+
+pub struct BraData {
+ pub uniform: bool,
+}
+
+pub enum CvtDetails {
+ IntFromInt(CvtIntToIntDesc),
+ FloatFromFloat(CvtDesc),
+ IntFromFloat(CvtDesc),
+ FloatFromInt(CvtDesc),
+}
+
+pub struct CvtIntToIntDesc {
+ pub dst: ScalarType,
+ pub src: ScalarType,
+ pub saturate: bool,
+}
+
+#[derive(Clone)]
+pub struct CvtDesc {
+ pub rounding: Option<RoundingMode>,
+ pub flush_to_zero: Option<bool>,
+ pub saturate: bool,
+ pub dst: ScalarType,
+ pub src: ScalarType,
+}
+
+impl CvtDetails {
+ pub fn new_int_from_int_checked<'err, 'input>(
+ saturate: bool,
+ dst: ScalarType,
+ src: ScalarType,
+ err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
+ ) -> Self {
+ if saturate {
+ if src.kind() == ScalarKind::Signed {
+ if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() {
+ err.push(ParseError::from(PtxError::SyntaxError));
+ }
+ } else {
+ if dst == src || dst.size_of() >= src.size_of() {
+ err.push(ParseError::from(PtxError::SyntaxError));
+ }
+ }
+ }
+ CvtDetails::IntFromInt(CvtIntToIntDesc { dst, src, saturate })
+ }
+
+ pub fn new_float_from_int_checked<'err, 'input>(
+ rounding: RoundingMode,
+ flush_to_zero: bool,
+ saturate: bool,
+ dst: ScalarType,
+ src: ScalarType,
+ err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
+ ) -> Self {
+ if flush_to_zero && dst != ScalarType::F32 {
+ err.push(ParseError::from(PtxError::NonF32Ftz));
+ }
+ CvtDetails::FloatFromInt(CvtDesc {
+ dst,
+ src,
+ saturate,
+ flush_to_zero: Some(flush_to_zero),
+ rounding: Some(rounding),
+ })
+ }
+
+ pub fn new_int_from_float_checked<'err, 'input>(
+ rounding: RoundingMode,
+ flush_to_zero: bool,
+ saturate: bool,
+ dst: ScalarType,
+ src: ScalarType,
+ err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
+ ) -> Self {
+ if flush_to_zero && src != ScalarType::F32 {
+ err.push(ParseError::from(PtxError::NonF32Ftz));
+ }
+ CvtDetails::IntFromFloat(CvtDesc {
+ dst,
+ src,
+ saturate,
+ flush_to_zero: Some(flush_to_zero),
+ rounding: Some(rounding),
+ })
+ }
+}
+
+pub struct CvtaDetails {
+ pub to: StateSpace,
+ pub from: StateSpace,
+ pub size: CvtaSize,
+}
+
+#[derive(Clone, Copy)]
+pub enum CvtaSize {
+ U32,
+ U64,
+}
+
+pub struct StData {
+ pub qualifier: LdStQualifier,
+ pub state_space: StateSpace,
+ pub caching: StCacheOperator,
+ pub typ: Type,
+}
+
+#[derive(PartialEq, Eq)]
+pub enum StCacheOperator {
+ Writeback,
+ L2Only,
+ Streaming,
+ Writethrough,
+}
+
+#[derive(Copy, Clone)]
+pub struct RetData {
+ pub uniform: bool,
+}
+
+#[derive(Copy, Clone)]
+pub enum MulDetails {
+ Unsigned(MulInt),
+ Signed(MulInt),
+ Float(ArithFloat),
+}
+
+#[derive(Copy, Clone)]
+pub struct MulInt {
+ pub typ: ScalarType,
+ pub control: MulIntControl,
+}
+
+#[derive(Copy, Clone)]
+pub enum ArithDetails {
+ Unsigned(ScalarType),
+ Signed(ArithSInt),
+ Float(ArithFloat),
+}
+
+#[derive(Copy, Clone)]
+pub struct ArithSInt {
+ pub typ: ScalarType,
+ pub saturate: bool,
+}
+
+#[derive(Copy, Clone)]
+pub struct ArithFloat {
+ pub typ: ScalarType,
+ pub rounding: Option<RoundingMode>,
+ pub flush_to_zero: Option<bool>,
+ pub saturate: bool,
+}
+
+#[derive(Copy, Clone)]
+pub enum MinMaxDetails {
+ Signed(ScalarType),
+ Unsigned(ScalarType),
+ Float(MinMaxFloat),
+}
+
+#[derive(Copy, Clone)]
+pub struct MinMaxFloat {
+ pub flush_to_zero: Option<bool>,
+ pub nan: bool,
+ pub typ: ScalarType,
+}
+
+#[derive(Copy, Clone)]
+pub struct AtomDetails {
+ pub semantics: AtomSemantics,
+ pub scope: MemScope,
+ pub space: StateSpace,
+ pub inner: AtomInnerDetails,
+}
+
+#[derive(Copy, Clone)]
+pub enum AtomSemantics {
+ Relaxed,
+ Acquire,
+ Release,
+ AcquireRelease,
+}
+
+#[derive(Copy, Clone)]
+pub enum AtomInnerDetails {
+ Bit { op: AtomBitOp, typ: ScalarType },
+ Unsigned { op: AtomUIntOp, typ: ScalarType },
+ Signed { op: AtomSIntOp, typ: ScalarType },
+ Float { op: AtomFloatOp, typ: ScalarType },
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum AtomBitOp {
+ And,
+ Or,
+ Xor,
+ Exchange,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum AtomUIntOp {
+ Add,
+ Inc,
+ Dec,
+ Min,
+ Max,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum AtomSIntOp {
+ Add,
+ Min,
+ Max,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum AtomFloatOp {
+ Add,
+}
+
+#[derive(Copy, Clone)]
+pub struct AtomCasDetails {
+ pub semantics: AtomSemantics,
+ pub scope: MemScope,
+ pub space: StateSpace,
+ pub typ: ScalarType,
+}
+
+#[derive(Copy, Clone)]
+pub enum DivDetails {
+ Unsigned(ScalarType),
+ Signed(ScalarType),
+ Float(DivFloatDetails),
+}
+
+#[derive(Copy, Clone)]
+pub struct DivFloatDetails {
+ pub typ: ScalarType,
+ pub flush_to_zero: Option<bool>,
+ pub kind: DivFloatKind,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum DivFloatKind {
+ Approx,
+ Full,
+ Rounding(RoundingMode),
+}
+
+pub enum NumsOrArrays<'a> {
+ Nums(Vec<(&'a str, u32)>),
+ Arrays(Vec<NumsOrArrays<'a>>),
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub struct RsqrtDetails {
+ pub typ: ScalarType,
+ pub flush_to_zero: bool,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub struct NegDetails {
+ pub typ: ScalarType,
+ pub flush_to_zero: Option<bool>,
+}
+
+impl<'a> NumsOrArrays<'a> {
+ pub fn to_vec(self, typ: ScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
+ self.normalize_dimensions(dimensions)?;
+ let sizeof_t = ScalarType::from(typ).size_of() as usize;
+ let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize));
+ let mut result = vec![0; result_size];
+ self.parse_and_copy(typ, sizeof_t, dimensions, &mut result)?;
+ Ok(result)
+ }
+
+ fn normalize_dimensions(&self, dimensions: &mut [u32]) -> Result<(), PtxError> {
+ match dimensions.first_mut() {
+ Some(first) => {
+ if *first == 0 {
+ *first = match self {
+ NumsOrArrays::Nums(v) => v.len() as u32,
+ NumsOrArrays::Arrays(v) => v.len() as u32,
+ };
+ }
+ }
+ None => return Err(PtxError::ZeroDimensionArray),
+ }
+ for dim in dimensions {
+ if *dim == 0 {
+ return Err(PtxError::ZeroDimensionArray);
+ }
+ }
+ Ok(())
+ }
+
+ fn parse_and_copy(
+ &self,
+ t: ScalarType,
+ size_of_t: usize,
+ dimensions: &[u32],
+ result: &mut [u8],
+ ) -> Result<(), PtxError> {
+ match dimensions {
+ [] => unreachable!(),
+ [dim] => match self {
+ NumsOrArrays::Nums(vec) => {
+ if vec.len() > *dim as usize {
+ return Err(PtxError::ZeroDimensionArray);
+ }
+ for (idx, (val, radix)) in vec.iter().enumerate() {
+ Self::parse_and_copy_single(t, idx, val, *radix, result)?;
+ }
+ }
+ NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray),
+ },
+ [first_dim, rest @ ..] => match self {
+ NumsOrArrays::Arrays(vec) => {
+ if vec.len() > *first_dim as usize {
+ return Err(PtxError::ZeroDimensionArray);
+ }
+ let size_of_element = rest.iter().fold(size_of_t, |x, y| x * (*y as usize));
+ for (idx, this) in vec.iter().enumerate() {
+ this.parse_and_copy(
+ t,
+ size_of_t,
+ rest,
+ &mut result[(size_of_element * idx)..],
+ )?;
+ }
+ }
+ NumsOrArrays::Nums(_) => return Err(PtxError::ZeroDimensionArray),
+ },
+ }
+ Ok(())
+ }
+
+ fn parse_and_copy_single(
+ t: ScalarType,
+ idx: usize,
+ str_val: &str,
+ radix: u32,
+ output: &mut [u8],
+ ) -> Result<(), PtxError> {
+ match t {
+ ScalarType::B8 | ScalarType::U8 => {
+ Self::parse_and_copy_single_t::<u8>(idx, str_val, radix, output)?;
+ }
+ ScalarType::B16 | ScalarType::U16 => {
+ Self::parse_and_copy_single_t::<u16>(idx, str_val, radix, output)?;
+ }
+ ScalarType::B32 | ScalarType::U32 => {
+ Self::parse_and_copy_single_t::<u32>(idx, str_val, radix, output)?;
+ }
+ ScalarType::B64 | ScalarType::U64 => {
+ Self::parse_and_copy_single_t::<u64>(idx, str_val, radix, output)?;
+ }
+ ScalarType::S8 => {
+ Self::parse_and_copy_single_t::<i8>(idx, str_val, radix, output)?;
+ }
+ ScalarType::S16 => {
+ Self::parse_and_copy_single_t::<i16>(idx, str_val, radix, output)?;
+ }
+ ScalarType::S32 => {
+ Self::parse_and_copy_single_t::<i32>(idx, str_val, radix, output)?;
+ }
+ ScalarType::S64 => {
+ Self::parse_and_copy_single_t::<i64>(idx, str_val, radix, output)?;
+ }
+ ScalarType::F16 => {
+ Self::parse_and_copy_single_t::<f16>(idx, str_val, radix, output)?;
+ }
+ ScalarType::F16x2 => todo!(),
+ ScalarType::F32 => {
+ Self::parse_and_copy_single_t::<f32>(idx, str_val, radix, output)?;
+ }
+ ScalarType::F64 => {
+ Self::parse_and_copy_single_t::<f64>(idx, str_val, radix, output)?;
+ }
+ ScalarType::Pred => todo!(),
+ }
+ Ok(())
+ }
+
+ fn parse_and_copy_single_t<T: Copy + FromStr>(
+ idx: usize,
+ str_val: &str,
+ _radix: u32, // TODO: use this to properly support hex literals
+ output: &mut [u8],
+ ) -> Result<(), PtxError>
+ where
+ T::Err: Into<PtxError>,
+ {
+ let typed_output = unsafe {
+ std::slice::from_raw_parts_mut::<T>(
+ output.as_mut_ptr() as *mut _,
+ output.len() / mem::size_of::<T>(),
+ )
+ };
+ typed_output[idx] = str_val.parse::<T>().map_err(|e| e.into())?;
+ Ok(())
+ }
+}
+
+pub enum ArrayOrPointer {
+ Array { dimensions: Vec<u32>, init: Vec<u8> },
+ Pointer,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq, Debug)]
+pub enum LinkingDirective {
+ None,
+ Extern,
+ Visible,
+ Weak,
+ Common,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum TuningDirective {
+ MaxNReg(u32),
+ MaxNtid(u32, u32, u32),
+ ReqNtid(u32, u32, u32),
+ MinNCtaPerSm(u32),
+}
+
+#[repr(u8)]
+#[derive(Clone, Copy, PartialEq, Eq)]
+pub enum ScalarKind {
+ Bit,
+ Unsigned,
+ Signed,
+ Float,
+ Float2,
+ Pred,
+}
+
+impl ScalarType {
+ pub fn kind(self) -> ScalarKind {
+ match self {
+ ScalarType::U8 => ScalarKind::Unsigned,
+ ScalarType::U16 => ScalarKind::Unsigned,
+ ScalarType::U32 => ScalarKind::Unsigned,
+ ScalarType::U64 => ScalarKind::Unsigned,
+ ScalarType::S8 => ScalarKind::Signed,
+ ScalarType::S16 => ScalarKind::Signed,
+ ScalarType::S32 => ScalarKind::Signed,
+ ScalarType::S64 => ScalarKind::Signed,
+ ScalarType::B8 => ScalarKind::Bit,
+ ScalarType::B16 => ScalarKind::Bit,
+ ScalarType::B32 => ScalarKind::Bit,
+ ScalarType::B64 => ScalarKind::Bit,
+ ScalarType::F16 => ScalarKind::Float,
+ ScalarType::F32 => ScalarKind::Float,
+ ScalarType::F64 => ScalarKind::Float,
+ ScalarType::F16x2 => ScalarKind::Float2,
+ ScalarType::Pred => ScalarKind::Pred,
+ }
+ }
+}
+
+pub struct TexDetails {
+ pub geometry: TextureGeometry,
+ pub channel_type: ScalarType,
+ pub coordinate_type: ScalarType,
+ // direct = takes .texref, indirect = takes .u64
+ pub direct: bool,
+}
+
+pub struct SurfaceDetails {
+ pub geometry: TextureGeometry,
+ pub vector: Option<u8>,
+ pub type_: ScalarType,
+ // direct = takes .texref, indirect = takes .u64
+ pub direct: bool,
+}
+
+#[derive(Clone, Copy, PartialEq, Eq)]
+pub enum TextureGeometry {
+ OneD,
+ TwoD,
+ ThreeD,
+ Array1D,
+ Array2D,
+}
+
+#[derive(Clone)]
+pub enum Initializer<ID> {
+ Constant(ImmediateValue),
+ Global(ID, Type),
+ GenericGlobal(ID, Type),
+ Add(Box<(Initializer<ID>, Initializer<ID>)>),
+ Array(Vec<Initializer<ID>>),
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn array_fails_multiple_0_dmiensions() {
+ let inp = NumsOrArrays::Nums(Vec::new());
+ assert!(inp.to_vec(ScalarType::B8, &mut vec![0, 0]).is_err());
+ }
+
+ #[test]
+ fn array_fails_on_empty() {
+ let inp = NumsOrArrays::Nums(Vec::new());
+ assert!(inp.to_vec(ScalarType::B8, &mut vec![0]).is_err());
+ }
+
+ #[test]
+ fn array_auto_sizes_0_dimension() {
+ let inp = NumsOrArrays::Arrays(vec![
+ NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]),
+ NumsOrArrays::Nums(vec![("3", 10), ("4", 10)]),
+ ]);
+ let mut dimensions = vec![0u32, 2];
+ assert_eq!(
+ vec![1u8, 2, 3, 4],
+ inp.to_vec(ScalarType::B8, &mut dimensions).unwrap()
+ );
+ assert_eq!(dimensions, vec![2u32, 2]);
+ }
+
+ #[test]
+ fn array_fails_wrong_structure() {
+ let inp = NumsOrArrays::Arrays(vec![
+ NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]),
+ NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]),
+ ]);
+ let mut dimensions = vec![0u32, 2];
+ assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err());
+ }
+
+ #[test]
+ fn array_fails_too_long_component() {
+ let inp = NumsOrArrays::Arrays(vec![
+ NumsOrArrays::Nums(vec![("1", 10), ("2", 10), ("3", 10)]),
+ NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]),
+ ]);
+ let mut dimensions = vec![0u32, 2];
+ assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err());
+ }
+}
|