diff options
Diffstat (limited to 'ptx/src/ast.rs')
-rw-r--r-- | ptx/src/ast.rs | 1074 |
1 files changed, 0 insertions, 1074 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs deleted file mode 100644 index 358b8ce..0000000 --- a/ptx/src/ast.rs +++ /dev/null @@ -1,1074 +0,0 @@ -use half::f16; -use lalrpop_util::{lexer::Token, ParseError}; -use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; -use std::{marker::PhantomData, num::ParseIntError}; - -#[derive(Debug, thiserror::Error)] -pub enum PtxError { - #[error("{source}")] - ParseInt { - #[from] - source: ParseIntError, - }, - #[error("{source}")] - ParseFloat { - #[from] - source: ParseFloatError, - }, - #[error("")] - Unsupported32Bit, - #[error("")] - SyntaxError, - #[error("")] - NonF32Ftz, - #[error("")] - WrongArrayType, - #[error("")] - WrongVectorElement, - #[error("")] - MultiArrayVariable, - #[error("")] - ZeroDimensionArray, - #[error("")] - ArrayInitalizer, - #[error("")] - NonExternPointer, - #[error("{start}:{end}")] - UnrecognizedStatement { start: usize, end: usize }, - #[error("{start}:{end}")] - UnrecognizedDirective { start: usize, end: usize }, -} - -// For some weird reson this is illegal: -// .param .f16x2 foobar; -// but this is legal: -// .param .f16x2 foobar[1]; -// even more interestingly this is legal, but only in .func (not in .entry): -// .param .b32 foobar[] - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum BarDetails { - SyncAligned, -} - -pub trait UnwrapWithVec<E, To> { - fn unwrap_with(self, errs: &mut Vec<E>) -> To; -} - -impl<R: Default, EFrom: std::convert::Into<EInto>, EInto> UnwrapWithVec<EInto, R> - for Result<R, EFrom> -{ - fn unwrap_with(self, errs: &mut Vec<EInto>) -> R { - self.unwrap_or_else(|e| { - errs.push(e.into()); - R::default() - }) - } -} - -impl< - R1: Default, - EFrom1: std::convert::Into<EInto>, - R2: Default, - EFrom2: std::convert::Into<EInto>, - EInto, - > UnwrapWithVec<EInto, (R1, R2)> for (Result<R1, EFrom1>, Result<R2, EFrom2>) -{ - fn unwrap_with(self, errs: &mut Vec<EInto>) -> (R1, R2) { - let (x, y) = self; - let r1 = x.unwrap_with(errs); - let r2 = y.unwrap_with(errs); - (r1, r2) - } -} - -pub struct Module<'a> { - pub version: (u8, u8), - pub directives: Vec<Directive<'a, ParsedArgParams<'a>>>, -} - -pub enum Directive<'a, P: ArgParams> { - Variable(LinkingDirective, Variable<P::Id>), - Method(LinkingDirective, Function<'a, &'a str, Statement<P>>), -} - -#[derive(Hash, PartialEq, Eq, Copy, Clone)] -pub enum MethodName<'input, ID> { - Kernel(&'input str), - Func(ID), -} - -pub struct MethodDeclaration<'input, ID> { - pub return_arguments: Vec<Variable<ID>>, - pub name: MethodName<'input, ID>, - pub input_arguments: Vec<Variable<ID>>, - pub shared_mem: Option<ID>, -} - -pub struct Function<'a, ID, S> { - pub func_directive: MethodDeclaration<'a, ID>, - pub tuning: Vec<TuningDirective>, - pub body: Option<Vec<S>>, -} - -pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>; - -#[derive(PartialEq, Eq, Clone)] -pub enum Type { - // .param.b32 foo; - // -> OpTypeInt - Scalar(ScalarType), - // .param.v2.b32 foo; - // -> OpTypeVector - Vector(ScalarType, u8), - // .param.b32 foo[4]; - // -> OpTypeArray - Array(ScalarType, Vec<u32>), - /* - Variables of this type almost never exist in the original .ptx and are - usually artificially created. Some examples below: - - extern pointers to the .shared memory in the form: - .extern .shared .b32 shared_mem[]; - which we first parse as - .extern .shared .b32 shared_mem; - and then convert to an additional function parameter: - .param .ptr<.b32.shared> shared_mem; - and do a load at the start of the function (and renames inside fn): - .reg .ptr<.b32.shared> temp; - ld.param.ptr<.b32.shared> temp, [shared_mem]; - note, we don't support non-.shared extern pointers, because there's - zero use for them in the ptxas - - artifical pointers created by stateful conversion, which work - similiarly to the above - - function parameters: - foobar(.param .align 4 .b8 numbers[]) - which get parsed to - foobar(.param .align 4 .b8 numbers) - and then converted to - foobar(.reg .align 4 .ptr<.b8.param> numbers) - - ld/st with offset: - .reg.b32 x; - .param.b64 arg0; - st.param.b32 [arg0+4], x; - Yes, this code is legal and actually emitted by the NV compiler! - We convert the st to: - .reg ptr<.b64.param> temp = ptr_offset(arg0, 4); - st.param.b32 [temp], x; - */ - // .reg ptr<.b64.param> - // -> OpTypePointer Function - Pointer(ScalarType, StateSpace), -} - -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -pub enum ScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F32, - F64, - F16x2, - Pred, -} - -impl ScalarType { - pub fn size_of(self) -> u8 { - match self { - ScalarType::U8 => 1, - ScalarType::S8 => 1, - ScalarType::B8 => 1, - ScalarType::U16 => 2, - ScalarType::S16 => 2, - ScalarType::B16 => 2, - ScalarType::F16 => 2, - ScalarType::U32 => 4, - ScalarType::S32 => 4, - ScalarType::B32 => 4, - ScalarType::F32 => 4, - ScalarType::U64 => 8, - ScalarType::S64 => 8, - ScalarType::B64 => 8, - ScalarType::F64 => 8, - ScalarType::F16x2 => 4, - ScalarType::Pred => 1, - } - } -} - -impl Default for ScalarType { - fn default() -> Self { - ScalarType::B8 - } -} - -pub enum Statement<P: ArgParams> { - Label(P::Id), - Variable(MultiVariable<P::Id>), - Instruction(Option<PredAt<P::Id>>, Instruction<P>), - Block(Vec<Statement<P>>), -} - -pub struct MultiVariable<ID> { - pub var: Variable<ID>, - pub count: Option<u32>, -} - -#[derive(Clone)] -pub struct Variable<ID> { - pub align: Option<u32>, - pub v_type: Type, - pub state_space: StateSpace, - pub name: ID, - pub array_init: Vec<u8>, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum StateSpace { - Reg, - Const, - Global, - Local, - Shared, - Param, - Generic, - Sreg, -} - -pub struct PredAt<ID> { - pub not: bool, - pub label: ID, -} - -pub enum Instruction<P: ArgParams> { - Ld(LdDetails, Arg2Ld<P>), - Mov(MovDetails, Arg2Mov<P>), - Mul(MulDetails, Arg3<P>), - Add(ArithDetails, Arg3<P>), - Setp(SetpData, Arg4Setp<P>), - SetpBool(SetpBoolData, Arg5Setp<P>), - Not(ScalarType, Arg2<P>), - Bra(BraData, Arg1<P>), - Cvt(CvtDetails, Arg2<P>), - Cvta(CvtaDetails, Arg2<P>), - Shl(ScalarType, Arg3<P>), - Shr(ScalarType, Arg3<P>), - St(StData, Arg2St<P>), - Ret(RetData), - Call(CallInst<P>), - Abs(AbsDetails, Arg2<P>), - Mad(MulDetails, Arg4<P>), - Fma(ArithFloat, Arg4<P>), - Or(ScalarType, Arg3<P>), - Sub(ArithDetails, Arg3<P>), - Min(MinMaxDetails, Arg3<P>), - Max(MinMaxDetails, Arg3<P>), - Rcp(RcpDetails, Arg2<P>), - And(ScalarType, Arg3<P>), - Selp(ScalarType, Arg4<P>), - Bar(BarDetails, Arg1Bar<P>), - Atom(AtomDetails, Arg3<P>), - AtomCas(AtomCasDetails, Arg4<P>), - Div(DivDetails, Arg3<P>), - Sqrt(SqrtDetails, Arg2<P>), - Rsqrt(RsqrtDetails, Arg2<P>), - Neg(NegDetails, Arg2<P>), - Sin { flush_to_zero: bool, arg: Arg2<P> }, - Cos { flush_to_zero: bool, arg: Arg2<P> }, - Lg2 { flush_to_zero: bool, arg: Arg2<P> }, - Ex2 { flush_to_zero: bool, arg: Arg2<P> }, - Clz { typ: ScalarType, arg: Arg2<P> }, - Brev { typ: ScalarType, arg: Arg2<P> }, - Popc { typ: ScalarType, arg: Arg2<P> }, - Xor { typ: ScalarType, arg: Arg3<P> }, - Bfe { typ: ScalarType, arg: Arg4<P> }, - Bfi { typ: ScalarType, arg: Arg5<P> }, - Rem { typ: ScalarType, arg: Arg3<P> }, - Prmt { control: u16, arg: Arg3<P> }, - Activemask { arg: Arg1<P> }, - Membar { level: MemScope }, -} - -#[derive(Copy, Clone)] -pub struct MadFloatDesc {} - -#[derive(Copy, Clone)] -pub struct AbsDetails { - pub flush_to_zero: Option<bool>, - pub typ: ScalarType, -} -#[derive(Copy, Clone)] -pub struct RcpDetails { - pub rounding: Option<RoundingMode>, - pub flush_to_zero: Option<bool>, - pub is_f64: bool, -} - -pub struct CallInst<P: ArgParams> { - pub uniform: bool, - pub ret_params: Vec<P::Id>, - pub func: P::Id, - pub param_list: Vec<P::Operand>, -} - -pub trait ArgParams { - type Id; - type Operand; -} - -pub struct ParsedArgParams<'a> { - _marker: PhantomData<&'a ()>, -} - -impl<'a> ArgParams for ParsedArgParams<'a> { - type Id = &'a str; - type Operand = Operand<&'a str>; -} - -pub struct Arg1<P: ArgParams> { - pub src: P::Id, // it is a jump destination, but in terms of operands it is a source operand -} - -pub struct Arg1Bar<P: ArgParams> { - pub src: P::Operand, -} - -pub struct Arg2<P: ArgParams> { - pub dst: P::Operand, - pub src: P::Operand, -} -pub struct Arg2Ld<P: ArgParams> { - pub dst: P::Operand, - pub src: P::Operand, -} - -pub struct Arg2St<P: ArgParams> { - pub src1: P::Operand, - pub src2: P::Operand, -} - -pub struct Arg2Mov<P: ArgParams> { - pub dst: P::Operand, - pub src: P::Operand, -} - -pub struct Arg3<P: ArgParams> { - pub dst: P::Operand, - pub src1: P::Operand, - pub src2: P::Operand, -} - -pub struct Arg4<P: ArgParams> { - pub dst: P::Operand, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, -} - -pub struct Arg4Setp<P: ArgParams> { - pub dst1: P::Id, - pub dst2: Option<P::Id>, - pub src1: P::Operand, - pub src2: P::Operand, -} - -pub struct Arg5<P: ArgParams> { - pub dst: P::Operand, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, - pub src4: P::Operand, -} - -pub struct Arg5Setp<P: ArgParams> { - pub dst1: P::Id, - pub dst2: Option<P::Id>, - pub src1: P::Operand, - pub src2: P::Operand, - pub src3: P::Operand, -} - -#[derive(Copy, Clone)] -pub enum ImmediateValue { - U64(u64), - S64(i64), - F32(f32), - F64(f64), -} - -#[derive(Clone)] -pub enum Operand<Id> { - Reg(Id), - RegOffset(Id, i32), - Imm(ImmediateValue), - VecMember(Id, u8), - VecPack(Vec<Id>), -} - -pub enum VectorPrefix { - V2, - V4, -} - -pub struct LdDetails { - pub qualifier: LdStQualifier, - pub state_space: StateSpace, - pub caching: LdCacheOperator, - pub typ: Type, - pub non_coherent: bool, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdStQualifier { - Weak, - Volatile, - Relaxed(MemScope), - Acquire(MemScope), -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum MemScope { - Cta, - Gpu, - Sys, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum LdCacheOperator { - Cached, - L2Only, - Streaming, - LastUse, - Uncached, -} - -#[derive(Clone)] -pub struct MovDetails { - pub typ: Type, - pub src_is_address: bool, - // two fields below are in use by member moves - pub dst_width: u8, - pub src_width: u8, - // This is in use by auto-generated movs - pub relaxed_src2_conv: bool, -} - -impl MovDetails { - pub fn new(typ: Type) -> Self { - MovDetails { - typ, - src_is_address: false, - dst_width: 0, - src_width: 0, - relaxed_src2_conv: false, - } - } -} - -#[derive(Copy, Clone)] -pub struct MulIntDesc { - pub typ: ScalarType, - pub control: MulIntControl, -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum MulIntControl { - Low, - High, - Wide, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum RoundingMode { - NearestEven, - Zero, - NegativeInf, - PositiveInf, -} - -pub struct AddIntDesc { - pub typ: ScalarType, - pub saturate: bool, -} - -pub struct SetpData { - pub typ: ScalarType, - pub flush_to_zero: Option<bool>, - pub cmp_op: SetpCompareOp, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum SetpCompareOp { - Eq, - NotEq, - Less, - LessOrEq, - Greater, - GreaterOrEq, - NanEq, - NanNotEq, - NanLess, - NanLessOrEq, - NanGreater, - NanGreaterOrEq, - IsNotNan, - IsAnyNan, -} - -pub enum SetpBoolPostOp { - And, - Or, - Xor, -} - -pub struct SetpBoolData { - pub typ: ScalarType, - pub flush_to_zero: Option<bool>, - pub cmp_op: SetpCompareOp, - pub bool_op: SetpBoolPostOp, -} - -pub struct BraData { - pub uniform: bool, -} - -pub enum CvtDetails { - IntFromInt(CvtIntToIntDesc), - FloatFromFloat(CvtDesc), - IntFromFloat(CvtDesc), - FloatFromInt(CvtDesc), -} - -pub struct CvtIntToIntDesc { - pub dst: ScalarType, - pub src: ScalarType, - pub saturate: bool, -} - -pub struct CvtDesc { - pub rounding: Option<RoundingMode>, - pub flush_to_zero: Option<bool>, - pub saturate: bool, - pub dst: ScalarType, - pub src: ScalarType, -} - -impl CvtDetails { - pub fn new_int_from_int_checked<'err, 'input>( - saturate: bool, - dst: ScalarType, - src: ScalarType, - err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>, - ) -> Self { - if saturate { - if src.kind() == ScalarKind::Signed { - if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() { - err.push(ParseError::User { - error: PtxError::SyntaxError, - }); - } - } else { - if dst == src || dst.size_of() >= src.size_of() { - err.push(ParseError::User { - error: PtxError::SyntaxError, - }); - } - } - } - CvtDetails::IntFromInt(CvtIntToIntDesc { dst, src, saturate }) - } - - pub fn new_float_from_int_checked<'err, 'input>( - rounding: RoundingMode, - flush_to_zero: bool, - saturate: bool, - dst: ScalarType, - src: ScalarType, - err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>, - ) -> Self { - if flush_to_zero && dst != ScalarType::F32 { - err.push(ParseError::from(lalrpop_util::ParseError::User { - error: PtxError::NonF32Ftz, - })); - } - CvtDetails::FloatFromInt(CvtDesc { - dst, - src, - saturate, - flush_to_zero: Some(flush_to_zero), - rounding: Some(rounding), - }) - } - - pub fn new_int_from_float_checked<'err, 'input>( - rounding: RoundingMode, - flush_to_zero: bool, - saturate: bool, - dst: ScalarType, - src: ScalarType, - err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>, - ) -> Self { - if flush_to_zero && src != ScalarType::F32 { - err.push(ParseError::from(lalrpop_util::ParseError::User { - error: PtxError::NonF32Ftz, - })); - } - CvtDetails::IntFromFloat(CvtDesc { - dst, - src, - saturate, - flush_to_zero: Some(flush_to_zero), - rounding: Some(rounding), - }) - } -} - -pub struct CvtaDetails { - pub to: StateSpace, - pub from: StateSpace, - pub size: CvtaSize, -} - -pub enum CvtaSize { - U32, - U64, -} - -pub struct StData { - pub qualifier: LdStQualifier, - pub state_space: StateSpace, - pub caching: StCacheOperator, - pub typ: Type, -} - -#[derive(PartialEq, Eq)] -pub enum StCacheOperator { - Writeback, - L2Only, - Streaming, - Writethrough, -} - -pub struct RetData { - pub uniform: bool, -} - -#[derive(Copy, Clone)] -pub enum MulDetails { - Unsigned(MulUInt), - Signed(MulSInt), - Float(ArithFloat), -} - -#[derive(Copy, Clone)] -pub struct MulUInt { - pub typ: ScalarType, - pub control: MulIntControl, -} - -#[derive(Copy, Clone)] -pub struct MulSInt { - pub typ: ScalarType, - pub control: MulIntControl, -} - -#[derive(Copy, Clone)] -pub enum ArithDetails { - Unsigned(ScalarType), - Signed(ArithSInt), - Float(ArithFloat), -} - -#[derive(Copy, Clone)] -pub struct ArithSInt { - pub typ: ScalarType, - pub saturate: bool, -} - -#[derive(Copy, Clone)] -pub struct ArithFloat { - pub typ: ScalarType, - pub rounding: Option<RoundingMode>, - pub flush_to_zero: Option<bool>, - pub saturate: bool, -} - -#[derive(Copy, Clone)] -pub enum MinMaxDetails { - Signed(ScalarType), - Unsigned(ScalarType), - Float(MinMaxFloat), -} - -#[derive(Copy, Clone)] -pub struct MinMaxFloat { - pub flush_to_zero: Option<bool>, - pub nan: bool, - pub typ: ScalarType, -} - -#[derive(Copy, Clone)] -pub struct AtomDetails { - pub semantics: AtomSemantics, - pub scope: MemScope, - pub space: StateSpace, - pub inner: AtomInnerDetails, -} - -#[derive(Copy, Clone)] -pub enum AtomSemantics { - Relaxed, - Acquire, - Release, - AcquireRelease, -} - -#[derive(Copy, Clone)] -pub enum AtomInnerDetails { - Bit { op: AtomBitOp, typ: ScalarType }, - Unsigned { op: AtomUIntOp, typ: ScalarType }, - Signed { op: AtomSIntOp, typ: ScalarType }, - Float { op: AtomFloatOp, typ: ScalarType }, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomBitOp { - And, - Or, - Xor, - Exchange, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomUIntOp { - Add, - Inc, - Dec, - Min, - Max, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomSIntOp { - Add, - Min, - Max, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum AtomFloatOp { - Add, -} - -#[derive(Copy, Clone)] -pub struct AtomCasDetails { - pub semantics: AtomSemantics, - pub scope: MemScope, - pub space: StateSpace, - pub typ: ScalarType, -} - -#[derive(Copy, Clone)] -pub enum DivDetails { - Unsigned(ScalarType), - Signed(ScalarType), - Float(DivFloatDetails), -} - -#[derive(Copy, Clone)] -pub struct DivFloatDetails { - pub typ: ScalarType, - pub flush_to_zero: Option<bool>, - pub kind: DivFloatKind, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum DivFloatKind { - Approx, - Full, - Rounding(RoundingMode), -} - -pub enum NumsOrArrays<'a> { - Nums(Vec<(&'a str, u32)>), - Arrays(Vec<NumsOrArrays<'a>>), -} - -#[derive(Copy, Clone)] -pub struct SqrtDetails { - pub typ: ScalarType, - pub flush_to_zero: Option<bool>, - pub kind: SqrtKind, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub enum SqrtKind { - Approx, - Rounding(RoundingMode), -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub struct RsqrtDetails { - pub typ: ScalarType, - pub flush_to_zero: bool, -} - -#[derive(Copy, Clone, Eq, PartialEq)] -pub struct NegDetails { - pub typ: ScalarType, - pub flush_to_zero: Option<bool>, -} - -impl<'a> NumsOrArrays<'a> { - pub fn to_vec(self, typ: ScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> { - self.normalize_dimensions(dimensions)?; - let sizeof_t = ScalarType::from(typ).size_of() as usize; - let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize)); - let mut result = vec![0; result_size]; - self.parse_and_copy(typ, sizeof_t, dimensions, &mut result)?; - Ok(result) - } - - fn normalize_dimensions(&self, dimensions: &mut [u32]) -> Result<(), PtxError> { - match dimensions.first_mut() { - Some(first) => { - if *first == 0 { - *first = match self { - NumsOrArrays::Nums(v) => v.len() as u32, - NumsOrArrays::Arrays(v) => v.len() as u32, - }; - } - } - None => return Err(PtxError::ZeroDimensionArray), - } - for dim in dimensions { - if *dim == 0 { - return Err(PtxError::ZeroDimensionArray); - } - } - Ok(()) - } - - fn parse_and_copy( - &self, - t: ScalarType, - size_of_t: usize, - dimensions: &[u32], - result: &mut [u8], - ) -> Result<(), PtxError> { - match dimensions { - [] => unreachable!(), - [dim] => match self { - NumsOrArrays::Nums(vec) => { - if vec.len() > *dim as usize { - return Err(PtxError::ZeroDimensionArray); - } - for (idx, (val, radix)) in vec.iter().enumerate() { - Self::parse_and_copy_single(t, idx, val, *radix, result)?; - } - } - NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray), - }, - [first_dim, rest @ ..] => match self { - NumsOrArrays::Arrays(vec) => { - if vec.len() > *first_dim as usize { - return Err(PtxError::ZeroDimensionArray); - } - let size_of_element = rest.iter().fold(size_of_t, |x, y| x * (*y as usize)); - for (idx, this) in vec.iter().enumerate() { - this.parse_and_copy( - t, - size_of_t, - rest, - &mut result[(size_of_element * idx)..], - )?; - } - } - NumsOrArrays::Nums(_) => return Err(PtxError::ZeroDimensionArray), - }, - } - Ok(()) - } - - fn parse_and_copy_single( - t: ScalarType, - idx: usize, - str_val: &str, - radix: u32, - output: &mut [u8], - ) -> Result<(), PtxError> { - match t { - ScalarType::B8 | ScalarType::U8 => { - Self::parse_and_copy_single_t::<u8>(idx, str_val, radix, output)?; - } - ScalarType::B16 | ScalarType::U16 => { - Self::parse_and_copy_single_t::<u16>(idx, str_val, radix, output)?; - } - ScalarType::B32 | ScalarType::U32 => { - Self::parse_and_copy_single_t::<u32>(idx, str_val, radix, output)?; - } - ScalarType::B64 | ScalarType::U64 => { - Self::parse_and_copy_single_t::<u64>(idx, str_val, radix, output)?; - } - ScalarType::S8 => { - Self::parse_and_copy_single_t::<i8>(idx, str_val, radix, output)?; - } - ScalarType::S16 => { - Self::parse_and_copy_single_t::<i16>(idx, str_val, radix, output)?; - } - ScalarType::S32 => { - Self::parse_and_copy_single_t::<i32>(idx, str_val, radix, output)?; - } - ScalarType::S64 => { - Self::parse_and_copy_single_t::<i64>(idx, str_val, radix, output)?; - } - ScalarType::F16 => { - Self::parse_and_copy_single_t::<f16>(idx, str_val, radix, output)?; - } - ScalarType::F16x2 => todo!(), - ScalarType::F32 => { - Self::parse_and_copy_single_t::<f32>(idx, str_val, radix, output)?; - } - ScalarType::F64 => { - Self::parse_and_copy_single_t::<f64>(idx, str_val, radix, output)?; - } - ScalarType::Pred => todo!(), - } - Ok(()) - } - - fn parse_and_copy_single_t<T: Copy + FromStr>( - idx: usize, - str_val: &str, - _radix: u32, // TODO: use this to properly support hex literals - output: &mut [u8], - ) -> Result<(), PtxError> - where - T::Err: Into<PtxError>, - { - let typed_output = unsafe { - std::slice::from_raw_parts_mut::<T>( - output.as_mut_ptr() as *mut _, - output.len() / mem::size_of::<T>(), - ) - }; - typed_output[idx] = str_val.parse::<T>().map_err(|e| e.into())?; - Ok(()) - } -} - -pub enum ArrayOrPointer { - Array { dimensions: Vec<u32>, init: Vec<u8> }, - Pointer, -} - -bitflags! { - pub struct LinkingDirective: u8 { - const NONE = 0b000; - const EXTERN = 0b001; - const VISIBLE = 0b10; - const WEAK = 0b100; - } -} - -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum TuningDirective { - MaxNReg(u32), - MaxNtid(u32, u32, u32), - ReqNtid(u32, u32, u32), - MinNCtaPerSm(u32), -} - -#[derive(Clone, Copy, PartialEq, Eq)] -pub enum ScalarKind { - Bit, - Unsigned, - Signed, - Float, - Float2, - Pred, -} - -impl ScalarType { - pub fn kind(self) -> ScalarKind { - match self { - ScalarType::U8 => ScalarKind::Unsigned, - ScalarType::U16 => ScalarKind::Unsigned, - ScalarType::U32 => ScalarKind::Unsigned, - ScalarType::U64 => ScalarKind::Unsigned, - ScalarType::S8 => ScalarKind::Signed, - ScalarType::S16 => ScalarKind::Signed, - ScalarType::S32 => ScalarKind::Signed, - ScalarType::S64 => ScalarKind::Signed, - ScalarType::B8 => ScalarKind::Bit, - ScalarType::B16 => ScalarKind::Bit, - ScalarType::B32 => ScalarKind::Bit, - ScalarType::B64 => ScalarKind::Bit, - ScalarType::F16 => ScalarKind::Float, - ScalarType::F32 => ScalarKind::Float, - ScalarType::F64 => ScalarKind::Float, - ScalarType::F16x2 => ScalarKind::Float2, - ScalarType::Pred => ScalarKind::Pred, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn array_fails_multiple_0_dmiensions() { - let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(ScalarType::B8, &mut vec![0, 0]).is_err()); - } - - #[test] - fn array_fails_on_empty() { - let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(ScalarType::B8, &mut vec![0]).is_err()); - } - - #[test] - fn array_auto_sizes_0_dimension() { - let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]), - NumsOrArrays::Nums(vec![("3", 10), ("4", 10)]), - ]); - let mut dimensions = vec![0u32, 2]; - assert_eq!( - vec![1u8, 2, 3, 4], - inp.to_vec(ScalarType::B8, &mut dimensions).unwrap() - ); - assert_eq!(dimensions, vec![2u32, 2]); - } - - #[test] - fn array_fails_wrong_structure() { - let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec![("1", 10), ("2", 10)]), - NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]), - ]); - let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); - } - - #[test] - fn array_fails_too_long_component() { - let inp = NumsOrArrays::Arrays(vec![ - NumsOrArrays::Nums(vec![("1", 10), ("2", 10), ("3", 10)]), - NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]), - ]); - let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); - } -} |