use super::{ AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp, StateSpace, VectorPrefix, }; use crate::{PtxError, PtxParserState}; use bitflags::bitflags; use std::{cmp::Ordering, num::NonZeroU8}; pub enum Statement { Label(P::Ident), Variable(MultiVariable), Instruction(Option>, Instruction

), Block(Vec>), } // We define the instruction enum through the macro instead of normally, because we have some of how // we use this type in the compilee. Each instruction can be logically split into two parts: // properties that define instruction semantics (e.g. is memory load volatile?) that don't change // during compilation and arguments (e.g. memory load source and destination) that evolve during // compilation. To support compilation passes we need to be able to visit (and change) every // argument in a generic way. This macro has visibility over all the fields. Consequently, we use it // to generate visitor functions. There re three functions to support three different semantics: // visit-by-ref, visit-by-mutable-ref, visit-and-map. In a previous version of the compiler it was // done by hand and was very limiting (we supported only visit-and-map). // The visitor must implement appropriate visitor trait defined below this macro. For convenience, // we implemented visitors for some corresponding FnMut(...) types. // Properties in this macro are used to encode information about the instruction arguments (what // Rust type is used for it post-parsing, what PTX type does it expect, what PTX address space does // it expect, etc.). // This information is then available to a visitor. ptx_parser_macros::generate_instruction_type!( pub enum Instruction { Mov { type: { &data.typ }, data: MovDetails, arguments: { dst: T, src: T } }, Ld { type: { &data.typ }, data: LdDetails, arguments: { dst: { repr: T, relaxed_type_check: true, }, src: { repr: T, space: { data.state_space }, } } }, Add { type: { Type::from(data.type_()) }, data: ArithDetails, arguments: { dst: T, src1: T, src2: T, } }, St { type: { &data.typ }, data: StData, arguments: { src1: { repr: T, space: { data.state_space }, }, src2: { repr: T, relaxed_type_check: true, } } }, Mul { type: { Type::from(data.type_()) }, data: MulDetails, arguments: { dst: { repr: T, type: { Type::from(data.dst_type()) }, }, src1: T, src2: T, } }, Setp { data: SetpData, arguments: { dst1: { repr: T, type: Type::from(ScalarType::Pred) }, dst2: { repr: Option, type: Type::from(ScalarType::Pred) }, src1: { repr: T, type: Type::from(data.type_), }, src2: { repr: T, type: Type::from(data.type_), } } }, SetpBool { data: SetpBoolData, arguments: { dst1: { repr: T, type: Type::from(ScalarType::Pred) }, dst2: { repr: Option, type: Type::from(ScalarType::Pred) }, src1: { repr: T, type: Type::from(data.base.type_), }, src2: { repr: T, type: Type::from(data.base.type_), }, src3: { repr: T, type: Type::from(ScalarType::Pred) } } }, Not { data: ScalarType, type: { Type::Scalar(data.clone()) }, arguments: { dst: T, src: T, } }, Or { data: ScalarType, type: { Type::Scalar(data.clone()) }, arguments: { dst: T, src1: T, src2: T, } }, And { data: ScalarType, type: { Type::Scalar(data.clone()) }, arguments: { dst: T, src1: T, src2: T, } }, Bra { type: !, arguments: { src: T } }, Call { data: CallDetails, arguments: CallArgs, visit: arguments.visit(data, visitor)?, visit_mut: arguments.visit_mut(data, visitor)?, map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data } }, Cvt { data: CvtDetails, arguments: { dst: { repr: T, type: { Type::Scalar(data.to) }, // TODO: double check relaxed_type_check: true, }, src: { repr: T, type: { Type::Scalar(data.from) }, relaxed_type_check: true, }, } }, Shr { data: ShrData, type: { Type::Scalar(data.type_.clone()) }, arguments: { dst: T, src1: T, src2: { repr: T, type: { Type::Scalar(ScalarType::U32) }, }, } }, Shl { data: ScalarType, type: { Type::Scalar(data.clone()) }, arguments: { dst: T, src1: T, src2: { repr: T, type: { Type::Scalar(ScalarType::U32) }, }, } }, Ret { data: RetData }, Cvta { data: CvtaDetails, type: { Type::Scalar(ScalarType::B64) }, arguments: { dst: T, src: T, } }, Abs { data: TypeFtz, type: { Type::Scalar(data.type_) }, arguments: { dst: T, src: T, } }, Mad { type: { Type::from(data.type_()) }, data: MadDetails, arguments: { dst: { repr: T, type: { Type::from(data.dst_type()) }, }, src1: T, src2: T, src3: T, } }, Fma { type: { Type::from(data.type_) }, data: ArithFloat, arguments: { dst: T, src1: T, src2: T, src3: T, } }, Sub { type: { Type::from(data.type_()) }, data: ArithDetails, arguments: { dst: T, src1: T, src2: T, } }, Min { type: { Type::from(data.type_()) }, data: MinMaxDetails, arguments: { dst: T, src1: T, src2: T, } }, Max { type: { Type::from(data.type_()) }, data: MinMaxDetails, arguments: { dst: T, src1: T, src2: T, } }, Rcp { type: { Type::from(data.type_) }, data: RcpData, arguments: { dst: T, src: T, } }, Sqrt { type: { Type::from(data.type_) }, data: RcpData, arguments: { dst: T, src: T, } }, Rsqrt { type: { Type::from(data.type_) }, data: TypeFtz, arguments: { dst: T, src: T, } }, Selp { type: { Type::Scalar(data.clone()) }, data: ScalarType, arguments: { dst: T, src1: T, src2: T, src3: { repr: T, type: Type::Scalar(ScalarType::Pred) }, } }, Bar { type: Type::Scalar(ScalarType::U32), data: BarData, arguments: { src1: T, src2: Option, } }, Atom { type: &data.type_, data: AtomDetails, arguments: { dst: T, src1: { repr: T, space: { data.space }, }, src2: T, } }, AtomCas { type: Type::Scalar(data.type_), data: AtomCasDetails, arguments: { dst: T, src1: { repr: T, space: { data.space }, }, src2: T, src3: T, } }, Div { type: Type::Scalar(data.type_()), data: DivDetails, arguments: { dst: T, src1: T, src2: T, } }, Neg { type: Type::Scalar(data.type_), data: TypeFtz, arguments: { dst: T, src: T } }, Sin { type: Type::Scalar(ScalarType::F32), data: FlushToZero, arguments: { dst: T, src: T } }, Cos { type: Type::Scalar(ScalarType::F32), data: FlushToZero, arguments: { dst: T, src: T } }, Lg2 { type: Type::Scalar(ScalarType::F32), data: FlushToZero, arguments: { dst: T, src: T } }, Ex2 { type: Type::Scalar(ScalarType::F32), data: TypeFtz, arguments: { dst: T, src: T } }, Clz { type: Type::Scalar(data.clone()), data: ScalarType, arguments: { dst: { repr: T, type: Type::Scalar(ScalarType::U32) }, src: T } }, Brev { type: Type::Scalar(data.clone()), data: ScalarType, arguments: { dst: T, src: T } }, Popc { type: Type::Scalar(data.clone()), data: ScalarType, arguments: { dst: { repr: T, type: Type::Scalar(ScalarType::U32) }, src: T } }, Xor { type: Type::Scalar(data.clone()), data: ScalarType, arguments: { dst: T, src1: T, src2: T } }, Rem { type: Type::Scalar(data.clone()), data: ScalarType, arguments: { dst: T, src1: T, src2: T } }, Bfe { type: Type::Scalar(data.clone()), data: ScalarType, arguments: { dst: T, src1: T, src2: { repr: T, type: Type::Scalar(ScalarType::U32) }, src3: { repr: T, type: Type::Scalar(ScalarType::U32) }, } }, Bfi { type: Type::Scalar(data.clone()), data: ScalarType, arguments: { dst: T, src1: T, src2: T, src3: { repr: T, type: Type::Scalar(ScalarType::U32) }, src4: { repr: T, type: Type::Scalar(ScalarType::U32) }, } }, PrmtSlow { type: Type::Scalar(ScalarType::U32), arguments: { dst: T, src1: T, src2: T, src3: T } }, Prmt { type: Type::Scalar(ScalarType::B32), data: u16, arguments: { dst: T, src1: T, src2: T } }, Activemask { type: Type::Scalar(ScalarType::B32), arguments: { dst: T } }, Membar { data: MemScope }, Trap { } } ); pub trait Visitor { fn visit( &mut self, args: &T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result<(), Err>; fn visit_ident( &mut self, args: &T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result<(), Err>; } impl< T: Operand, Err, Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>, > Visitor for Fn { fn visit( &mut self, args: &T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result<(), Err> { (self)(args, type_space, is_dst, relaxed_type_check) } fn visit_ident( &mut self, args: &T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result<(), Err> { (self)( &T::from_ident(*args), type_space, is_dst, relaxed_type_check, ) } } pub trait VisitorMut { fn visit( &mut self, args: &mut T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result<(), Err>; fn visit_ident( &mut self, args: &mut T::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result<(), Err>; } pub trait VisitorMap { fn visit( &mut self, args: From, type_space: Option<(&Type, StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result; fn visit_ident( &mut self, args: From::Ident, type_space: Option<(&Type, StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result; } impl VisitorMap, ParsedOperand, Err> for Fn where Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result, { fn visit( &mut self, args: ParsedOperand, type_space: Option<(&Type, StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result, Err> { Ok(match args { ParsedOperand::Reg(ident) => { ParsedOperand::Reg((self)(ident, type_space, is_dst, relaxed_type_check)?) } ParsedOperand::RegOffset(ident, imm) => ParsedOperand::RegOffset( (self)(ident, type_space, is_dst, relaxed_type_check)?, imm, ), ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm), ParsedOperand::VecMember(ident, index) => ParsedOperand::VecMember( (self)(ident, type_space, is_dst, relaxed_type_check)?, index, ), ParsedOperand::VecPack(vec) => ParsedOperand::VecPack( vec.into_iter() .map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check)) .collect::, _>>()?, ), }) } fn visit_ident( &mut self, args: T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result { (self)(args, type_space, is_dst, relaxed_type_check) } } impl, U: Operand, Err, Fn> VisitorMap for Fn where Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result, { fn visit( &mut self, args: T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result { (self)(args, type_space, is_dst, relaxed_type_check) } fn visit_ident( &mut self, args: T, type_space: Option<(&Type, StateSpace)>, is_dst: bool, relaxed_type_check: bool, ) -> Result { (self)(args, type_space, is_dst, relaxed_type_check) } } trait VisitOperand { type Operand: Operand; #[allow(unused)] // Used by generated code fn visit(&self, fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err>; #[allow(unused)] // Used by generated code fn visit_mut( &mut self, fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, ) -> Result<(), Err>; } impl VisitOperand for T { type Operand = Self; fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { fn_(self) } fn visit_mut( &mut self, mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, ) -> Result<(), Err> { fn_(self) } } impl VisitOperand for Option { type Operand = T; fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { if let Some(x) = self { fn_(x)?; } Ok(()) } fn visit_mut( &mut self, mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, ) -> Result<(), Err> { if let Some(x) = self { fn_(x)?; } Ok(()) } } impl VisitOperand for Vec { type Operand = T; fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> { for o in self { fn_(o)?; } Ok(()) } fn visit_mut( &mut self, mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>, ) -> Result<(), Err> { for o in self { fn_(o)?; } Ok(()) } } trait MapOperand: Sized { type Input; type Output; #[allow(unused)] // Used by generated code fn map( self, fn_: impl FnOnce(Self::Input) -> Result, ) -> Result, Err>; } impl MapOperand for T { type Input = Self; type Output = U; fn map(self, fn_: impl FnOnce(T) -> Result) -> Result { fn_(self) } } impl MapOperand for Option { type Input = T; type Output = Option; fn map(self, fn_: impl FnOnce(T) -> Result) -> Result, Err> { self.map(|x| fn_(x)).transpose() } } pub struct MultiVariable { pub var: Variable, pub count: Option, } #[derive(Clone)] pub struct Variable { pub align: Option, pub v_type: Type, pub state_space: StateSpace, pub name: ID, pub array_init: Vec, } pub struct PredAt { pub not: bool, pub label: ID, } #[derive(PartialEq, Eq, Clone, Hash)] pub enum Type { // .param.b32 foo; Scalar(ScalarType), // .param.v2.b32 foo; Vector(u8, ScalarType), // .param.b32 foo[4]; Array(Option, ScalarType, Vec), Pointer(ScalarType, StateSpace), } impl Type { pub(crate) fn maybe_vector(vector: Option, scalar: ScalarType) -> Self { match vector { Some(prefix) => Type::Vector(prefix.len().get(), scalar), None => Type::Scalar(scalar), } } pub(crate) fn maybe_vector_parsed(prefix: Option, scalar: ScalarType) -> Self { match prefix { Some(prefix) => Type::Vector(prefix.get(), scalar), None => Type::Scalar(scalar), } } pub(crate) fn maybe_array( prefix: Option, scalar: ScalarType, array: Option>, ) -> Self { match array { Some(dimensions) => Type::Array(prefix, scalar, dimensions), None => Self::maybe_vector_parsed(prefix, scalar), } } } impl ScalarType { pub fn size_of(self) -> u8 { match self { ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1, ScalarType::U16 | ScalarType::S16 | ScalarType::B16 | ScalarType::F16 | ScalarType::BF16 => 2, ScalarType::U32 | ScalarType::S32 | ScalarType::B32 | ScalarType::F32 | ScalarType::U16x2 | ScalarType::S16x2 | ScalarType::F16x2 | ScalarType::BF16x2 => 4, ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => 8, ScalarType::B128 => 16, ScalarType::Pred => 1, } } pub fn kind(self) -> ScalarKind { match self { ScalarType::U8 => ScalarKind::Unsigned, ScalarType::U16 => ScalarKind::Unsigned, ScalarType::U16x2 => ScalarKind::Unsigned, ScalarType::U32 => ScalarKind::Unsigned, ScalarType::U64 => ScalarKind::Unsigned, ScalarType::S8 => ScalarKind::Signed, ScalarType::S16 => ScalarKind::Signed, ScalarType::S16x2 => ScalarKind::Signed, ScalarType::S32 => ScalarKind::Signed, ScalarType::S64 => ScalarKind::Signed, ScalarType::B8 => ScalarKind::Bit, ScalarType::B16 => ScalarKind::Bit, ScalarType::B32 => ScalarKind::Bit, ScalarType::B64 => ScalarKind::Bit, ScalarType::B128 => ScalarKind::Bit, ScalarType::F16 => ScalarKind::Float, ScalarType::F16x2 => ScalarKind::Float, ScalarType::F32 => ScalarKind::Float, ScalarType::F64 => ScalarKind::Float, ScalarType::BF16 => ScalarKind::Float, ScalarType::BF16x2 => ScalarKind::Float, ScalarType::Pred => ScalarKind::Pred, } } } #[derive(Clone, Copy, PartialEq, Eq)] pub enum ScalarKind { Bit, Unsigned, Signed, Float, Pred, } impl From for Type { fn from(value: ScalarType) -> Self { Type::Scalar(value) } } #[derive(Clone)] pub struct MovDetails { pub typ: super::Type, pub src_is_address: bool, // two fields below are in use by member moves pub dst_width: u8, pub src_width: u8, // This is in use by auto-generated movs pub relaxed_src2_conv: bool, } impl MovDetails { pub(crate) fn new(vector: Option, scalar: ScalarType) -> Self { MovDetails { typ: Type::maybe_vector(vector, scalar), src_is_address: false, dst_width: 0, src_width: 0, relaxed_src2_conv: false, } } } #[derive(Clone)] pub enum ParsedOperand { Reg(Ident), RegOffset(Ident, i32), Imm(ImmediateValue), VecMember(Ident, u8), VecPack(Vec), } impl Operand for ParsedOperand { type Ident = Ident; fn from_ident(ident: Self::Ident) -> Self { ParsedOperand::Reg(ident) } } pub trait Operand: Sized { type Ident: Copy; fn from_ident(ident: Self::Ident) -> Self; } #[derive(Copy, Clone)] pub enum ImmediateValue { U64(u64), S64(i64), F32(f32), F64(f64), } #[derive(Copy, Clone, PartialEq, Eq)] pub enum StCacheOperator { Writeback, L2Only, Streaming, Writethrough, } #[derive(Copy, Clone, PartialEq, Eq)] pub enum LdCacheOperator { Cached, L2Only, Streaming, LastUse, Uncached, } #[derive(Copy, Clone)] pub enum ArithDetails { Integer(ArithInteger), Float(ArithFloat), } impl ArithDetails { pub fn type_(&self) -> ScalarType { match self { ArithDetails::Integer(t) => t.type_, ArithDetails::Float(arith) => arith.type_, } } } #[derive(Copy, Clone)] pub struct ArithInteger { pub type_: ScalarType, pub saturate: bool, } #[derive(Copy, Clone)] pub struct ArithFloat { pub type_: ScalarType, pub rounding: Option, pub flush_to_zero: Option, pub saturate: bool, } #[derive(Copy, Clone, PartialEq, Eq)] pub enum LdStQualifier { Weak, Volatile, Relaxed(MemScope), Acquire(MemScope), Release(MemScope), } #[derive(PartialEq, Eq, Copy, Clone)] pub enum RoundingMode { NearestEven, Zero, NegativeInf, PositiveInf, } pub struct LdDetails { pub qualifier: LdStQualifier, pub state_space: StateSpace, pub caching: LdCacheOperator, pub typ: Type, pub non_coherent: bool, } pub struct StData { pub qualifier: LdStQualifier, pub state_space: StateSpace, pub caching: StCacheOperator, pub typ: Type, } #[derive(Copy, Clone)] pub struct RetData { pub uniform: bool, } #[derive(Copy, Clone, PartialEq, Eq)] pub enum TuningDirective { MaxNReg(u32), MaxNtid(u32, u32, u32), ReqNtid(u32, u32, u32), MinNCtaPerSm(u32), } pub struct MethodDeclaration<'input, ID> { pub return_arguments: Vec>, pub name: MethodName<'input, ID>, pub input_arguments: Vec>, pub shared_mem: Option, } impl<'input> MethodDeclaration<'input, &'input str> { pub fn name(&self) -> &'input str { match self.name { MethodName::Kernel(n) => n, MethodName::Func(n) => n, } } } #[derive(Hash, PartialEq, Eq, Copy, Clone)] pub enum MethodName<'input, ID> { Kernel(&'input str), Func(ID), } impl<'input, ID> MethodName<'input, ID> { pub fn is_kernel(&self) -> bool { match self { MethodName::Kernel(_) => true, MethodName::Func(_) => false, } } } impl<'input> MethodName<'input, &'input str> { pub fn text(&self) -> &'input str { match self { MethodName::Kernel(name) => *name, MethodName::Func(name) => *name, } } } bitflags! { pub struct LinkingDirective: u8 { const NONE = 0b000; const EXTERN = 0b001; const VISIBLE = 0b10; const WEAK = 0b100; } } pub struct Function<'a, ID, S> { pub func_directive: MethodDeclaration<'a, ID>, pub tuning: Vec, pub body: Option>, } pub enum Directive<'input, O: Operand> { Variable(LinkingDirective, Variable), Method( LinkingDirective, Function<'input, &'input str, Statement>, ), } pub struct Module<'input> { pub version: (u8, u8), pub directives: Vec>>, } #[derive(Copy, Clone)] pub enum MulDetails { Integer { type_: ScalarType, control: MulIntControl, }, Float(ArithFloat), } impl MulDetails { pub fn type_(&self) -> ScalarType { match self { MulDetails::Integer { type_, .. } => *type_, MulDetails::Float(arith) => arith.type_, } } pub fn dst_type(&self) -> ScalarType { match self { MulDetails::Integer { type_, control: MulIntControl::Wide, } => match type_ { ScalarType::U16 => ScalarType::U32, ScalarType::S16 => ScalarType::S32, ScalarType::U32 => ScalarType::U64, ScalarType::S32 => ScalarType::S64, _ => unreachable!(), }, _ => self.type_(), } } } #[derive(Copy, Clone, PartialEq, Eq)] pub enum MulIntControl { Low, High, Wide, } pub struct SetpData { pub type_: ScalarType, pub flush_to_zero: Option, pub cmp_op: SetpCompareOp, } impl SetpData { pub(crate) fn try_parse( state: &mut PtxParserState, cmp_op: super::RawSetpCompareOp, ftz: bool, type_: ScalarType, ) -> Self { let flush_to_zero = match (ftz, type_) { (_, ScalarType::F32) => Some(ftz), (true, _) => { state.errors.push(PtxError::NonF32Ftz); None } _ => None, }; let type_kind = type_.kind(); let cmp_op = if type_kind == ScalarKind::Float { SetpCompareOp::Float(SetpCompareFloat::from(cmp_op)) } else { match SetpCompareInt::try_from((cmp_op, type_kind)) { Ok(op) => SetpCompareOp::Integer(op), Err(err) => { state.errors.push(err); SetpCompareOp::Integer(SetpCompareInt::Eq) } } }; Self { type_, flush_to_zero, cmp_op, } } } pub struct SetpBoolData { pub base: SetpData, pub bool_op: SetpBoolPostOp, pub negate_src3: bool, } #[derive(PartialEq, Eq, Copy, Clone)] pub enum SetpCompareOp { Integer(SetpCompareInt), Float(SetpCompareFloat), } #[derive(PartialEq, Eq, Copy, Clone)] pub enum SetpCompareInt { Eq, NotEq, UnsignedLess, UnsignedLessOrEq, UnsignedGreater, UnsignedGreaterOrEq, SignedLess, SignedLessOrEq, SignedGreater, SignedGreaterOrEq, } #[derive(PartialEq, Eq, Copy, Clone)] pub enum SetpCompareFloat { Eq, NotEq, Less, LessOrEq, Greater, GreaterOrEq, NanEq, NanNotEq, NanLess, NanLessOrEq, NanGreater, NanGreaterOrEq, IsNotNan, IsAnyNan, } impl TryFrom<(RawSetpCompareOp, ScalarKind)> for SetpCompareInt { type Error = PtxError; fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result { match (value, kind) { (RawSetpCompareOp::Eq, _) => Ok(SetpCompareInt::Eq), (RawSetpCompareOp::Ne, _) => Ok(SetpCompareInt::NotEq), (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, ScalarKind::Signed) => { Ok(SetpCompareInt::SignedLess) } (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, _) => Ok(SetpCompareInt::UnsignedLess), (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, ScalarKind::Signed) => { Ok(SetpCompareInt::SignedLessOrEq) } (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, _) => { Ok(SetpCompareInt::UnsignedLessOrEq) } (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, ScalarKind::Signed) => { Ok(SetpCompareInt::SignedGreater) } (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, _) => Ok(SetpCompareInt::UnsignedGreater), (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, ScalarKind::Signed) => { Ok(SetpCompareInt::SignedGreaterOrEq) } (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, _) => { Ok(SetpCompareInt::UnsignedGreaterOrEq) } (RawSetpCompareOp::Equ, _) => Err(PtxError::WrongType), (RawSetpCompareOp::Neu, _) => Err(PtxError::WrongType), (RawSetpCompareOp::Ltu, _) => Err(PtxError::WrongType), (RawSetpCompareOp::Leu, _) => Err(PtxError::WrongType), (RawSetpCompareOp::Gtu, _) => Err(PtxError::WrongType), (RawSetpCompareOp::Geu, _) => Err(PtxError::WrongType), (RawSetpCompareOp::Num, _) => Err(PtxError::WrongType), (RawSetpCompareOp::Nan, _) => Err(PtxError::WrongType), } } } impl From for SetpCompareFloat { fn from(value: RawSetpCompareOp) -> Self { match value { RawSetpCompareOp::Eq => SetpCompareFloat::Eq, RawSetpCompareOp::Ne => SetpCompareFloat::NotEq, RawSetpCompareOp::Lt => SetpCompareFloat::Less, RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq, RawSetpCompareOp::Gt => SetpCompareFloat::Greater, RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq, RawSetpCompareOp::Lo => SetpCompareFloat::Less, RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq, RawSetpCompareOp::Hi => SetpCompareFloat::Greater, RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq, RawSetpCompareOp::Equ => SetpCompareFloat::NanEq, RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq, RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess, RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq, RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater, RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq, RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan, RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan, } } } pub struct CallDetails { pub uniform: bool, pub return_arguments: Vec<(Type, StateSpace)>, pub input_arguments: Vec<(Type, StateSpace)>, } pub struct CallArgs { pub return_arguments: Vec, pub func: T::Ident, pub input_arguments: Vec, } impl CallArgs { #[allow(dead_code)] // Used by generated code fn visit( &self, details: &CallDetails, visitor: &mut impl Visitor, ) -> Result<(), Err> { for (param, (type_, space)) in self .return_arguments .iter() .zip(details.return_arguments.iter()) { visitor.visit_ident( param, Some((type_, *space)), *space == StateSpace::Reg, false, )?; } visitor.visit_ident(&self.func, None, false, false)?; for (param, (type_, space)) in self .input_arguments .iter() .zip(details.input_arguments.iter()) { visitor.visit(param, Some((type_, *space)), false, false)?; } Ok(()) } #[allow(dead_code)] // Used by generated code fn visit_mut( &mut self, details: &CallDetails, visitor: &mut impl VisitorMut, ) -> Result<(), Err> { for (param, (type_, space)) in self .return_arguments .iter_mut() .zip(details.return_arguments.iter()) { visitor.visit_ident( param, Some((type_, *space)), *space == StateSpace::Reg, false, )?; } visitor.visit_ident(&mut self.func, None, false, false)?; for (param, (type_, space)) in self .input_arguments .iter_mut() .zip(details.input_arguments.iter()) { visitor.visit(param, Some((type_, *space)), false, false)?; } Ok(()) } #[allow(dead_code)] // Used by generated code fn map( self, details: &CallDetails, visitor: &mut impl VisitorMap, ) -> Result, Err> { let return_arguments = self .return_arguments .into_iter() .zip(details.return_arguments.iter()) .map(|(param, (type_, space))| { visitor.visit_ident( param, Some((type_, *space)), *space == StateSpace::Reg, false, ) }) .collect::, _>>()?; let func = visitor.visit_ident(self.func, None, false, false)?; let input_arguments = self .input_arguments .into_iter() .zip(details.input_arguments.iter()) .map(|(param, (type_, space))| { visitor.visit(param, Some((type_, *space)), false, false) }) .collect::, _>>()?; Ok(CallArgs { return_arguments, func, input_arguments, }) } } pub struct CvtDetails { pub from: ScalarType, pub to: ScalarType, pub mode: CvtMode, } pub enum CvtMode { // int from int ZeroExtend, SignExtend, Truncate, Bitcast, SaturateUnsignedToSigned, SaturateSignedToUnsigned, // float from float FPExtend { flush_to_zero: Option, }, FPTruncate { // float rounding rounding: RoundingMode, flush_to_zero: Option, }, FPRound { integer_rounding: Option, flush_to_zero: Option, }, // int from float SignedFromFP { rounding: RoundingMode, flush_to_zero: Option, }, // integer rounding UnsignedFromFP { rounding: RoundingMode, flush_to_zero: Option, }, // integer rounding // float from int, ftz is allowed in the grammar, but clearly nonsensical FPFromSigned(RoundingMode), // float rounding FPFromUnsigned(RoundingMode), // float rounding } impl CvtDetails { pub(crate) fn new( errors: &mut Vec, rnd: Option, ftz: bool, saturate: bool, dst: ScalarType, src: ScalarType, ) -> Self { if saturate && dst.kind() == ScalarKind::Float { errors.push(PtxError::SyntaxError); } // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results. let flush_to_zero = match (dst, src) { (ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz), _ => { if ftz { errors.push(PtxError::NonF32Ftz); } None } }; let rounding = rnd.map(Into::into); let mut unwrap_rounding = || match rounding { Some(rnd) => rnd, None => { errors.push(PtxError::SyntaxError); RoundingMode::NearestEven } }; let mode = match (dst.kind(), src.kind()) { (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) { Ordering::Less => CvtMode::FPTruncate { rounding: unwrap_rounding(), flush_to_zero, }, Ordering::Equal => CvtMode::FPRound { integer_rounding: rounding, flush_to_zero, }, Ordering::Greater => { if rounding.is_some() { errors.push(PtxError::SyntaxError); } CvtMode::FPExtend { flush_to_zero } } }, (ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP { rounding: unwrap_rounding(), flush_to_zero, }, (ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP { rounding: unwrap_rounding(), flush_to_zero, }, (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()), (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()), (ScalarKind::Signed, ScalarKind::Unsigned) if saturate => { CvtMode::SaturateUnsignedToSigned } (ScalarKind::Unsigned, ScalarKind::Signed) if saturate => { CvtMode::SaturateSignedToUnsigned } (ScalarKind::Unsigned, ScalarKind::Signed) | (ScalarKind::Signed, ScalarKind::Unsigned) if dst.size_of() == src.size_of() => { CvtMode::Bitcast } (ScalarKind::Unsigned, ScalarKind::Unsigned) | (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) { Ordering::Less => CvtMode::Truncate, Ordering::Equal => CvtMode::Bitcast, Ordering::Greater => { if src.kind() == ScalarKind::Signed { CvtMode::SignExtend } else { CvtMode::ZeroExtend } } }, (ScalarKind::Unsigned, ScalarKind::Signed) => CvtMode::SaturateSignedToUnsigned, (_, _) => { errors.push(PtxError::SyntaxError); CvtMode::Bitcast } }; CvtDetails { mode, to: dst, from: src, } } } pub struct CvtIntToIntDesc { pub dst: ScalarType, pub src: ScalarType, pub saturate: bool, } pub struct CvtDesc { pub rounding: Option, pub flush_to_zero: Option, pub saturate: bool, pub dst: ScalarType, pub src: ScalarType, } pub struct ShrData { pub type_: ScalarType, pub kind: RightShiftKind, } pub enum RightShiftKind { Arithmetic, Logical, } pub struct CvtaDetails { pub state_space: StateSpace, pub direction: CvtaDirection, } pub enum CvtaDirection { GenericToExplicit, ExplicitToGeneric, } #[derive(Copy, Clone, PartialEq, Eq)] pub struct TypeFtz { pub flush_to_zero: Option, pub type_: ScalarType, } #[derive(Copy, Clone)] pub enum MadDetails { Integer { control: MulIntControl, saturate: bool, type_: ScalarType, }, Float(ArithFloat), } impl MadDetails { pub fn dst_type(&self) -> ScalarType { match self { MadDetails::Integer { type_, control: MulIntControl::Wide, .. } => match type_ { ScalarType::U16 => ScalarType::U32, ScalarType::S16 => ScalarType::S32, ScalarType::U32 => ScalarType::U64, ScalarType::S32 => ScalarType::S64, _ => unreachable!(), }, _ => self.type_(), } } fn type_(&self) -> ScalarType { match self { MadDetails::Integer { type_, .. } => *type_, MadDetails::Float(arith) => arith.type_, } } } #[derive(Copy, Clone)] pub enum MinMaxDetails { Signed(ScalarType), Unsigned(ScalarType), Float(MinMaxFloat), } impl MinMaxDetails { pub fn type_(&self) -> ScalarType { match self { MinMaxDetails::Signed(t) => *t, MinMaxDetails::Unsigned(t) => *t, MinMaxDetails::Float(float) => float.type_, } } } #[derive(Copy, Clone)] pub struct MinMaxFloat { pub flush_to_zero: Option, pub nan: bool, pub type_: ScalarType, } #[derive(Copy, Clone)] pub struct RcpData { pub kind: RcpKind, pub flush_to_zero: Option, pub type_: ScalarType, } #[derive(Copy, Clone, Eq, PartialEq)] pub enum RcpKind { Approx, Compliant(RoundingMode), } pub struct BarData { pub aligned: bool, } pub struct AtomDetails { pub type_: Type, pub semantics: AtomSemantics, pub scope: MemScope, pub space: StateSpace, pub op: AtomicOp, } #[derive(Copy, Clone)] pub enum AtomicOp { And, Or, Xor, Exchange, Add, IncrementWrap, DecrementWrap, SignedMin, UnsignedMin, SignedMax, UnsignedMax, FloatAdd, FloatMin, FloatMax, } impl AtomicOp { pub(crate) fn new(op: super::RawAtomicOp, kind: ScalarKind) -> Self { use super::RawAtomicOp; match (op, kind) { (RawAtomicOp::And, _) => Self::And, (RawAtomicOp::Or, _) => Self::Or, (RawAtomicOp::Xor, _) => Self::Xor, (RawAtomicOp::Exch, _) => Self::Exchange, (RawAtomicOp::Add, ScalarKind::Float) => Self::FloatAdd, (RawAtomicOp::Add, _) => Self::Add, (RawAtomicOp::Inc, _) => Self::IncrementWrap, (RawAtomicOp::Dec, _) => Self::DecrementWrap, (RawAtomicOp::Min, ScalarKind::Signed) => Self::SignedMin, (RawAtomicOp::Min, ScalarKind::Float) => Self::FloatMin, (RawAtomicOp::Min, _) => Self::UnsignedMin, (RawAtomicOp::Max, ScalarKind::Signed) => Self::SignedMax, (RawAtomicOp::Max, ScalarKind::Float) => Self::FloatMax, (RawAtomicOp::Max, _) => Self::UnsignedMax, } } } pub struct AtomCasDetails { pub type_: ScalarType, pub semantics: AtomSemantics, pub scope: MemScope, pub space: StateSpace, } #[derive(Copy, Clone)] pub enum DivDetails { Unsigned(ScalarType), Signed(ScalarType), Float(DivFloatDetails), } impl DivDetails { pub fn type_(&self) -> ScalarType { match self { DivDetails::Unsigned(t) => *t, DivDetails::Signed(t) => *t, DivDetails::Float(float) => float.type_, } } } #[derive(Copy, Clone)] pub struct DivFloatDetails { pub type_: ScalarType, pub flush_to_zero: Option, pub kind: DivFloatKind, } #[derive(Copy, Clone, Eq, PartialEq)] pub enum DivFloatKind { Approx, ApproxFull, Rounding(RoundingMode), } #[derive(Copy, Clone, Eq, PartialEq)] pub struct FlushToZero { pub flush_to_zero: bool, }