diff options
Diffstat (limited to 'ptx_parser')
-rw-r--r-- | ptx_parser/Cargo.toml | 17 | ||||
-rw-r--r-- | ptx_parser/src/ast.rs | 1695 | ||||
-rw-r--r-- | ptx_parser/src/check_args.py | 69 | ||||
-rw-r--r-- | ptx_parser/src/lib.rs | 3269 |
4 files changed, 5050 insertions, 0 deletions
diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml new file mode 100644 index 0000000..9032de5 --- /dev/null +++ b/ptx_parser/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "ptx_parser" +version = "0.0.0" +authors = ["Andrzej Janik <[email protected]>"] +edition = "2021" + +[lib] + +[dependencies] +logos = "0.14" +winnow = { version = "0.6.18" } +#winnow = { version = "0.6.18", features = ["debug"] } +ptx_parser_macros = { path = "../ptx_parser_macros" } +thiserror = "1.0" +bitflags = "1.2" +rustc-hash = "2.0.0" +derive_more = { version = "1", features = ["display"] } diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs new file mode 100644 index 0000000..d0dc303 --- /dev/null +++ b/ptx_parser/src/ast.rs @@ -0,0 +1,1695 @@ +use super::{
+ AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp,
+ StateSpace, VectorPrefix,
+};
+use crate::{PtxError, PtxParserState};
+use bitflags::bitflags;
+use std::{cmp::Ordering, num::NonZeroU8};
+
+pub enum Statement<P: Operand> {
+ Label(P::Ident),
+ Variable(MultiVariable<P::Ident>),
+ Instruction(Option<PredAt<P::Ident>>, Instruction<P>),
+ Block(Vec<Statement<P>>),
+}
+
+// We define the instruction enum through the macro instead of normally, because we have some of how
+// we use this type in the compilee. Each instruction can be logically split into two parts:
+// properties that define instruction semantics (e.g. is memory load volatile?) that don't change
+// during compilation and arguments (e.g. memory load source and destination) that evolve during
+// compilation. To support compilation passes we need to be able to visit (and change) every
+// argument in a generic way. This macro has visibility over all the fields. Consequently, we use it
+// to generate visitor functions. There re three functions to support three different semantics:
+// visit-by-ref, visit-by-mutable-ref, visit-and-map. In a previous version of the compiler it was
+// done by hand and was very limiting (we supported only visit-and-map).
+// The visitor must implement appropriate visitor trait defined below this macro. For convenience,
+// we implemented visitors for some corresponding FnMut(...) types.
+// Properties in this macro are used to encode information about the instruction arguments (what
+// Rust type is used for it post-parsing, what PTX type does it expect, what PTX address space does
+// it expect, etc.).
+// This information is then available to a visitor.
+ptx_parser_macros::generate_instruction_type!(
+ pub enum Instruction<T: Operand> {
+ Mov {
+ type: { &data.typ },
+ data: MovDetails,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Ld {
+ type: { &data.typ },
+ data: LdDetails,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ relaxed_type_check: true,
+ },
+ src: {
+ repr: T,
+ space: { data.state_space },
+ }
+ }
+ },
+ Add {
+ type: { Type::from(data.type_()) },
+ data: ArithDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ St {
+ type: { &data.typ },
+ data: StData,
+ arguments<T>: {
+ src1: {
+ repr: T,
+ space: { data.state_space },
+ },
+ src2: {
+ repr: T,
+ relaxed_type_check: true,
+ }
+ }
+ },
+ Mul {
+ type: { Type::from(data.type_()) },
+ data: MulDetails,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: { Type::from(data.dst_type()) },
+ },
+ src1: T,
+ src2: T,
+ }
+ },
+ Setp {
+ data: SetpData,
+ arguments<T>: {
+ dst1: {
+ repr: T,
+ type: Type::from(ScalarType::Pred)
+ },
+ dst2: {
+ repr: Option<T>,
+ type: Type::from(ScalarType::Pred)
+ },
+ src1: {
+ repr: T,
+ type: Type::from(data.type_),
+ },
+ src2: {
+ repr: T,
+ type: Type::from(data.type_),
+ }
+ }
+ },
+ SetpBool {
+ data: SetpBoolData,
+ arguments<T>: {
+ dst1: {
+ repr: T,
+ type: Type::from(ScalarType::Pred)
+ },
+ dst2: {
+ repr: Option<T>,
+ type: Type::from(ScalarType::Pred)
+ },
+ src1: {
+ repr: T,
+ type: Type::from(data.base.type_),
+ },
+ src2: {
+ repr: T,
+ type: Type::from(data.base.type_),
+ },
+ src3: {
+ repr: T,
+ type: Type::from(ScalarType::Pred)
+ }
+ }
+ },
+ Not {
+ data: ScalarType,
+ type: { Type::Scalar(data.clone()) },
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Or {
+ data: ScalarType,
+ type: { Type::Scalar(data.clone()) },
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ And {
+ data: ScalarType,
+ type: { Type::Scalar(data.clone()) },
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Bra {
+ type: !,
+ arguments<T::Ident>: {
+ src: T
+ }
+ },
+ Call {
+ data: CallDetails,
+ arguments: CallArgs<T>,
+ visit: arguments.visit(data, visitor)?,
+ visit_mut: arguments.visit_mut(data, visitor)?,
+ map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data }
+ },
+ Cvt {
+ data: CvtDetails,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: { Type::Scalar(data.to) },
+ // TODO: double check
+ relaxed_type_check: true,
+ },
+ src: {
+ repr: T,
+ type: { Type::Scalar(data.from) },
+ relaxed_type_check: true,
+ },
+ }
+ },
+ Shr {
+ data: ShrData,
+ type: { Type::Scalar(data.type_.clone()) },
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: {
+ repr: T,
+ type: { Type::Scalar(ScalarType::U32) },
+ },
+ }
+ },
+ Shl {
+ data: ScalarType,
+ type: { Type::Scalar(data.clone()) },
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: {
+ repr: T,
+ type: { Type::Scalar(ScalarType::U32) },
+ },
+ }
+ },
+ Ret {
+ data: RetData
+ },
+ Cvta {
+ data: CvtaDetails,
+ type: { Type::Scalar(ScalarType::B64) },
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Abs {
+ data: TypeFtz,
+ type: { Type::Scalar(data.type_) },
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Mad {
+ type: { Type::from(data.type_()) },
+ data: MadDetails,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: { Type::from(data.dst_type()) },
+ },
+ src1: T,
+ src2: T,
+ src3: T,
+ }
+ },
+ Fma {
+ type: { Type::from(data.type_) },
+ data: ArithFloat,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ src3: T,
+ }
+ },
+ Sub {
+ type: { Type::from(data.type_()) },
+ data: ArithDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Min {
+ type: { Type::from(data.type_()) },
+ data: MinMaxDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Max {
+ type: { Type::from(data.type_()) },
+ data: MinMaxDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Rcp {
+ type: { Type::from(data.type_) },
+ data: RcpData,
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Sqrt {
+ type: { Type::from(data.type_) },
+ data: RcpData,
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Rsqrt {
+ type: { Type::from(data.type_) },
+ data: TypeFtz,
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Selp {
+ type: { Type::Scalar(data.clone()) },
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ src3: {
+ repr: T,
+ type: Type::Scalar(ScalarType::Pred)
+ },
+ }
+ },
+ Bar {
+ type: Type::Scalar(ScalarType::U32),
+ data: BarData,
+ arguments<T>: {
+ src1: T,
+ src2: Option<T>,
+ }
+ },
+ Atom {
+ type: &data.type_,
+ data: AtomDetails,
+ arguments<T>: {
+ dst: T,
+ src1: {
+ repr: T,
+ space: { data.space },
+ },
+ src2: T,
+ }
+ },
+ AtomCas {
+ type: Type::Scalar(data.type_),
+ data: AtomCasDetails,
+ arguments<T>: {
+ dst: T,
+ src1: {
+ repr: T,
+ space: { data.space },
+ },
+ src2: T,
+ src3: T,
+ }
+ },
+ Div {
+ type: Type::Scalar(data.type_()),
+ data: DivDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Neg {
+ type: Type::Scalar(data.type_),
+ data: TypeFtz,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Sin {
+ type: Type::Scalar(ScalarType::F32),
+ data: FlushToZero,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Cos {
+ type: Type::Scalar(ScalarType::F32),
+ data: FlushToZero,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Lg2 {
+ type: Type::Scalar(ScalarType::F32),
+ data: FlushToZero,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Ex2 {
+ type: Type::Scalar(ScalarType::F32),
+ data: TypeFtz,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Clz {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ src: T
+ }
+ },
+ Brev {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Popc {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ src: T
+ }
+ },
+ Xor {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T
+ }
+ },
+ Rem {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T
+ }
+ },
+ Bfe {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ src3: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ }
+ },
+ Bfi {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ src3: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ src4: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ }
+ },
+ PrmtSlow {
+ type: Type::Scalar(ScalarType::U32),
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ src3: T
+ }
+ },
+ Prmt {
+ type: Type::Scalar(ScalarType::B32),
+ data: u16,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T
+ }
+ },
+ Activemask {
+ type: Type::Scalar(ScalarType::B32),
+ arguments<T>: {
+ dst: T
+ }
+ },
+ Membar {
+ data: MemScope
+ },
+ Trap { }
+ }
+);
+
+pub trait Visitor<T: Operand, Err> {
+ fn visit(
+ &mut self,
+ args: &T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err>;
+ fn visit_ident(
+ &mut self,
+ args: &T::Ident,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err>;
+}
+
+impl<
+ T: Operand,
+ Err,
+ Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>,
+ > Visitor<T, Err> for Fn
+{
+ fn visit(
+ &mut self,
+ args: &T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err> {
+ (self)(args, type_space, is_dst, relaxed_type_check)
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: &T::Ident,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err> {
+ (self)(
+ &T::from_ident(*args),
+ type_space,
+ is_dst,
+ relaxed_type_check,
+ )
+ }
+}
+
+pub trait VisitorMut<T: Operand, Err> {
+ fn visit(
+ &mut self,
+ args: &mut T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err>;
+ fn visit_ident(
+ &mut self,
+ args: &mut T::Ident,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err>;
+}
+
+pub trait VisitorMap<From: Operand, To: Operand, Err> {
+ fn visit(
+ &mut self,
+ args: From,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<To, Err>;
+ fn visit_ident(
+ &mut self,
+ args: From::Ident,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<To::Ident, Err>;
+}
+
+impl<T: Copy, U: Copy, Err, Fn> VisitorMap<ParsedOperand<T>, ParsedOperand<U>, Err> for Fn
+where
+ Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result<U, Err>,
+{
+ fn visit(
+ &mut self,
+ args: ParsedOperand<T>,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<ParsedOperand<U>, Err> {
+ Ok(match args {
+ ParsedOperand::Reg(ident) => {
+ ParsedOperand::Reg((self)(ident, type_space, is_dst, relaxed_type_check)?)
+ }
+ ParsedOperand::RegOffset(ident, imm) => ParsedOperand::RegOffset(
+ (self)(ident, type_space, is_dst, relaxed_type_check)?,
+ imm,
+ ),
+ ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm),
+ ParsedOperand::VecMember(ident, index) => ParsedOperand::VecMember(
+ (self)(ident, type_space, is_dst, relaxed_type_check)?,
+ index,
+ ),
+ ParsedOperand::VecPack(vec) => ParsedOperand::VecPack(
+ vec.into_iter()
+ .map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check))
+ .collect::<Result<Vec<_>, _>>()?,
+ ),
+ })
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<U, Err> {
+ (self)(args, type_space, is_dst, relaxed_type_check)
+ }
+}
+
+impl<T: Operand<Ident = T>, U: Operand<Ident = U>, Err, Fn> VisitorMap<T, U, Err> for Fn
+where
+ Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result<U, Err>,
+{
+ fn visit(
+ &mut self,
+ args: T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<U, Err> {
+ (self)(args, type_space, is_dst, relaxed_type_check)
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<U, Err> {
+ (self)(args, type_space, is_dst, relaxed_type_check)
+ }
+}
+
+trait VisitOperand<Err> {
+ type Operand: Operand;
+ #[allow(unused)] // Used by generated code
+ fn visit(&self, fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err>;
+ #[allow(unused)] // Used by generated code
+ fn visit_mut(
+ &mut self,
+ fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
+ ) -> Result<(), Err>;
+}
+
+impl<T: Operand, Err> VisitOperand<Err> for T {
+ type Operand = Self;
+ fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
+ fn_(self)
+ }
+ fn visit_mut(
+ &mut self,
+ mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
+ ) -> Result<(), Err> {
+ fn_(self)
+ }
+}
+
+impl<T: Operand, Err> VisitOperand<Err> for Option<T> {
+ type Operand = T;
+ fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
+ if let Some(x) = self {
+ fn_(x)?;
+ }
+ Ok(())
+ }
+ fn visit_mut(
+ &mut self,
+ mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
+ ) -> Result<(), Err> {
+ if let Some(x) = self {
+ fn_(x)?;
+ }
+ Ok(())
+ }
+}
+
+impl<T: Operand, Err> VisitOperand<Err> for Vec<T> {
+ type Operand = T;
+ fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
+ for o in self {
+ fn_(o)?;
+ }
+ Ok(())
+ }
+ fn visit_mut(
+ &mut self,
+ mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
+ ) -> Result<(), Err> {
+ for o in self {
+ fn_(o)?;
+ }
+ Ok(())
+ }
+}
+
+trait MapOperand<Err>: Sized {
+ type Input;
+ type Output<U>;
+ #[allow(unused)] // Used by generated code
+ fn map<U>(
+ self,
+ fn_: impl FnOnce(Self::Input) -> Result<U, Err>,
+ ) -> Result<Self::Output<U>, Err>;
+}
+
+impl<T: Operand, Err> MapOperand<Err> for T {
+ type Input = Self;
+ type Output<U> = U;
+ fn map<U>(self, fn_: impl FnOnce(T) -> Result<U, Err>) -> Result<U, Err> {
+ fn_(self)
+ }
+}
+
+impl<T: Operand, Err> MapOperand<Err> for Option<T> {
+ type Input = T;
+ type Output<U> = Option<U>;
+ fn map<U>(self, fn_: impl FnOnce(T) -> Result<U, Err>) -> Result<Option<U>, Err> {
+ self.map(|x| fn_(x)).transpose()
+ }
+}
+
+pub struct MultiVariable<ID> {
+ pub var: Variable<ID>,
+ pub count: Option<u32>,
+}
+
+#[derive(Clone)]
+pub struct Variable<ID> {
+ pub align: Option<u32>,
+ pub v_type: Type,
+ pub state_space: StateSpace,
+ pub name: ID,
+ pub array_init: Vec<u8>,
+}
+
+pub struct PredAt<ID> {
+ pub not: bool,
+ pub label: ID,
+}
+
+#[derive(PartialEq, Eq, Clone, Hash)]
+pub enum Type {
+ // .param.b32 foo;
+ Scalar(ScalarType),
+ // .param.v2.b32 foo;
+ Vector(u8, ScalarType),
+ // .param.b32 foo[4];
+ Array(Option<NonZeroU8>, ScalarType, Vec<u32>),
+ Pointer(ScalarType, StateSpace),
+}
+
+impl Type {
+ pub(crate) fn maybe_vector(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
+ match vector {
+ Some(prefix) => Type::Vector(prefix.len().get(), scalar),
+ None => Type::Scalar(scalar),
+ }
+ }
+
+ pub(crate) fn maybe_vector_parsed(prefix: Option<NonZeroU8>, scalar: ScalarType) -> Self {
+ match prefix {
+ Some(prefix) => Type::Vector(prefix.get(), scalar),
+ None => Type::Scalar(scalar),
+ }
+ }
+
+ pub(crate) fn maybe_array(
+ prefix: Option<NonZeroU8>,
+ scalar: ScalarType,
+ array: Option<Vec<u32>>,
+ ) -> Self {
+ match array {
+ Some(dimensions) => Type::Array(prefix, scalar, dimensions),
+ None => Self::maybe_vector_parsed(prefix, scalar),
+ }
+ }
+}
+
+impl ScalarType {
+ pub fn size_of(self) -> u8 {
+ match self {
+ ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1,
+ ScalarType::U16
+ | ScalarType::S16
+ | ScalarType::B16
+ | ScalarType::F16
+ | ScalarType::BF16 => 2,
+ ScalarType::U32
+ | ScalarType::S32
+ | ScalarType::B32
+ | ScalarType::F32
+ | ScalarType::U16x2
+ | ScalarType::S16x2
+ | ScalarType::F16x2
+ | ScalarType::BF16x2 => 4,
+ ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => 8,
+ ScalarType::B128 => 16,
+ ScalarType::Pred => 1,
+ }
+ }
+
+ pub fn kind(self) -> ScalarKind {
+ match self {
+ ScalarType::U8 => ScalarKind::Unsigned,
+ ScalarType::U16 => ScalarKind::Unsigned,
+ ScalarType::U16x2 => ScalarKind::Unsigned,
+ ScalarType::U32 => ScalarKind::Unsigned,
+ ScalarType::U64 => ScalarKind::Unsigned,
+ ScalarType::S8 => ScalarKind::Signed,
+ ScalarType::S16 => ScalarKind::Signed,
+ ScalarType::S16x2 => ScalarKind::Signed,
+ ScalarType::S32 => ScalarKind::Signed,
+ ScalarType::S64 => ScalarKind::Signed,
+ ScalarType::B8 => ScalarKind::Bit,
+ ScalarType::B16 => ScalarKind::Bit,
+ ScalarType::B32 => ScalarKind::Bit,
+ ScalarType::B64 => ScalarKind::Bit,
+ ScalarType::B128 => ScalarKind::Bit,
+ ScalarType::F16 => ScalarKind::Float,
+ ScalarType::F16x2 => ScalarKind::Float,
+ ScalarType::F32 => ScalarKind::Float,
+ ScalarType::F64 => ScalarKind::Float,
+ ScalarType::BF16 => ScalarKind::Float,
+ ScalarType::BF16x2 => ScalarKind::Float,
+ ScalarType::Pred => ScalarKind::Pred,
+ }
+ }
+}
+
+#[derive(Clone, Copy, PartialEq, Eq)]
+pub enum ScalarKind {
+ Bit,
+ Unsigned,
+ Signed,
+ Float,
+ Pred,
+}
+impl From<ScalarType> for Type {
+ fn from(value: ScalarType) -> Self {
+ Type::Scalar(value)
+ }
+}
+
+#[derive(Clone)]
+pub struct MovDetails {
+ pub typ: super::Type,
+ pub src_is_address: bool,
+ // two fields below are in use by member moves
+ pub dst_width: u8,
+ pub src_width: u8,
+ // This is in use by auto-generated movs
+ pub relaxed_src2_conv: bool,
+}
+
+impl MovDetails {
+ pub(crate) fn new(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
+ MovDetails {
+ typ: Type::maybe_vector(vector, scalar),
+ src_is_address: false,
+ dst_width: 0,
+ src_width: 0,
+ relaxed_src2_conv: false,
+ }
+ }
+}
+
+#[derive(Clone)]
+pub enum ParsedOperand<Ident> {
+ Reg(Ident),
+ RegOffset(Ident, i32),
+ Imm(ImmediateValue),
+ VecMember(Ident, u8),
+ VecPack(Vec<Ident>),
+}
+
+impl<Ident: Copy> Operand for ParsedOperand<Ident> {
+ type Ident = Ident;
+
+ fn from_ident(ident: Self::Ident) -> Self {
+ ParsedOperand::Reg(ident)
+ }
+}
+
+pub trait Operand: Sized {
+ type Ident: Copy;
+
+ fn from_ident(ident: Self::Ident) -> Self;
+}
+
+#[derive(Copy, Clone)]
+pub enum ImmediateValue {
+ U64(u64),
+ S64(i64),
+ F32(f32),
+ F64(f64),
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum StCacheOperator {
+ Writeback,
+ L2Only,
+ Streaming,
+ Writethrough,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum LdCacheOperator {
+ Cached,
+ L2Only,
+ Streaming,
+ LastUse,
+ Uncached,
+}
+
+#[derive(Copy, Clone)]
+pub enum ArithDetails {
+ Integer(ArithInteger),
+ Float(ArithFloat),
+}
+
+impl ArithDetails {
+ pub fn type_(&self) -> ScalarType {
+ match self {
+ ArithDetails::Integer(t) => t.type_,
+ ArithDetails::Float(arith) => arith.type_,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct ArithInteger {
+ pub type_: ScalarType,
+ pub saturate: bool,
+}
+
+#[derive(Copy, Clone)]
+pub struct ArithFloat {
+ pub type_: ScalarType,
+ pub rounding: Option<RoundingMode>,
+ pub flush_to_zero: Option<bool>,
+ pub saturate: bool,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum LdStQualifier {
+ Weak,
+ Volatile,
+ Relaxed(MemScope),
+ Acquire(MemScope),
+ Release(MemScope),
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum RoundingMode {
+ NearestEven,
+ Zero,
+ NegativeInf,
+ PositiveInf,
+}
+
+pub struct LdDetails {
+ pub qualifier: LdStQualifier,
+ pub state_space: StateSpace,
+ pub caching: LdCacheOperator,
+ pub typ: Type,
+ pub non_coherent: bool,
+}
+
+pub struct StData {
+ pub qualifier: LdStQualifier,
+ pub state_space: StateSpace,
+ pub caching: StCacheOperator,
+ pub typ: Type,
+}
+
+#[derive(Copy, Clone)]
+pub struct RetData {
+ pub uniform: bool,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum TuningDirective {
+ MaxNReg(u32),
+ MaxNtid(u32, u32, u32),
+ ReqNtid(u32, u32, u32),
+ MinNCtaPerSm(u32),
+}
+
+pub struct MethodDeclaration<'input, ID> {
+ pub return_arguments: Vec<Variable<ID>>,
+ pub name: MethodName<'input, ID>,
+ pub input_arguments: Vec<Variable<ID>>,
+ pub shared_mem: Option<ID>,
+}
+
+impl<'input> MethodDeclaration<'input, &'input str> {
+ pub fn name(&self) -> &'input str {
+ match self.name {
+ MethodName::Kernel(n) => n,
+ MethodName::Func(n) => n,
+ }
+ }
+}
+
+#[derive(Hash, PartialEq, Eq, Copy, Clone)]
+pub enum MethodName<'input, ID> {
+ Kernel(&'input str),
+ Func(ID),
+}
+
+bitflags! {
+ pub struct LinkingDirective: u8 {
+ const NONE = 0b000;
+ const EXTERN = 0b001;
+ const VISIBLE = 0b10;
+ const WEAK = 0b100;
+ }
+}
+
+pub struct Function<'a, ID, S> {
+ pub func_directive: MethodDeclaration<'a, ID>,
+ pub tuning: Vec<TuningDirective>,
+ pub body: Option<Vec<S>>,
+}
+
+pub enum Directive<'input, O: Operand> {
+ Variable(LinkingDirective, Variable<O::Ident>),
+ Method(
+ LinkingDirective,
+ Function<'input, &'input str, Statement<O>>,
+ ),
+}
+
+pub struct Module<'input> {
+ pub version: (u8, u8),
+ pub directives: Vec<Directive<'input, ParsedOperand<&'input str>>>,
+}
+
+#[derive(Copy, Clone)]
+pub enum MulDetails {
+ Integer {
+ type_: ScalarType,
+ control: MulIntControl,
+ },
+ Float(ArithFloat),
+}
+
+impl MulDetails {
+ pub fn type_(&self) -> ScalarType {
+ match self {
+ MulDetails::Integer { type_, .. } => *type_,
+ MulDetails::Float(arith) => arith.type_,
+ }
+ }
+
+ pub fn dst_type(&self) -> ScalarType {
+ match self {
+ MulDetails::Integer {
+ type_,
+ control: MulIntControl::Wide,
+ } => match type_ {
+ ScalarType::U16 => ScalarType::U32,
+ ScalarType::S16 => ScalarType::S32,
+ ScalarType::U32 => ScalarType::U64,
+ ScalarType::S32 => ScalarType::S64,
+ _ => unreachable!(),
+ },
+ _ => self.type_(),
+ }
+ }
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum MulIntControl {
+ Low,
+ High,
+ Wide,
+}
+
+pub struct SetpData {
+ pub type_: ScalarType,
+ pub flush_to_zero: Option<bool>,
+ pub cmp_op: SetpCompareOp,
+}
+
+impl SetpData {
+ pub(crate) fn try_parse(
+ state: &mut PtxParserState,
+ cmp_op: super::RawSetpCompareOp,
+ ftz: bool,
+ type_: ScalarType,
+ ) -> Self {
+ let flush_to_zero = match (ftz, type_) {
+ (_, ScalarType::F32) => Some(ftz),
+ (true, _) => {
+ state.errors.push(PtxError::NonF32Ftz);
+ None
+ }
+ _ => None
+ };
+ let type_kind = type_.kind();
+ let cmp_op = if type_kind == ScalarKind::Float {
+ SetpCompareOp::Float(SetpCompareFloat::from(cmp_op))
+ } else {
+ match SetpCompareInt::try_from((cmp_op, type_kind)) {
+ Ok(op) => SetpCompareOp::Integer(op),
+ Err(err) => {
+ state.errors.push(err);
+ SetpCompareOp::Integer(SetpCompareInt::Eq)
+ }
+ }
+ };
+ Self {
+ type_,
+ flush_to_zero,
+ cmp_op,
+ }
+ }
+}
+
+pub struct SetpBoolData {
+ pub base: SetpData,
+ pub bool_op: SetpBoolPostOp,
+ pub negate_src3: bool,
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum SetpCompareOp {
+ Integer(SetpCompareInt),
+ Float(SetpCompareFloat),
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum SetpCompareInt {
+ Eq,
+ NotEq,
+ UnsignedLess,
+ UnsignedLessOrEq,
+ UnsignedGreater,
+ UnsignedGreaterOrEq,
+ SignedLess,
+ SignedLessOrEq,
+ SignedGreater,
+ SignedGreaterOrEq,
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum SetpCompareFloat {
+ Eq,
+ NotEq,
+ Less,
+ LessOrEq,
+ Greater,
+ GreaterOrEq,
+ NanEq,
+ NanNotEq,
+ NanLess,
+ NanLessOrEq,
+ NanGreater,
+ NanGreaterOrEq,
+ IsNotNan,
+ IsAnyNan,
+}
+
+impl TryFrom<(RawSetpCompareOp, ScalarKind)> for SetpCompareInt {
+ type Error = PtxError;
+
+ fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result<Self, PtxError> {
+ match (value, kind) {
+ (RawSetpCompareOp::Eq, _) => Ok(SetpCompareInt::Eq),
+ (RawSetpCompareOp::Ne, _) => Ok(SetpCompareInt::NotEq),
+ (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, ScalarKind::Signed) => {
+ Ok(SetpCompareInt::SignedLess)
+ }
+ (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, _) => Ok(SetpCompareInt::UnsignedLess),
+ (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, ScalarKind::Signed) => {
+ Ok(SetpCompareInt::SignedLessOrEq)
+ }
+ (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, _) => {
+ Ok(SetpCompareInt::UnsignedLessOrEq)
+ }
+ (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, ScalarKind::Signed) => {
+ Ok(SetpCompareInt::SignedGreater)
+ }
+ (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, _) => Ok(SetpCompareInt::UnsignedGreater),
+ (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, ScalarKind::Signed) => {
+ Ok(SetpCompareInt::SignedGreaterOrEq)
+ }
+ (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, _) => {
+ Ok(SetpCompareInt::UnsignedGreaterOrEq)
+ }
+ (RawSetpCompareOp::Equ, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Neu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Ltu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Leu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Gtu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Geu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Num, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Nan, _) => Err(PtxError::WrongType),
+ }
+ }
+}
+
+impl From<RawSetpCompareOp> for SetpCompareFloat {
+ fn from(value: RawSetpCompareOp) -> Self {
+ match value {
+ RawSetpCompareOp::Eq => SetpCompareFloat::Eq,
+ RawSetpCompareOp::Ne => SetpCompareFloat::NotEq,
+ RawSetpCompareOp::Lt => SetpCompareFloat::Less,
+ RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq,
+ RawSetpCompareOp::Gt => SetpCompareFloat::Greater,
+ RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq,
+ RawSetpCompareOp::Lo => SetpCompareFloat::Less,
+ RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq,
+ RawSetpCompareOp::Hi => SetpCompareFloat::Greater,
+ RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq,
+ RawSetpCompareOp::Equ => SetpCompareFloat::NanEq,
+ RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq,
+ RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess,
+ RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq,
+ RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater,
+ RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq,
+ RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan,
+ RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan,
+ }
+ }
+}
+
+pub struct CallDetails {
+ pub uniform: bool,
+ pub return_arguments: Vec<(Type, StateSpace)>,
+ pub input_arguments: Vec<(Type, StateSpace)>,
+}
+
+pub struct CallArgs<T: Operand> {
+ pub return_arguments: Vec<T::Ident>,
+ pub func: T::Ident,
+ pub input_arguments: Vec<T>,
+}
+
+impl<T: Operand> CallArgs<T> {
+ #[allow(dead_code)] // Used by generated code
+ fn visit<Err>(
+ &self,
+ details: &CallDetails,
+ visitor: &mut impl Visitor<T, Err>,
+ ) -> Result<(), Err> {
+ for (param, (type_, space)) in self
+ .return_arguments
+ .iter()
+ .zip(details.return_arguments.iter())
+ {
+ visitor.visit_ident(param, Some((type_, *space)), true, false)?;
+ }
+ visitor.visit_ident(&self.func, None, false, false)?;
+ for (param, (type_, space)) in self
+ .input_arguments
+ .iter()
+ .zip(details.input_arguments.iter())
+ {
+ visitor.visit(param, Some((type_, *space)), false, false)?;
+ }
+ Ok(())
+ }
+
+ #[allow(dead_code)] // Used by generated code
+ fn visit_mut<Err>(
+ &mut self,
+ details: &CallDetails,
+ visitor: &mut impl VisitorMut<T, Err>,
+ ) -> Result<(), Err> {
+ for (param, (type_, space)) in self
+ .return_arguments
+ .iter_mut()
+ .zip(details.return_arguments.iter())
+ {
+ visitor.visit_ident(param, Some((type_, *space)), true, false)?;
+ }
+ visitor.visit_ident(&mut self.func, None, false, false)?;
+ for (param, (type_, space)) in self
+ .input_arguments
+ .iter_mut()
+ .zip(details.input_arguments.iter())
+ {
+ visitor.visit(param, Some((type_, *space)), false, false)?;
+ }
+ Ok(())
+ }
+
+ #[allow(dead_code)] // Used by generated code
+ fn map<U: Operand, Err>(
+ self,
+ details: &CallDetails,
+ visitor: &mut impl VisitorMap<T, U, Err>,
+ ) -> Result<CallArgs<U>, Err> {
+ let return_arguments = self
+ .return_arguments
+ .into_iter()
+ .zip(details.return_arguments.iter())
+ .map(|(param, (type_, space))| {
+ visitor.visit_ident(param, Some((type_, *space)), true, false)
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ let func = visitor.visit_ident(self.func, None, false, false)?;
+ let input_arguments = self
+ .input_arguments
+ .into_iter()
+ .zip(details.input_arguments.iter())
+ .map(|(param, (type_, space))| {
+ visitor.visit(param, Some((type_, *space)), false, false)
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(CallArgs {
+ return_arguments,
+ func,
+ input_arguments,
+ })
+ }
+}
+
+pub struct CvtDetails {
+ pub from: ScalarType,
+ pub to: ScalarType,
+ pub mode: CvtMode,
+}
+
+pub enum CvtMode {
+ // int from int
+ ZeroExtend,
+ SignExtend,
+ Truncate,
+ Bitcast,
+ SaturateUnsignedToSigned,
+ SaturateSignedToUnsigned,
+ // float from float
+ FPExtend {
+ flush_to_zero: Option<bool>,
+ },
+ FPTruncate {
+ // float rounding
+ rounding: RoundingMode,
+ flush_to_zero: Option<bool>,
+ },
+ FPRound {
+ integer_rounding: Option<RoundingMode>,
+ flush_to_zero: Option<bool>,
+ },
+ // int from float
+ SignedFromFP {
+ rounding: RoundingMode,
+ flush_to_zero: Option<bool>,
+ }, // integer rounding
+ UnsignedFromFP {
+ rounding: RoundingMode,
+ flush_to_zero: Option<bool>,
+ }, // integer rounding
+ // float from int, ftz is allowed in the grammar, but clearly nonsensical
+ FPFromSigned(RoundingMode), // float rounding
+ FPFromUnsigned(RoundingMode), // float rounding
+}
+
+impl CvtDetails {
+ pub(crate) fn new(
+ errors: &mut Vec<PtxError>,
+ rnd: Option<RawRoundingMode>,
+ ftz: bool,
+ saturate: bool,
+ dst: ScalarType,
+ src: ScalarType,
+ ) -> Self {
+ if saturate && dst.kind() == ScalarKind::Float {
+ errors.push(PtxError::SyntaxError);
+ }
+ // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results.
+ let flush_to_zero = match (dst, src) {
+ (ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz),
+ _ => {
+ if ftz {
+ errors.push(PtxError::NonF32Ftz);
+ }
+ None
+ }
+ };
+ let rounding = rnd.map(Into::into);
+ let mut unwrap_rounding = || match rounding {
+ Some(rnd) => rnd,
+ None => {
+ errors.push(PtxError::SyntaxError);
+ RoundingMode::NearestEven
+ }
+ };
+ let mode = match (dst.kind(), src.kind()) {
+ (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) {
+ Ordering::Less => CvtMode::FPTruncate {
+ rounding: unwrap_rounding(),
+ flush_to_zero,
+ },
+ Ordering::Equal => CvtMode::FPRound {
+ integer_rounding: rounding,
+ flush_to_zero,
+ },
+ Ordering::Greater => {
+ if rounding.is_some() {
+ errors.push(PtxError::SyntaxError);
+ }
+ CvtMode::FPExtend { flush_to_zero }
+ }
+ },
+ (ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP {
+ rounding: unwrap_rounding(),
+ flush_to_zero,
+ },
+ (ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP {
+ rounding: unwrap_rounding(),
+ flush_to_zero,
+ },
+ (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()),
+ (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()),
+ (ScalarKind::Signed, ScalarKind::Unsigned) if saturate => {
+ CvtMode::SaturateUnsignedToSigned
+ }
+ (ScalarKind::Unsigned, ScalarKind::Signed) if saturate => {
+ CvtMode::SaturateSignedToUnsigned
+ }
+ (ScalarKind::Unsigned, ScalarKind::Signed)
+ | (ScalarKind::Signed, ScalarKind::Unsigned)
+ if dst.size_of() == src.size_of() =>
+ {
+ CvtMode::Bitcast
+ }
+ (ScalarKind::Unsigned, ScalarKind::Unsigned)
+ | (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) {
+ Ordering::Less => CvtMode::Truncate,
+ Ordering::Equal => CvtMode::Bitcast,
+ Ordering::Greater => {
+ if src.kind() == ScalarKind::Signed {
+ CvtMode::SignExtend
+ } else {
+ CvtMode::ZeroExtend
+ }
+ }
+ },
+ (ScalarKind::Unsigned, ScalarKind::Signed) => CvtMode::SaturateSignedToUnsigned,
+ (_, _) => {
+ errors.push(PtxError::SyntaxError);
+ CvtMode::Bitcast
+ }
+ };
+ CvtDetails {
+ mode,
+ to: dst,
+ from: src,
+ }
+ }
+}
+
+pub struct CvtIntToIntDesc {
+ pub dst: ScalarType,
+ pub src: ScalarType,
+ pub saturate: bool,
+}
+
+pub struct CvtDesc {
+ pub rounding: Option<RoundingMode>,
+ pub flush_to_zero: Option<bool>,
+ pub saturate: bool,
+ pub dst: ScalarType,
+ pub src: ScalarType,
+}
+
+pub struct ShrData {
+ pub type_: ScalarType,
+ pub kind: RightShiftKind,
+}
+
+pub enum RightShiftKind {
+ Arithmetic,
+ Logical,
+}
+
+pub struct CvtaDetails {
+ pub state_space: StateSpace,
+ pub direction: CvtaDirection,
+}
+
+pub enum CvtaDirection {
+ GenericToExplicit,
+ ExplicitToGeneric,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub struct TypeFtz {
+ pub flush_to_zero: Option<bool>,
+ pub type_: ScalarType,
+}
+
+#[derive(Copy, Clone)]
+pub enum MadDetails {
+ Integer {
+ control: MulIntControl,
+ saturate: bool,
+ type_: ScalarType,
+ },
+ Float(ArithFloat),
+}
+
+impl MadDetails {
+ pub fn dst_type(&self) -> ScalarType {
+ match self {
+ MadDetails::Integer {
+ type_,
+ control: MulIntControl::Wide,
+ ..
+ } => match type_ {
+ ScalarType::U16 => ScalarType::U32,
+ ScalarType::S16 => ScalarType::S32,
+ ScalarType::U32 => ScalarType::U64,
+ ScalarType::S32 => ScalarType::S64,
+ _ => unreachable!(),
+ },
+ _ => self.type_(),
+ }
+ }
+
+ fn type_(&self) -> ScalarType {
+ match self {
+ MadDetails::Integer { type_, .. } => *type_,
+ MadDetails::Float(arith) => arith.type_,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub enum MinMaxDetails {
+ Signed(ScalarType),
+ Unsigned(ScalarType),
+ Float(MinMaxFloat),
+}
+
+impl MinMaxDetails {
+ pub fn type_(&self) -> ScalarType {
+ match self {
+ MinMaxDetails::Signed(t) => *t,
+ MinMaxDetails::Unsigned(t) => *t,
+ MinMaxDetails::Float(float) => float.type_,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct MinMaxFloat {
+ pub flush_to_zero: Option<bool>,
+ pub nan: bool,
+ pub type_: ScalarType,
+}
+
+#[derive(Copy, Clone)]
+pub struct RcpData {
+ pub kind: RcpKind,
+ pub flush_to_zero: Option<bool>,
+ pub type_: ScalarType,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum RcpKind {
+ Approx,
+ Compliant(RoundingMode),
+}
+
+pub struct BarData {
+ pub aligned: bool,
+}
+
+pub struct AtomDetails {
+ pub type_: Type,
+ pub semantics: AtomSemantics,
+ pub scope: MemScope,
+ pub space: StateSpace,
+ pub op: AtomicOp,
+}
+
+#[derive(Copy, Clone)]
+pub enum AtomicOp {
+ And,
+ Or,
+ Xor,
+ Exchange,
+ Add,
+ IncrementWrap,
+ DecrementWrap,
+ SignedMin,
+ UnsignedMin,
+ SignedMax,
+ UnsignedMax,
+ FloatAdd,
+ FloatMin,
+ FloatMax,
+}
+
+impl AtomicOp {
+ pub(crate) fn new(op: super::RawAtomicOp, kind: ScalarKind) -> Self {
+ use super::RawAtomicOp;
+ match (op, kind) {
+ (RawAtomicOp::And, _) => Self::And,
+ (RawAtomicOp::Or, _) => Self::Or,
+ (RawAtomicOp::Xor, _) => Self::Xor,
+ (RawAtomicOp::Exch, _) => Self::Exchange,
+ (RawAtomicOp::Add, ScalarKind::Float) => Self::FloatAdd,
+ (RawAtomicOp::Add, _) => Self::Add,
+ (RawAtomicOp::Inc, _) => Self::IncrementWrap,
+ (RawAtomicOp::Dec, _) => Self::DecrementWrap,
+ (RawAtomicOp::Min, ScalarKind::Signed) => Self::SignedMin,
+ (RawAtomicOp::Min, ScalarKind::Float) => Self::FloatMin,
+ (RawAtomicOp::Min, _) => Self::UnsignedMin,
+ (RawAtomicOp::Max, ScalarKind::Signed) => Self::SignedMax,
+ (RawAtomicOp::Max, ScalarKind::Float) => Self::FloatMax,
+ (RawAtomicOp::Max, _) => Self::UnsignedMax,
+ }
+ }
+}
+
+pub struct AtomCasDetails {
+ pub type_: ScalarType,
+ pub semantics: AtomSemantics,
+ pub scope: MemScope,
+ pub space: StateSpace,
+}
+
+#[derive(Copy, Clone)]
+pub enum DivDetails {
+ Unsigned(ScalarType),
+ Signed(ScalarType),
+ Float(DivFloatDetails),
+}
+
+impl DivDetails {
+ pub fn type_(&self) -> ScalarType {
+ match self {
+ DivDetails::Unsigned(t) => *t,
+ DivDetails::Signed(t) => *t,
+ DivDetails::Float(float) => float.type_,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct DivFloatDetails {
+ pub type_: ScalarType,
+ pub flush_to_zero: Option<bool>,
+ pub kind: DivFloatKind,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum DivFloatKind {
+ Approx,
+ ApproxFull,
+ Rounding(RoundingMode),
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub struct FlushToZero {
+ pub flush_to_zero: bool,
+}
diff --git a/ptx_parser/src/check_args.py b/ptx_parser/src/check_args.py new file mode 100644 index 0000000..04ffdb9 --- /dev/null +++ b/ptx_parser/src/check_args.py @@ -0,0 +1,69 @@ +import os, sys, subprocess
+
+
+SPACE = [".reg", ".sreg", ".param", ".param::entry", ".param::func", ".local", ".global", ".const", ".shared", ".shared::cta", ".shared::cluster"]
+TYPE_AND_INIT = ["", " = 1", "[1]", "[1] = {1}"]
+MULTIVAR = ["", "<1>" ]
+VECTOR = ["", ".v2" ]
+
+HEADER = """
+ .version 8.5
+ .target sm_90
+ .address_size 64
+"""
+
+
+def directive(space, variable, multivar, vector):
+ return """{3}
+ {0} {4} .b32 variable{2} {1};
+ """.format(space, variable, multivar, HEADER, vector)
+
+def entry_arg(space, variable, multivar, vector):
+ return """{3}
+ .entry foobar ( {0} {4} .b32 variable{2} {1})
+ {{
+ ret;
+ }}
+ """.format(space, variable, multivar, HEADER, vector)
+
+
+def fn_arg(space, variable, multivar, vector):
+ return """{3}
+ .func foobar ( {0} {4} .b32 variable{2} {1})
+ {{
+ ret;
+ }}
+ """.format(space, variable, multivar, HEADER, vector)
+
+
+def fn_body(space, variable, multivar, vector):
+ return """{3}
+ .func foobar ()
+ {{
+ {0} {4} .b32 variable{2} {1};
+ ret;
+ }}
+ """.format(space, variable, multivar, HEADER, vector)
+
+
+def generate(generator):
+ legal = []
+ for space in SPACE:
+ for init in TYPE_AND_INIT:
+ for multi in MULTIVAR:
+ for vector in VECTOR:
+ ptx = generator(space, init, multi, vector)
+ if 0 == subprocess.call(["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin\\ptxas.exe", "-arch", "sm_90", "-ias", ptx], stdout = subprocess.DEVNULL): #
+ legal.append((space, vector, init, multi))
+ print(generator.__name__)
+ print(legal)
+
+
+def main():
+ generate(directive)
+ generate(entry_arg)
+ generate(fn_arg)
+ generate(fn_body)
+
+if __name__ == "__main__":
+ main()
diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs new file mode 100644 index 0000000..f842ace --- /dev/null +++ b/ptx_parser/src/lib.rs @@ -0,0 +1,3269 @@ +use derive_more::Display; +use logos::Logos; +use ptx_parser_macros::derive_parser; +use rustc_hash::FxHashMap; +use std::fmt::Debug; +use std::iter; +use std::num::{NonZeroU8, ParseFloatError, ParseIntError}; +use winnow::ascii::dec_uint; +use winnow::combinator::*; +use winnow::error::{ErrMode, ErrorKind}; +use winnow::stream::Accumulate; +use winnow::token::any; +use winnow::{ + error::{ContextError, ParserError}, + stream::{Offset, Stream, StreamIsPartial}, + PResult, +}; +use winnow::{prelude::*, Stateful}; + +mod ast; +pub use ast::*; + +impl From<RawMulIntControl> for ast::MulIntControl { + fn from(value: RawMulIntControl) -> Self { + match value { + RawMulIntControl::Lo => ast::MulIntControl::Low, + RawMulIntControl::Hi => ast::MulIntControl::High, + RawMulIntControl::Wide => ast::MulIntControl::Wide, + } + } +} + +impl From<RawStCacheOperator> for ast::StCacheOperator { + fn from(value: RawStCacheOperator) -> Self { + match value { + RawStCacheOperator::Wb => ast::StCacheOperator::Writeback, + RawStCacheOperator::Cg => ast::StCacheOperator::L2Only, + RawStCacheOperator::Cs => ast::StCacheOperator::Streaming, + RawStCacheOperator::Wt => ast::StCacheOperator::Writethrough, + } + } +} + +impl From<RawLdCacheOperator> for ast::LdCacheOperator { + fn from(value: RawLdCacheOperator) -> Self { + match value { + RawLdCacheOperator::Ca => ast::LdCacheOperator::Cached, + RawLdCacheOperator::Cg => ast::LdCacheOperator::L2Only, + RawLdCacheOperator::Cs => ast::LdCacheOperator::Streaming, + RawLdCacheOperator::Lu => ast::LdCacheOperator::LastUse, + RawLdCacheOperator::Cv => ast::LdCacheOperator::Uncached, + } + } +} + +impl From<RawLdStQualifier> for ast::LdStQualifier { + fn from(value: RawLdStQualifier) -> Self { + match value { + RawLdStQualifier::Weak => ast::LdStQualifier::Weak, + RawLdStQualifier::Volatile => ast::LdStQualifier::Volatile, + } + } +} + +impl From<RawRoundingMode> for ast::RoundingMode { + fn from(value: RawRoundingMode) -> Self { + match value { + RawRoundingMode::Rn | RawRoundingMode::Rni => ast::RoundingMode::NearestEven, + RawRoundingMode::Rz | RawRoundingMode::Rzi => ast::RoundingMode::Zero, + RawRoundingMode::Rm | RawRoundingMode::Rmi => ast::RoundingMode::NegativeInf, + RawRoundingMode::Rp | RawRoundingMode::Rpi => ast::RoundingMode::PositiveInf, + } + } +} + +impl VectorPrefix { + pub(crate) fn len(self) -> NonZeroU8 { + unsafe { + match self { + VectorPrefix::V2 => NonZeroU8::new_unchecked(2), + VectorPrefix::V4 => NonZeroU8::new_unchecked(4), + VectorPrefix::V8 => NonZeroU8::new_unchecked(8), + } + } + } +} + +struct PtxParserState<'a, 'input> { + errors: &'a mut Vec<PtxError>, + function_declarations: + FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>, +} + +impl<'a, 'input> PtxParserState<'a, 'input> { + fn new(errors: &'a mut Vec<PtxError>) -> Self { + Self { + errors, + function_declarations: FxHashMap::default(), + } + } + + fn record_function(&mut self, function_decl: &MethodDeclaration<'input, &'input str>) { + let name = match function_decl.name { + MethodName::Kernel(name) => name, + MethodName::Func(name) => name, + }; + let return_arguments = Self::get_type_space(&*function_decl.return_arguments); + let input_arguments = Self::get_type_space(&*function_decl.input_arguments); + // TODO: check if declarations match + self.function_declarations + .insert(name, (return_arguments, input_arguments)); + } + + fn get_type_space(input_arguments: &[Variable<&str>]) -> Vec<(Type, StateSpace)> { + input_arguments + .iter() + .map(|var| (var.v_type.clone(), var.state_space)) + .collect::<Vec<_>>() + } +} + +impl<'a, 'input> Debug for PtxParserState<'a, 'input> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("PtxParserState") + .field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */ + .finish() + } +} + +type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>; + +fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { + any.verify_map(|t| { + if let Token::Ident(text) = t { + Some(text) + } else if let Some(text) = t.opcode_text() { + Some(text) + } else { + None + } + }) + .parse_next(stream) +} + +fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> { + any.verify_map(|t| { + if let Token::DotIdent(text) = t { + Some(text) + } else { + None + } + }) + .parse_next(stream) +} + +fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> { + any.verify_map(|t| { + Some(match t { + Token::Hex(s) => { + if s.ends_with('U') { + (&s[2..s.len() - 1], 16, true) + } else { + (&s[2..], 16, false) + } + } + Token::Decimal(s) => { + let radix = if s.starts_with('0') { 8 } else { 10 }; + if s.ends_with('U') { + (&s[..s.len() - 1], radix, true) + } else { + (s, radix, false) + } + } + _ => return None, + }) + }) + .parse_next(stream) +} + +fn take_error<'a, 'input: 'a, O, E>( + mut parser: impl Parser<PtxParser<'a, 'input>, Result<O, (O, PtxError)>, E>, +) -> impl Parser<PtxParser<'a, 'input>, O, E> { + move |input: &mut PtxParser<'a, 'input>| { + Ok(match parser.parse_next(input)? { + Ok(x) => x, + Err((x, err)) => { + input.state.errors.push(err); + x + } + }) + } +} + +fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult<ast::ImmediateValue> { + take_error((opt(Token::Minus), num).map(|(neg, x)| { + let (num, radix, is_unsigned) = x; + if neg.is_some() { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(-x)), + Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))), + } + } else if is_unsigned { + match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))), + } + } else { + match i64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::S64(x)), + Err(_) => match u64::from_str_radix(num, radix) { + Ok(x) => Ok(ast::ImmediateValue::U64(x)), + Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))), + }, + } + } + })) + .parse_next(input) +} + +fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f32> { + take_error(any.verify_map(|t| match t { + Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) { + Ok(x) => Ok(f32::from_bits(x)), + Err(err) => Err((0.0, PtxError::from(err))), + }), + _ => None, + })) + .parse_next(stream) +} + +fn f64<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f64> { + take_error(any.verify_map(|t| match t { + Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) { + Ok(x) => Ok(f64::from_bits(x)), + Err(err) => Err((0.0, PtxError::from(err))), + }), + _ => None, + })) + .parse_next(stream) +} + +fn s32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<i32> { + take_error((opt(Token::Minus), num).map(|(sign, x)| { + let (text, radix, _) = x; + match i32::from_str_radix(text, radix) { + Ok(x) => Ok(if sign.is_some() { -x } else { x }), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn u8<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u8> { + take_error(num.map(|x| { + let (text, radix, _) = x; + match u8::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u32> { + take_error(num.map(|x| { + let (text, radix, _) = x; + match u32::from_str_radix(text, radix) { + Ok(x) => Ok(x), + Err(err) => Err((0, PtxError::from(err))), + } + })) + .parse_next(stream) +} + +fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::ImmediateValue> { + alt(( + int_immediate, + f32.map(ast::ImmediateValue::F32), + f64.map(ast::ImmediateValue::F64), + )) + .parse_next(stream) +} + +pub fn parse_module_unchecked<'input>(text: &'input str) -> Option<ast::Module<'input>> { + let lexer = Token::lexer(text); + let input = lexer.collect::<Result<Vec<_>, _>>().ok()?; + let mut errors = Vec::new(); + let state = PtxParserState::new(&mut errors); + let parser = PtxParser { + state, + input: &input[..], + }; + let parsing_result = module.parse(parser).ok(); + if !errors.is_empty() { + None + } else { + parsing_result + } +} + +pub fn parse_module_checked<'input>( + text: &'input str, +) -> Result<ast::Module<'input>, Vec<PtxError>> { + let mut lexer = Token::lexer(text); + let mut errors = Vec::new(); + let mut tokens = Vec::new(); + loop { + let maybe_token = match lexer.next() { + Some(maybe_token) => maybe_token, + None => break, + }; + match maybe_token { + Ok(token) => tokens.push(token), + Err(mut err) => { + err.0 = lexer.span(); + errors.push(PtxError::from(err)) + } + } + } + if !errors.is_empty() { + return Err(errors); + } + let parse_result = { + let state = PtxParserState::new(&mut errors); + let parser = PtxParser { + state, + input: &tokens[..], + }; + module + .parse(parser) + .map_err(|err| PtxError::Parser(err.into_inner())) + }; + match parse_result { + Ok(result) if errors.is_empty() => Ok(result), + Ok(_) => Err(errors), + Err(err) => { + errors.push(err); + Err(errors) + } + } +} + +fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> { + ( + version, + target, + opt(address_size), + repeat_without_none(directive), + eof, + ) + .map(|(version, _, _, directives, _)| ast::Module { + version, + directives, + }) + .parse_next(stream) +} + +fn address_size<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + (Token::DotAddressSize, u8_literal(64)) + .void() + .parse_next(stream) +} + +fn version<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u8, u8)> { + (Token::DotVersion, u8, Token::Dot, u8) + .map(|(_, major, _, minor)| (major, minor)) + .parse_next(stream) +} + +fn target<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, Option<char>)> { + preceded(Token::DotTarget, ident.and_then(shader_model)).parse_next(stream) +} + +fn shader_model<'a>(stream: &mut &str) -> PResult<(u32, Option<char>)> { + ( + "sm_", + dec_uint, + opt(any.verify(|c: &char| c.is_ascii_lowercase())), + eof, + ) + .map(|(_, digits, arch_variant, _)| (digits, arch_variant)) + .parse_next(stream) +} + +fn directive<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<Option<ast::Directive<'input, ast::ParsedOperand<&'input str>>>> { + alt(( + function.map(|(linking, func)| Some(ast::Directive::Method(linking, func))), + file.map(|_| None), + section.map(|_| None), + (module_variable, Token::Semicolon) + .map(|((linking, var), _)| Some(ast::Directive::Variable(linking, var))), + )) + .parse_next(stream) +} + +fn module_variable<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<(ast::LinkingDirective, ast::Variable<&'input str>)> { + let linking = linking_directives.parse_next(stream)?; + let var = global_space + .flat_map(|space| multi_variable(linking.contains(LinkingDirective::EXTERN), space)) + // TODO: support multi var in globals + .map(|multi_var| multi_var.var) + .parse_next(stream)?; + Ok((linking, var)) +} + +fn file<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotFile, + u32, + Token::String, + opt((Token::Comma, u32, Token::Comma, u32)), + ) + .void() + .parse_next(stream) +} + +fn section<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotSection.void(), + dot_ident.void(), + Token::LBrace.void(), + repeat::<_, _, (), _, _>(0.., section_dwarf_line), + Token::RBrace.void(), + ) + .void() + .parse_next(stream) +} + +fn section_dwarf_line<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt(( + (section_label, Token::Colon).void(), + (Token::DotB32, section_label, opt((Token::Add, u32))).void(), + (Token::DotB64, section_label, opt((Token::Add, u32))).void(), + ( + any_bit_type, + separated::<_, _, (), _, _, _, _>(1.., u32, Token::Comma), + ) + .void(), + )) + .parse_next(stream) +} + +fn any_bit_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt((Token::DotB8, Token::DotB16, Token::DotB32, Token::DotB64)) + .void() + .parse_next(stream) +} + +fn section_label<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + alt((ident, dot_ident)).void().parse_next(stream) +} + +fn function<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<( + ast::LinkingDirective, + ast::Function<'input, &'input str, ast::Statement<ParsedOperand<&'input str>>>, +)> { + let (linking, function) = ( + linking_directives, + method_declaration, + repeat(0.., tuning_directive), + function_body, + ) + .map(|(linking, func_directive, tuning, body)| { + ( + linking, + ast::Function { + func_directive, + tuning, + body, + }, + ) + }) + .parse_next(stream)?; + stream.state.record_function(&function.func_directive); + Ok((linking, function)) +} + +fn linking_directives<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::LinkingDirective> { + repeat( + 0.., + dispatch! { any; + Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN), + Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE), + Token::DotWeak => empty.value(ast::LinkingDirective::WEAK), + _ => fail + }, + ) + .fold(|| ast::LinkingDirective::NONE, |x, y| x | y) + .parse_next(stream) +} + +fn tuning_directive<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::TuningDirective> { + dispatch! {any; + Token::DotMaxnreg => u32.map(ast::TuningDirective::MaxNReg), + Token::DotMaxntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)), + Token::DotReqntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)), + Token::DotMinnctapersm => u32.map(ast::TuningDirective::MinNCtaPerSm), + _ => fail + } + .parse_next(stream) +} + +fn method_declaration<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::MethodDeclaration<'input, &'input str>> { + dispatch! {any; + Token::DotEntry => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{ + return_arguments: Vec::new(), name: ast::MethodName::Kernel(name), input_arguments, shared_mem: None + }), + Token::DotFunc => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| { + let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); + let name = ast::MethodName::Func(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } + }), + _ => fail + } + .parse_next(stream) +} + +fn fn_arguments<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<Vec<ast::Variable<&'input str>>> { + delimited( + Token::LParen, + separated(0.., fn_input, Token::Comma), + Token::RParen, + ) + .parse_next(stream) +} + +fn kernel_arguments<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<Vec<ast::Variable<&'input str>>> { + delimited( + Token::LParen, + separated(0.., kernel_input, Token::Comma), + Token::RParen, + ) + .parse_next(stream) +} + +fn kernel_input<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::Variable<&'input str>> { + preceded(Token::DotParam, method_parameter(StateSpace::Param)).parse_next(stream) +} + +fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> { + dispatch! { any; + Token::DotParam => method_parameter(StateSpace::Param), + Token::DotReg => method_parameter(StateSpace::Reg), + _ => fail + } + .parse_next(stream) +} + +fn tuple1to3_u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, u32, u32)> { + struct Tuple3AccumulateU32 { + index: usize, + value: (u32, u32, u32), + } + + impl Accumulate<u32> for Tuple3AccumulateU32 { + fn initial(_: Option<usize>) -> Self { + Self { + index: 0, + value: (1, 1, 1), + } + } + + fn accumulate(&mut self, value: u32) { + match self.index { + 0 => { + self.value = (value, self.value.1, self.value.2); + self.index = 1; + } + 1 => { + self.value = (self.value.0, value, self.value.2); + self.index = 2; + } + 2 => { + self.value = (self.value.0, self.value.1, value); + self.index = 3; + } + _ => unreachable!(), + } + } + } + + separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..=3, u32, Token::Comma) + .map(|acc| acc.value) + .parse_next(stream) +} + +fn function_body<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<Option<Vec<ast::Statement<ParsedOperandStr<'input>>>>> { + dispatch! {any; + Token::LBrace => terminated(repeat_without_none(statement), Token::RBrace).map(Some), + Token::Semicolon => empty.map(|_| None), + _ => fail + } + .parse_next(stream) +} + +fn statement<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<Option<Statement<ParsedOperandStr<'input>>>> { + alt(( + label.map(Some), + debug_directive.map(|_| None), + terminated( + method_space + .flat_map(|space| multi_variable(false, space)) + .map(|var| Some(Statement::Variable(var))), + Token::Semicolon, + ), + predicated_instruction.map(Some), + pragma.map(|_| None), + block_statement.map(Some), + )) + .parse_next(stream) +} + +fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + (Token::DotPragma, Token::String, Token::Semicolon) + .void() + .parse_next(stream) +} + +fn method_parameter<'a, 'input: 'a>( + state_space: StateSpace, +) -> impl Parser<PtxParser<'a, 'input>, Variable<&'input str>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + let (align, vector, type_, name) = variable_declaration.parse_next(stream)?; + let array_dimensions = if state_space != StateSpace::Reg { + opt(array_dimensions).parse_next(stream)? + } else { + None + }; + // TODO: push this check into array_dimensions(...) + if let Some(ref dims) = array_dimensions { + if dims[0] == 0 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + } + Ok(Variable { + align, + v_type: Type::maybe_array(vector, type_, array_dimensions), + state_space, + name, + array_init: Vec::new(), + }) + } +} + +// TODO: split to a separate type +fn variable_declaration<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<(Option<u32>, Option<NonZeroU8>, ScalarType, &'input str)> { + ( + opt(align.verify(|x| x.count_ones() == 1)), + vector_prefix, + scalar_type, + ident, + ) + .parse_next(stream) +} + +fn multi_variable<'a, 'input: 'a>( + extern_: bool, + state_space: StateSpace, +) -> impl Parser<PtxParser<'a, 'input>, MultiVariable<&'input str>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + let ((align, vector, type_, name), count) = ( + variable_declaration, + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names + opt(delimited(Token::Lt, u32.verify(|x| *x != 0), Token::Gt)), + ) + .parse_next(stream)?; + if count.is_some() { + return Ok(MultiVariable { + var: Variable { + align, + v_type: Type::maybe_vector_parsed(vector, type_), + state_space, + name, + array_init: Vec::new(), + }, + count, + }); + } + let mut array_dimensions = if state_space != StateSpace::Reg { + opt(array_dimensions).parse_next(stream)? + } else { + None + }; + let initializer = match state_space { + StateSpace::Global | StateSpace::Const => match array_dimensions { + Some(ref mut dimensions) => { + opt(array_initializer(vector, type_, dimensions)).parse_next(stream)? + } + None => opt(value_initializer(vector, type_)).parse_next(stream)?, + }, + _ => None, + }; + if let Some(ref dims) = array_dimensions { + if !extern_ && dims[0] == 0 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + } + Ok(MultiVariable { + var: Variable { + align, + v_type: Type::maybe_array(vector, type_, array_dimensions), + state_space, + name, + array_init: initializer.unwrap_or(Vec::new()), + }, + count, + }) + } +} + +fn array_initializer<'a, 'input: 'a>( + vector: Option<NonZeroU8>, + type_: ScalarType, + array_dimensions: &mut Vec<u32>, +) -> impl Parser<PtxParser<'a, 'input>, Vec<u8>, ContextError> + '_ { + move |stream: &mut PtxParser<'a, 'input>| { + Token::Eq.parse_next(stream)?; + let mut result = Vec::new(); + // TODO: vector constants and multi dim arrays + if vector.is_some() || array_dimensions[0] == 0 || array_dimensions.len() > 1 { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + delimited( + Token::LBrace, + separated( + 0..=array_dimensions[0] as usize, + single_value_append(&mut result, type_), + Token::Comma, + ), + Token::RBrace, + ) + .parse_next(stream)?; + // pad with zeros + let result_size = type_.size_of() as usize * array_dimensions[0] as usize; + result.extend(iter::repeat(0u8).take(result_size - result.len())); + Ok(result) + } +} + +fn value_initializer<'a, 'input: 'a>( + vector: Option<NonZeroU8>, + type_: ScalarType, +) -> impl Parser<PtxParser<'a, 'input>, Vec<u8>, ContextError> { + move |stream: &mut PtxParser<'a, 'input>| { + Token::Eq.parse_next(stream)?; + let mut result = Vec::new(); + // TODO: vector constants + if vector.is_some() { + return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)); + } + single_value_append(&mut result, type_).parse_next(stream)?; + Ok(result) + } +} + +fn single_value_append<'a, 'input: 'a>( + accumulator: &mut Vec<u8>, + type_: ScalarType, +) -> impl Parser<PtxParser<'a, 'input>, (), ContextError> + '_ { + move |stream: &mut PtxParser<'a, 'input>| { + let value = immediate_value.parse_next(stream)?; + match (type_, value) { + (ScalarType::U8 | ScalarType::B8, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u8).to_le_bytes()) + } + (ScalarType::U8 | ScalarType::B8, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u8).to_le_bytes()) + } + (ScalarType::U16 | ScalarType::B16, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u16).to_le_bytes()) + } + (ScalarType::U16 | ScalarType::B16, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u16).to_le_bytes()) + } + (ScalarType::U32 | ScalarType::B32, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u32).to_le_bytes()) + } + (ScalarType::U32 | ScalarType::B32, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u32).to_le_bytes()) + } + (ScalarType::U64 | ScalarType::B64, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as u64).to_le_bytes()) + } + (ScalarType::U64 | ScalarType::B64, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as u64).to_le_bytes()) + } + (ScalarType::S8, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i8).to_le_bytes()) + } + (ScalarType::S8, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i8).to_le_bytes()) + } + (ScalarType::S16, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i16).to_le_bytes()) + } + (ScalarType::S16, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i16).to_le_bytes()) + } + (ScalarType::S32, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i32).to_le_bytes()) + } + (ScalarType::S32, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i32).to_le_bytes()) + } + (ScalarType::S64, ImmediateValue::U64(x)) => { + accumulator.extend_from_slice(&(x as i64).to_le_bytes()) + } + (ScalarType::S64, ImmediateValue::S64(x)) => { + accumulator.extend_from_slice(&(x as i64).to_le_bytes()) + } + (ScalarType::F32, ImmediateValue::F32(x)) => { + accumulator.extend_from_slice(&x.to_le_bytes()) + } + (ScalarType::F64, ImmediateValue::F64(x)) => { + accumulator.extend_from_slice(&x.to_le_bytes()) + } + _ => return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)), + } + Ok(()) + } +} + +fn array_dimensions<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Vec<u32>> { + let dimension = delimited( + Token::LBracket, + opt(u32).verify(|dim| *dim != Some(0)), + Token::RBracket, + ) + .parse_next(stream)?; + let result = vec![dimension.unwrap_or(0)]; + repeat_fold_0_or_more( + delimited( + Token::LBracket, + u32.verify(|dim| *dim != 0), + Token::RBracket, + ), + move || result, + |mut result: Vec<u32>, x| { + result.push(x); + result + }, + stream, + ) +} + +// Copied and fixed from Winnow sources (fold_repeat0_) +// Winnow Repeat::fold takes FnMut() -> Result to initalize accumulator, +// this really should be FnOnce() -> Result +fn repeat_fold_0_or_more<I, O, E, F, G, H, R>( + mut f: F, + init: H, + mut g: G, + input: &mut I, +) -> PResult<R, E> +where + I: Stream, + F: Parser<I, O, E>, + G: FnMut(R, O) -> R, + H: FnOnce() -> R, + E: ParserError<I>, +{ + use winnow::error::ErrMode; + let mut res = init(); + loop { + let start = input.checkpoint(); + match f.parse_next(input) { + Ok(o) => { + res = g(res, o); + } + Err(ErrMode::Backtrack(_)) => { + input.reset(&start); + return Ok(res); + } + Err(e) => { + return Err(e); + } + } + } +} + +fn global_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<StateSpace> { + alt(( + Token::DotGlobal.value(StateSpace::Global), + Token::DotConst.value(StateSpace::Const), + Token::DotShared.value(StateSpace::Shared), + )) + .parse_next(stream) +} + +fn method_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<StateSpace> { + alt(( + Token::DotReg.value(StateSpace::Reg), + Token::DotLocal.value(StateSpace::Local), + Token::DotParam.value(StateSpace::Param), + global_space, + )) + .parse_next(stream) +} + +fn align<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u32> { + preceded(Token::DotAlign, u32).parse_next(stream) +} + +fn vector_prefix<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Option<NonZeroU8>> { + opt(alt(( + Token::DotV2.value(unsafe { NonZeroU8::new_unchecked(2) }), + Token::DotV4.value(unsafe { NonZeroU8::new_unchecked(4) }), + Token::DotV8.value(unsafe { NonZeroU8::new_unchecked(8) }), + ))) + .parse_next(stream) +} + +fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ScalarType> { + any.verify_map(|t| { + Some(match t { + Token::DotS8 => ScalarType::S8, + Token::DotS16 => ScalarType::S16, + Token::DotS16x2 => ScalarType::S16x2, + Token::DotS32 => ScalarType::S32, + Token::DotS64 => ScalarType::S64, + Token::DotU8 => ScalarType::U8, + Token::DotU16 => ScalarType::U16, + Token::DotU16x2 => ScalarType::U16x2, + Token::DotU32 => ScalarType::U32, + Token::DotU64 => ScalarType::U64, + Token::DotB8 => ScalarType::B8, + Token::DotB16 => ScalarType::B16, + Token::DotB32 => ScalarType::B32, + Token::DotB64 => ScalarType::B64, + Token::DotB128 => ScalarType::B128, + Token::DotPred => ScalarType::Pred, + Token::DotF16 => ScalarType::F16, + Token::DotF16x2 => ScalarType::F16x2, + Token::DotF32 => ScalarType::F32, + Token::DotF64 => ScalarType::F64, + Token::DotBF16 => ScalarType::BF16, + Token::DotBF16x2 => ScalarType::BF16x2, + _ => return None, + }) + }) + .parse_next(stream) +} + +fn predicated_instruction<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::Statement<ParsedOperandStr<'input>>> { + (opt(pred_at), parse_instruction, Token::Semicolon) + .map(|(p, i, _)| ast::Statement::Instruction(p, i)) + .parse_next(stream) +} + +fn pred_at<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::PredAt<&'input str>> { + (Token::At, opt(Token::Exclamation), ident) + .map(|(_, not, label)| ast::PredAt { + not: not.is_some(), + label, + }) + .parse_next(stream) +} + +fn label<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::Statement<ParsedOperandStr<'input>>> { + terminated(ident, Token::Colon) + .map(|l| ast::Statement::Label(l)) + .parse_next(stream) +} + +fn debug_directive<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> { + ( + Token::DotLoc, + u32, + u32, + u32, + opt(( + Token::Comma, + ident_literal("function_name"), + ident, + dispatch! { any; + Token::Comma => (ident_literal("inlined_at"), u32, u32, u32).void(), + Token::Plus => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(), + _ => fail + }, + )), + ) + .void() + .parse_next(stream) +} + +fn block_statement<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::Statement<ParsedOperandStr<'input>>> { + delimited(Token::LBrace, repeat_without_none(statement), Token::RBrace) + .map(|s| ast::Statement::Block(s)) + .parse_next(stream) +} + +fn repeat_without_none<Input: Stream, Output, Error: ParserError<Input>>( + parser: impl Parser<Input, Option<Output>, Error>, +) -> impl Parser<Input, Vec<Output>, Error> { + repeat(0.., parser).fold(Vec::new, |mut acc: Vec<_>, item| { + if let Some(item) = item { + acc.push(item); + } + acc + }) +} + +fn ident_literal< + 'a, + 'input, + I: Stream<Token = Token<'input>> + StreamIsPartial, + E: ParserError<I>, +>( + s: &'input str, +) -> impl Parser<I, (), E> + 'input { + move |stream: &mut I| { + any.verify(|t| matches!(t, Token::Ident(text) if *text == s)) + .void() + .parse_next(stream) + } +} + +fn u8_literal<'a, 'input>(x: u8) -> impl Parser<PtxParser<'a, 'input>, (), ContextError> { + move |stream: &mut PtxParser| u8.verify(|t| *t == x).void().parse_next(stream) +} + +impl<Ident> ast::ParsedOperand<Ident> { + fn parse<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, + ) -> PResult<ast::ParsedOperand<&'input str>> { + use winnow::combinator::*; + use winnow::token::any; + fn vector_index<'input>(inp: &'input str) -> Result<u8, PtxError> { + match inp { + ".x" | ".r" => Ok(0), + ".y" | ".g" => Ok(1), + ".z" | ".b" => Ok(2), + ".w" | ".a" => Ok(3), + _ => Err(PtxError::WrongVectorElement), + } + } + fn ident_operands<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, + ) -> PResult<ast::ParsedOperand<&'input str>> { + let main_ident = ident.parse_next(stream)?; + alt(( + preceded(Token::Plus, s32) + .map(move |offset| ast::ParsedOperand::RegOffset(main_ident, offset)), + take_error(dot_ident.map(move |suffix| { + let vector_index = vector_index(suffix) + .map_err(move |e| (ast::ParsedOperand::VecMember(main_ident, 0), e))?; + Ok(ast::ParsedOperand::VecMember(main_ident, vector_index)) + })), + empty.value(ast::ParsedOperand::Reg(main_ident)), + )) + .parse_next(stream) + } + fn vector_operand<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, + ) -> PResult<Vec<&'input str>> { + let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?; + // TODO: parse .v8 literals + dispatch! {any; + Token::RBrace => empty.map(|_| vec![r1, r2]), + Token::Comma => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]), + _ => fail + } + .parse_next(stream) + } + alt(( + ident_operands, + immediate_value.map(ast::ParsedOperand::Imm), + vector_operand.map(ast::ParsedOperand::VecPack), + )) + .parse_next(stream) + } +} + +#[derive(Debug, thiserror::Error)] +pub enum PtxError { + #[error("{source}")] + ParseInt { + #[from] + source: ParseIntError, + }, + #[error("{source}")] + ParseFloat { + #[from] + source: ParseFloatError, + }, + #[error("{source}")] + Lexer { + #[from] + source: TokenError, + }, + #[error("")] + Parser(ContextError), + #[error("")] + Todo, + #[error("")] + SyntaxError, + #[error("")] + NonF32Ftz, + #[error("")] + Unsupported32Bit, + #[error("")] + WrongType, + #[error("")] + UnknownFunction, + #[error("")] + MalformedCall, + #[error("")] + WrongArrayType, + #[error("")] + WrongVectorElement, + #[error("")] + MultiArrayVariable, + #[error("")] + ZeroDimensionArray, + #[error("")] + ArrayInitalizer, + #[error("")] + NonExternPointer, + #[error("{start}:{end}")] + UnrecognizedStatement { start: usize, end: usize }, + #[error("{start}:{end}")] + UnrecognizedDirective { start: usize, end: usize }, +} + +#[derive(Debug)] +struct ReverseStream<'a, T>(pub &'a [T]); + +impl<'i, T> Stream for ReverseStream<'i, T> +where + T: Clone + ::std::fmt::Debug, +{ + type Token = T; + type Slice = &'i [T]; + + type IterOffsets = + std::iter::Enumerate<std::iter::Cloned<std::iter::Rev<std::slice::Iter<'i, T>>>>; + + type Checkpoint = &'i [T]; + + fn iter_offsets(&self) -> Self::IterOffsets { + self.0.iter().rev().cloned().enumerate() + } + + fn eof_offset(&self) -> usize { + self.0.len() + } + + fn next_token(&mut self) -> Option<Self::Token> { + let (token, next) = self.0.split_last()?; + self.0 = next; + Some(token.clone()) + } + + fn offset_for<P>(&self, predicate: P) -> Option<usize> + where + P: Fn(Self::Token) -> bool, + { + self.0.iter().rev().position(|b| predicate(b.clone())) + } + + fn offset_at(&self, tokens: usize) -> Result<usize, winnow::error::Needed> { + if let Some(needed) = tokens + .checked_sub(self.0.len()) + .and_then(std::num::NonZeroUsize::new) + { + Err(winnow::error::Needed::Size(needed)) + } else { + Ok(tokens) + } + } + + fn next_slice(&mut self, offset: usize) -> Self::Slice { + let offset = self.0.len() - offset; + let (next, slice) = self.0.split_at(offset); + self.0 = next; + slice + } + + fn checkpoint(&self) -> Self::Checkpoint { + self.0 + } + + fn reset(&mut self, checkpoint: &Self::Checkpoint) { + self.0 = checkpoint; + } + + fn raw(&self) -> &dyn std::fmt::Debug { + self + } +} + +impl<'a, T> Offset<&'a [T]> for ReverseStream<'a, T> { + fn offset_from(&self, start: &&'a [T]) -> usize { + let fst = start.as_ptr(); + let snd = self.0.as_ptr(); + + debug_assert!( + snd <= fst, + "`Offset::offset_from({snd:?}, {fst:?})` only accepts slices of `self`" + ); + (fst as usize - snd as usize) / std::mem::size_of::<T>() + } +} + +impl<'a, T> StreamIsPartial for ReverseStream<'a, T> { + type PartialState = (); + + fn complete(&mut self) -> Self::PartialState {} + + fn restore_partial(&mut self, _state: Self::PartialState) {} + + fn is_partial_supported() -> bool { + false + } +} + +impl<'input, I: Stream<Token = Self> + StreamIsPartial, E: ParserError<I>> Parser<I, Self, E> + for Token<'input> +{ + fn parse_next(&mut self, input: &mut I) -> PResult<Self, E> { + any.verify(|t| t == self).parse_next(input) + } +} + +fn bra<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::Instruction<ParsedOperandStr<'input>>> { + preceded( + opt(Token::DotUni), + any.verify_map(|t| match t { + Token::Ident(ident) => Some(ast::Instruction::Bra { + arguments: BraArgs { src: ident }, + }), + _ => None, + }), + ) + .parse_next(stream) +} + +fn call<'a, 'input>( + stream: &mut PtxParser<'a, 'input>, +) -> PResult<ast::Instruction<ParsedOperandStr<'input>>> { + let (uni, return_arguments, name, input_arguments) = ( + opt(Token::DotUni), + opt(( + Token::LParen, + separated(1.., ident, Token::Comma).map(|x: Vec<_>| x), + Token::RParen, + Token::Comma, + ) + .map(|(_, arguments, _, _)| arguments)), + ident, + opt(( + Token::Comma.void(), + Token::LParen.void(), + separated(1.., ParsedOperand::<&'input str>::parse, Token::Comma).map(|x: Vec<_>| x), + Token::RParen.void(), + ) + .map(|(_, _, arguments, _)| arguments)), + ) + .parse_next(stream)?; + let uniform = uni.is_some(); + let recorded_fn = match stream.state.function_declarations.get(name) { + Some(decl) => decl, + None => { + stream.state.errors.push(PtxError::UnknownFunction); + return Ok(empty_call(uniform, name)); + } + }; + let return_arguments = return_arguments.unwrap_or(Vec::new()); + let input_arguments = input_arguments.unwrap_or(Vec::new()); + if recorded_fn.0.len() != return_arguments.len() || recorded_fn.1.len() != input_arguments.len() + { + stream.state.errors.push(PtxError::MalformedCall); + return Ok(empty_call(uniform, name)); + } + let data = CallDetails { + uniform, + return_arguments: recorded_fn.0.clone(), + input_arguments: recorded_fn.1.clone(), + }; + let arguments = CallArgs { + return_arguments, + func: name, + input_arguments, + }; + Ok(ast::Instruction::Call { data, arguments }) +} + +fn empty_call<'input>( + uniform: bool, + name: &'input str, +) -> ast::Instruction<ParsedOperandStr<'input>> { + ast::Instruction::Call { + data: CallDetails { + uniform, + return_arguments: Vec::new(), + input_arguments: Vec::new(), + }, + arguments: CallArgs { + return_arguments: Vec::new(), + func: name, + input_arguments: Vec::new(), + }, + } +} + +type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>; + +#[derive(Clone, PartialEq, Default, Debug, Display)] +#[display("({}:{})", _0.start, _0.end)] +pub struct TokenError(std::ops::Range<usize>); + +impl std::error::Error for TokenError {} + +// This macro is responsible for generating parser code for instruction parser. +// Instruction parsing is by far the most complex part of parsing PTX code: +// * There are tens of instruction kinds, each with slightly different parsing rules +// * After parsing, each instruction needs to do some early validation and generate a specific, +// strongly-typed object. We want strong-typing because we have a single PTX parser frontend, but +// there can be multiple different code emitter backends +// * Most importantly, instruction modifiers can come in aby order, so e.g. both +// `ld.relaxed.global.u32 a, b` and `ld.global.relaxed.u32 a, b` are equally valid. This makes +// classic parsing generators fail: if we tried to generate parsing rules that cover every possible +// ordering we'd need thousands of rules. This is not a purely theoretical problem. NVCC and Clang +// will always emit modifiers in the correct order, but people who write inline assembly usually +// get it wrong (even first party developers) +// +// This macro exists purely to generate repetitive code for parsing each instruction. It is +// _not_ self-contained and is _not_ general-purpose: it relies on certain types and functions from +// the enclosing module +// +// derive_parser!(...) input is split into three parts: +// * Token type definition +// * Partial enums +// * Parsing definitions +// +// Token type definition: +// This is the enum type that will be usesby the instruction parser. For every instruction and +// modifier, derive_parser!(...) will add appropriate variant into this type. So e.g. if there is a +// rule for for `bar.sync` then those two variants wil be appended to the Token enum: +// #[token("bar")] Bar, +// #[token(".sync")] DotSync, +// +// Partial enums: +// With proper annotations, derive_parser!(...) parsing definitions are able to interpret +// instruction modifiers as variants of a single enum type. So e.g. for definitions `ld.u32` and +// `ld.u64` the macro can generate `enum ScalarType { U32, U64 }`. The problem is that for some +// (but not all) of those generated enum types we want to add some attributes and additional +// variants. In order to do so, you need to define this enum and derive_parser!(...) will append to +// the type instead of creating a new type. This is sort of replacement for partial classes known +// from C# +// +// Parsing definitions: +// Parsing definitions consist of a list of patterns and rules: +// * Pattern consists of: +// * Opcode: `ld` +// * Modifiers, always start with a dot: `.global`, `.relaxed`. Optionals are enclosed in braces +// * Arguments: `a`, `b`. Optionals are enclosed in braces +// * Code block: => { <code expression> }. Code blocks implictly take all modifiers ansd arguments +// as parameters. All modifiers and arguments are passed to the code block: +// * If it is an alternative (as defined in rules list later): +// * If it is mandatory then its type is Foo (as defined by the relevant rule) +// * If it is optional then its type is Option<Foo> +// * Otherwise: +// * If it is mandatory then it is skipped +// * If it is optional then its type is `bool` +// * List of rules. They are associated with the preceding patterns (until different opcode or +// different rules). Rules are used to resolve modifiers. There are two types of rules: +// * Normal rule: `.foobar: FoobarEnum => { .a, .b, .c }`. This means that instead of `.foobar` we +// expecte one of `.a`, `.b`, `.c` and will emit value FoobarEnum::DotA, FoobarEnum::DotB, +// FoobarEnum::DotC appropriately +// * Type-only rule: `FoobarEnum => { .a, .b, .c }` this means that all the occurences of `.a` will +// emit FoobarEnum::DotA to the code block. This helps to avoid copy-paste errors +// Additionally, you can opt out from the usual parsing rule generation with a special `<=` pattern. +// See `call` instruction to see it in action +derive_parser!( + #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)] + #[logos(skip r"(?:\s+)|(?://[^\n\r]*[\n\r]*)|(?:/\*[^*]*\*+(?:[^/*][^*]*\*+)*/)")] + #[logos(error = TokenError)] + enum Token<'input> { + #[token(",")] + Comma, + #[token(".")] + Dot, + #[token(":")] + Colon, + #[token(";")] + Semicolon, + #[token("@")] + At, + #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] + Ident(&'input str), + #[regex(r"\.[a-zA-Z][a-zA-Z0-9_$]*|\.[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)] + DotIdent(&'input str), + #[regex(r#""[^"]*""#)] + String, + #[token("|")] + Pipe, + #[token("!")] + Exclamation, + #[token("(")] + LParen, + #[token(")")] + RParen, + #[token("[")] + LBracket, + #[token("]")] + RBracket, + #[token("{")] + LBrace, + #[token("}")] + RBrace, + #[token("<")] + Lt, + #[token(">")] + Gt, + #[regex(r"0[fF][0-9a-zA-Z]{8}", |lex| lex.slice())] + F32(&'input str), + #[regex(r"0[dD][0-9a-zA-Z]{16}", |lex| lex.slice())] + F64(&'input str), + #[regex(r"0[xX][0-9a-zA-Z]+U?", |lex| lex.slice())] + Hex(&'input str), + #[regex(r"[0-9]+U?", |lex| lex.slice())] + Decimal(&'input str), + #[token("-")] + Minus, + #[token("+")] + Plus, + #[token("=")] + Eq, + #[token(".version")] + DotVersion, + #[token(".loc")] + DotLoc, + #[token(".reg")] + DotReg, + #[token(".align")] + DotAlign, + #[token(".pragma")] + DotPragma, + #[token(".maxnreg")] + DotMaxnreg, + #[token(".maxntid")] + DotMaxntid, + #[token(".reqntid")] + DotReqntid, + #[token(".minnctapersm")] + DotMinnctapersm, + #[token(".entry")] + DotEntry, + #[token(".func")] + DotFunc, + #[token(".extern")] + DotExtern, + #[token(".visible")] + DotVisible, + #[token(".target")] + DotTarget, + #[token(".address_size")] + DotAddressSize, + #[token(".action")] + DotSection, + #[token(".file")] + DotFile + } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum StateSpace { + Reg, + Generic, + Sreg, + } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum MemScope { } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum ScalarType { } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum SetpBoolPostOp { } + + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub enum AtomSemantics { } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov + mov{.vec}.type d, a => { + Instruction::Mov { + data: ast::MovDetails::new(vec, type_), + arguments: MovArgs { dst: d, src: a }, + } + } + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .pred, + .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st + st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::St { + data: StData { + qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: cop.unwrap_or(RawStCacheOperator::Wb).into(), + typ: ast::Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } + } + st.volatile{.ss}{.vec}.type [a], b => { + Instruction::St { + data: StData { + qualifier: volatile.into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: ast::Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } + } + st.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::St { + data: StData { + qualifier: ast::LdStQualifier::Relaxed(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: ast::Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } + } + st.release.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::St { + data: StData { + qualifier: ast::LdStQualifier::Release(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: ast::Type::maybe_vector(vec, type_) + }, + arguments: StArgs { src1:a, src2:b } + } + } + st.mmio.relaxed.sys{.global}.type [a], b => { + state.errors.push(PtxError::Todo); + Instruction::St { + data: ast::StData { + qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), + state_space: global.unwrap_or(StateSpace::Generic), + caching: ast::StCacheOperator::Writeback, + typ: type_.into() + }, + arguments: ast::StArgs { src1:a, src2:b } + } + } + .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; + .level::cache_hint = { .L2::cache_hint }; + .cop: RawStCacheOperator = { .wb, .cg, .cs, .wt }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + RawLdStQualifier = { .weak, .volatile }; + StateSpace = { .global }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld + ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => { + let (a, unified) = a; + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || unified || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(), + typ: ast::Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => { + if level_prefetch_size.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: volatile.into(), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: ast::Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Relaxed(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: ast::Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Acquire(scope), + state_space: ss.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: ast::Type::maybe_vector(vec, type_), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + ld.mmio.relaxed.sys{.global}.type d, [a] => { + state.errors.push(PtxError::Todo); + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys), + state_space: global.unwrap_or(StateSpace::Generic), + caching: ast::LdCacheOperator::Cached, + typ: type_.into(), + non_coherent: false + }, + arguments: LdArgs { dst:d, src:a } + } + } + .ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} }; + .cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate }; + .level::cache_hint = { .L2::cache_hint }; + .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + RawLdStQualifier = { .weak, .volatile }; + StateSpace = { .global }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld-global-nc + ld.global{.cop}.nc{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => { + if cop.is_some() && level_eviction_priority.is_some() { + state.errors.push(PtxError::SyntaxError); + } + if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + Instruction::Ld { + data: LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: global, + caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(), + typ: Type::maybe_vector(vec, type_), + non_coherent: true + }, + arguments: LdArgs { dst:d, src:a } + } + } + .cop: RawLdCacheOperator = { .ca, .cg, .cs }; + .level::eviction_priority: EvictionPriority = + { .L1::evict_normal, .L1::evict_unchanged, + .L1::evict_first, .L1::evict_last, .L1::no_allocate}; + .level::cache_hint = { .L2::cache_hint }; + .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B }; + .vec: VectorPrefix = { .v2, .v4 }; + .type: ScalarType = { .b8, .b16, .b32, .b64, .b128, + .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .f32, .f64 }; + StateSpace = { .global }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add + add.type d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Integer( + ast::ArithInteger { + type_, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.sat}.s32 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Integer( + ast::ArithInteger { + type_: s32, + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s64, + .u16x2, .s16x2 }; + ScalarType = { .s32 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add + add{.rnd}{.ftz}{.sat}.f32 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.rnd}.f64 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add + add{.rnd}{.ftz}{.sat}.f16 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.rnd}.bf16 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + add{.rnd}.bf16x2 d, a, b => { + Instruction::Add { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: AddArgs { + dst: d, src1: a, src2: b + } + } + } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul + mul.mode.type d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Integer { + type_, + control: mode.into() + }, + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .mode: RawMulIntControl = { .hi, .lo }; + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + // "The .wide suffix is supported only for 16- and 32-bit integer types" + mul.wide.type d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Integer { + type_, + control: wide.into() + }, + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .u16, .u32, + .s16, .s32 }; + RawMulIntControl = { .wide }; + + mul{.rnd}{.ftz}{.sat}.f32 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.f64 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + mul{.rnd}{.ftz}{.sat}.f16 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.bf16 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + mul{.rnd}.bf16x2 d, a, b => { + ast::Instruction::Mul { + data: ast::MulDetails::Float ( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false, + } + ), + arguments: MulArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp + setp.CmpOp{.ftz}.type p[|q], a, b => { + let data = ast::SetpData::try_parse(state, cmpop, ftz, type_); + ast::Instruction::Setp { + data, + arguments: SetpArgs { dst1: p, dst2: q, src1: a, src2: b } + } + } + setp.CmpOp.BoolOp{.ftz}.type p[|q], a, b, {!}c => { + let (negate_src3, c) = c; + let base = ast::SetpData::try_parse(state, cmpop, ftz, type_); + let data = ast::SetpBoolData { + base, + bool_op: boolop, + negate_src3 + }; + ast::Instruction::SetpBool { + data, + arguments: SetpBoolArgs { dst1: p, dst2: q, src1: a, src2: b, src3: c } + } + } + .CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge, + .lo, .ls, .hi, .hs, // signed + .equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only + .BoolOp: SetpBoolPostOp = { .and, .or, .xor }; + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64, + .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not + not.type d, a => { + ast::Instruction::Not { + data: type_, + arguments: NotArgs { dst: d, src: a } + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-or + or.type d, a, b => { + ast::Instruction::Or { + data: type_, + arguments: OrArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-and + and.type d, a, b => { + ast::Instruction::And { + data: type_, + arguments: AndArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra + bra <= { bra(stream) } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call + call <= { call(stream) } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt + cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => { + let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype); + let arguments = ast::CvtArgs { dst: d, src: a }; + ast::Instruction::Cvt { + data, arguments + } + } + // cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a; + // cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b; + // cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a; + // cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b; + // cvt.rna{.satfinite}.tf32.f32 d, a; + // cvt.frnd2{.relu}.tf32.f32 d, a; + // cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b; + // cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a; + // cvt.rn.{.relu}.f16x2.f8x2type d, a; + + .ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi }; + .frnd2: RawRoundingMode = { .rn, .rz }; + .dtype: ScalarType = { .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .bf16, .f16, .f32, .f64 }; + .atype: ScalarType = { .u8, .u16, .u32, .u64, + .s8, .s16, .s32, .s64, + .bf16, .f16, .f32, .f64 }; + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl + shl.type d, a, b => { + ast::Instruction::Shl { data: type_, arguments: ShlArgs { dst: d, src1: a, src2: b } } + } + .type: ScalarType = { .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr + shr.type d, a, b => { + let kind = if type_.kind() == ast::ScalarKind::Signed { RightShiftKind::Arithmetic} else { RightShiftKind::Logical }; + ast::Instruction::Shr { + data: ast::ShrData { type_, kind }, + arguments: ShrArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvta + cvta.space.size p, a => { + if size != ScalarType::U64 { + state.errors.push(PtxError::Unsupported32Bit); + } + let data = ast::CvtaDetails { + state_space: space, + direction: ast::CvtaDirection::ExplicitToGeneric + }; + let arguments = ast::CvtaArgs { + dst: p, src: a + }; + ast::Instruction::Cvta { + data, arguments + } + } + cvta.to.space.size p, a => { + if size != ScalarType::U64 { + state.errors.push(PtxError::Unsupported32Bit); + } + let data = ast::CvtaDetails { + state_space: space, + direction: ast::CvtaDirection::GenericToExplicit + }; + let arguments = ast::CvtaArgs { + dst: p, src: a + }; + ast::Instruction::Cvta { + data, arguments + } + } + .space: StateSpace = { .const, .global, .local, .shared{::cta, ::cluster}, .param{::entry} }; + .size: ScalarType = { .u32, .u64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-abs + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-abs + abs.type d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: None, + type_ + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f32 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.f64 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: None, + type_: f64 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f16 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: Some(ftz), + type_: f16 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs{.ftz}.f16x2 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: Some(ftz), + type_: f16x2 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.bf16 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: None, + type_: bf16 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + abs.bf16x2 d, a => { + ast::Instruction::Abs { + data: ast::TypeFtz { + flush_to_zero: None, + type_: bf16x2 + }, + arguments: ast::AbsArgs { + dst: d, src: a + } + } + } + .type: ScalarType = { .s16, .s32, .s64 }; + ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad + mad.mode.type d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_, + control: mode.into(), + saturate: false + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + .mode: RawMulIntControl = { .hi, .lo }; + + // The .wide suffix is supported only for 16-bit and 32-bit integer types. + mad.wide.type d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_, + control: wide.into(), + saturate: false + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .u16, .u32, + .s16, .s32 }; + RawMulIntControl = { .wide }; + + mad.hi.sat.s32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Integer { + type_: s32, + control: hi.into(), + saturate: true + }, + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + RawMulIntControl = { .hi }; + ScalarType = { .s32 }; + + mad{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: None, + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + mad.rnd{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + mad.rnd.f64 d, a, b, c => { + ast::Instruction::Mad { + data: ast::MadDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: Some(rnd.into()), + flush_to_zero: None, + saturate: false + } + ), + arguments: MadArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-fma + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-fma + fma.rnd{.ftz}{.sat}.f32 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f32, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + fma.rnd.f64 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f64, + rounding: Some(rnd.into()), + flush_to_zero: None, + saturate: false + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + fma.rnd{.ftz}{.sat}.f16 d, a, b, c => { + ast::Instruction::Fma { + data: ast::ArithFloat { + type_: f16, + rounding: Some(rnd.into()), + flush_to_zero: Some(ftz), + saturate: sat + }, + arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c } + } + } + //fma.rnd{.ftz}{.sat}.f16x2 d, a, b, c; + //fma.rnd{.ftz}.relu.f16 d, a, b, c; + //fma.rnd{.ftz}.relu.f16x2 d, a, b, c; + //fma.rnd{.relu}.bf16 d, a, b, c; + //fma.rnd{.relu}.bf16x2 d, a, b, c; + //fma.rnd.oob.{relu}.type d, a, b, c; + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub + sub.type d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Integer( + ArithInteger { + type_, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub.sat.s32 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Integer( + ArithInteger { + type_: s32, + saturate: true + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + ScalarType = { .s32 }; + + sub{.rnd}{.ftz}{.sat}.f32 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f32, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.f64 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f64, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + sub{.rnd}{.ftz}{.sat}.f16 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}{.ftz}{.sat}.f16x2 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: f16x2, + rounding: rnd.map(Into::into), + flush_to_zero: Some(ftz), + saturate: sat + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.bf16 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + sub{.rnd}.bf16x2 d, a, b => { + ast::Instruction::Sub { + data: ast::ArithDetails::Float( + ast::ArithFloat { + type_: bf16x2, + rounding: rnd.map(Into::into), + flush_to_zero: None, + saturate: false + } + ), + arguments: SubArgs { dst: d, src1: a, src2: b } + } + } + .rnd: RawRoundingMode = { .rn }; + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min + min.atype d, a, b => { + ast::Instruction::Min { + data: if atype.kind() == ast::ScalarKind::Signed { + ast::MinMaxDetails::Signed(atype) + } else { + ast::MinMaxDetails::Unsigned(atype) + }, + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + //min{.relu}.btype d, a, b => { todo!() } + min.btype d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Signed(btype), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + .atype: ScalarType = { .u16, .u32, .u64, + .u16x2, .s16, .s64 }; + .btype: ScalarType = { .s16x2, .s32 }; + + //min{.ftz}{.NaN}{.xorsign.abs}.f32 d, a, b; + min{.ftz}{.NaN}.f32 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f32 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min.f64 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan: false, + type_: f64 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f32, .f64 }; + + //min{.ftz}{.NaN}{.xorsign.abs}.f16 d, a, b; + //min{.ftz}{.NaN}{.xorsign.abs}.f16x2 d, a, b; + //min{.NaN}{.xorsign.abs}.bf16 d, a, b; + //min{.NaN}{.xorsign.abs}.bf16x2 d, a, b; + min{.ftz}{.NaN}.f16 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min{.ftz}{.NaN}.f16x2 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16x2 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min{.NaN}.bf16 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + min{.NaN}.bf16x2 d, a, b => { + ast::Instruction::Min { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16x2 + } + ), + arguments: MinArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max + max.atype d, a, b => { + ast::Instruction::Max { + data: if atype.kind() == ast::ScalarKind::Signed { + ast::MinMaxDetails::Signed(atype) + } else { + ast::MinMaxDetails::Unsigned(atype) + }, + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + //max{.relu}.btype d, a, b => { todo!() } + max.btype d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Signed(btype), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + .atype: ScalarType = { .u16, .u32, .u64, + .u16x2, .s16, .s64 }; + .btype: ScalarType = { .s16x2, .s32 }; + + //max{.ftz}{.NaN}{.xorsign.abs}.f32 d, a, b; + max{.ftz}{.NaN}.f32 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f32 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max.f64 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan: false, + type_: f64 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f32, .f64 }; + + //max{.ftz}{.NaN}{.xorsign.abs}.f16 d, a, b; + //max{.ftz}{.NaN}{.xorsign.abs}.f16x2 d, a, b; + //max{.NaN}{.xorsign.abs}.bf16 d, a, b; + //max{.NaN}{.xorsign.abs}.bf16x2 d, a, b; + max{.ftz}{.NaN}.f16 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max{.ftz}{.NaN}.f16x2 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: Some(ftz), + nan, + type_: f16x2 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max{.NaN}.bf16 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + max{.NaN}.bf16x2 d, a, b => { + ast::Instruction::Max { + data: ast::MinMaxDetails::Float( + MinMaxFloat { + flush_to_zero: None, + nan, + type_: bf16x2 + } + ), + arguments: MaxArgs { dst: d, src1: a, src2: b } + } + } + ScalarType = { .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp-approx-ftz-f64 + rcp.approx{.ftz}.type d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Approx, + flush_to_zero: Some(ftz), + type_ + }, + arguments: RcpArgs { dst: d, src: a } + } + } + rcp.rnd{.ftz}.f32 d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Compliant(rnd.into()), + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: RcpArgs { dst: d, src: a } + } + } + rcp.rnd.f64 d, a => { + ast::Instruction::Rcp { + data: ast::RcpData { + kind: ast::RcpKind::Compliant(rnd.into()), + flush_to_zero: None, + type_: f64 + }, + arguments: RcpArgs { dst: d, src: a } + } + } + .type: ScalarType = { .f32, .f64 }; + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sqrt + sqrt.approx{.ftz}.f32 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Approx, + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + sqrt.rnd{.ftz}.f32 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Compliant(rnd.into()), + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + sqrt.rnd.f64 d, a => { + ast::Instruction::Sqrt { + data: ast::RcpData { + kind: ast::RcpKind::Compliant(rnd.into()), + flush_to_zero: None, + type_: f64 + }, + arguments: SqrtArgs { dst: d, src: a } + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64 + rsqrt.approx{.ftz}.f32 d, a => { + ast::Instruction::Rsqrt { + data: ast::TypeFtz { + flush_to_zero: Some(ftz), + type_: f32 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + rsqrt.approx.f64 d, a => { + ast::Instruction::Rsqrt { + data: ast::TypeFtz { + flush_to_zero: None, + type_: f64 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + rsqrt.approx.ftz.f64 d, a => { + ast::Instruction::Rsqrt { + data: ast::TypeFtz { + flush_to_zero: None, + type_: f64 + }, + arguments: RsqrtArgs { dst: d, src: a } + } + } + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp + selp.type d, a, b, c => { + ast::Instruction::Selp { + data: type_, + arguments: SelpArgs { dst: d, src1: a, src2: b, src3: c } + } + } + .type: ScalarType = { .b16, .b32, .b64, + .u16, .u32, .u64, + .s16, .s32, .s64, + .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar + barrier{.cta}.sync{.aligned} a{, b} => { + let _ = cta; + ast::Instruction::Bar { + data: ast::BarData { aligned }, + arguments: BarArgs { src1: a, src2: b } + } + } + //barrier{.cta}.arrive{.aligned} a, b; + //barrier{.cta}.red.popc{.aligned}.u32 d, a{, b}, {!}c; + //barrier{.cta}.red.op{.aligned}.pred p, a{, b}, {!}c; + bar{.cta}.sync a{, b} => { + let _ = cta; + ast::Instruction::Bar { + data: ast::BarData { aligned: true }, + arguments: BarArgs { src1: a, src2: b } + } + } + //bar{.cta}.arrive a, b; + //bar{.cta}.red.popc.u32 d, a{, b}, {!}c; + //bar{.cta}.red.op.pred p, a{, b}, {!}c; + //.op = { .and, .or }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom + atom{.sem}{.scope}{.space}.op{.level::cache_hint}.type d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: space.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(op, type_.kind()), + type_: type_.into() + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.space}.cas.cas_type d, [a], b, c => { + ast::Instruction::AtomCas { + data: AtomCasDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: space.unwrap_or(StateSpace::Generic), + type_: cas_type + }, + arguments: AtomCasArgs { dst: d, src1: a, src2: b, src3: c } + } + } + atom{.sem}{.scope}{.space}.exch{.level::cache_hint}.b128 d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: space.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(exch, b128.kind()), + type_: b128.into() + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.global}.float_op{.level::cache_hint}.vec_32_bit.f32 d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: global.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(float_op, f32.kind()), + type_: ast::Type::Vector(vec_32_bit.len().get(), f32) + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.global}.float_op.noftz{.level::cache_hint}{.vec_16_bit}.half_word_type d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: global.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(float_op, half_word_type.kind()), + type_: ast::Type::maybe_vector(vec_16_bit, half_word_type) + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + atom{.sem}{.scope}{.global}.float_op.noftz{.level::cache_hint}{.vec_32_bit}.packed_type d, [a], b{, cache_policy} => { + if level_cache_hint || cache_policy.is_some() { + state.errors.push(PtxError::Todo); + } + ast::Instruction::Atom { + data: AtomDetails { + semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed), + scope: scope.unwrap_or(MemScope::Gpu), + space: global.unwrap_or(StateSpace::Generic), + op: ast::AtomicOp::new(float_op, packed_type.kind()), + type_: ast::Type::maybe_vector(vec_32_bit, packed_type) + }, + arguments: AtomArgs { dst: d, src1: a, src2: b } + } + } + .space: StateSpace = { .global, .shared{::cta, ::cluster} }; + .sem: AtomSemantics = { .relaxed, .acquire, .release, .acq_rel }; + .scope: MemScope = { .cta, .cluster, .gpu, .sys }; + .op: RawAtomicOp = { .and, .or, .xor, + .exch, + .add, .inc, .dec, + .min, .max }; + .level::cache_hint = { .L2::cache_hint }; + .type: ScalarType = { .b32, .b64, .u32, .u64, .s32, .s64, .f32, .f64 }; + .cas_type: ScalarType = { .b32, .b64, .u32, .u64, .s32, .s64, .f32, .f64, .b16, .b128 }; + .half_word_type: ScalarType = { .f16, .bf16 }; + .packed_type: ScalarType = { .f16x2, .bf16x2 }; + .vec_16_bit: VectorPrefix = { .v2, .v4, .v8 }; + .vec_32_bit: VectorPrefix = { .v2, .v4 }; + .float_op: RawAtomicOp = { .add, .min, .max }; + ScalarType = { .b16, .b128, .f32 }; + StateSpace = { .global }; + RawAtomicOp = { .exch }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-div + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div + div.type d, a, b => { + ast::Instruction::Div { + data: if type_.kind() == ast::ScalarKind::Signed { + ast::DivDetails::Signed(type_) + } else { + ast::DivDetails::Unsigned(type_) + }, + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + .type: ScalarType = { .u16, .u32, .u64, + .s16, .s32, .s64 }; + + div.approx{.ftz}.f32 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f32, + flush_to_zero: Some(ftz), + kind: ast::DivFloatKind::Approx + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + div.full{.ftz}.f32 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f32, + flush_to_zero: Some(ftz), + kind: ast::DivFloatKind::ApproxFull + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + div.rnd{.ftz}.f32 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f32, + flush_to_zero: Some(ftz), + kind: ast::DivFloatKind::Rounding(rnd.into()) + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + div.rnd.f64 d, a, b => { + ast::Instruction::Div { + data: ast::DivDetails::Float(ast::DivFloatDetails{ + type_: f64, + flush_to_zero: None, + kind: ast::DivFloatKind::Rounding(rnd.into()) + }), + arguments: DivArgs { + dst: d, + src1: a, + src2: b, + }, + } + } + .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp }; + ScalarType = { .f32, .f64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-neg + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-neg + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-neg + neg.type d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .s16, .s32, .s64 }; + + neg{.ftz}.f32 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f32, + flush_to_zero: Some(ftz) + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg.f64 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f64, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg{.ftz}.f16 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f16, + flush_to_zero: Some(ftz) + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg{.ftz}.f16x2 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: f16x2, + flush_to_zero: Some(ftz) + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg.bf16 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: bf16, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + neg.bf16x2 d, a => { + ast::Instruction::Neg { + data: TypeFtz { + type_: bf16x2, + flush_to_zero: None + }, + arguments: NegArgs { dst: d, src: a, }, + } + } + ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sin + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-cos + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-lg2 + sin.approx{.ftz}.f32 d, a => { + ast::Instruction::Sin { + data: ast::FlushToZero { + flush_to_zero: ftz + }, + arguments: SinArgs { dst: d, src: a, }, + } + } + cos.approx{.ftz}.f32 d, a => { + ast::Instruction::Cos { + data: ast::FlushToZero { + flush_to_zero: ftz + }, + arguments: CosArgs { dst: d, src: a, }, + } + } + lg2.approx{.ftz}.f32 d, a => { + ast::Instruction::Lg2 { + data: ast::FlushToZero { + flush_to_zero: ftz + }, + arguments: Lg2Args { dst: d, src: a, }, + } + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-ex2 + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-ex2 + ex2.approx{.ftz}.f32 d, a => { + ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_: f32, + flush_to_zero: Some(ftz) + }, + arguments: Ex2Args { dst: d, src: a, }, + } + } + ex2.approx.atype d, a => { + ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_: atype, + flush_to_zero: None + }, + arguments: Ex2Args { dst: d, src: a, }, + } + } + ex2.approx.ftz.btype d, a => { + ast::Instruction::Ex2 { + data: ast::TypeFtz { + type_: btype, + flush_to_zero: Some(true) + }, + arguments: Ex2Args { dst: d, src: a, }, + } + } + .atype: ScalarType = { .f16, .f16x2 }; + .btype: ScalarType = { .bf16, .bf16x2 }; + ScalarType = { .f32 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-clz + clz.type d, a => { + ast::Instruction::Clz { + data: type_, + arguments: ClzArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-brev + brev.type d, a => { + ast::Instruction::Brev { + data: type_, + arguments: BrevArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-popc + popc.type d, a => { + ast::Instruction::Popc { + data: type_, + arguments: PopcArgs { dst: d, src: a, }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-xor + xor.type d, a, b => { + ast::Instruction::Xor { + data: type_, + arguments: XorArgs { dst: d, src1: a, src2: b, }, + } + } + .type: ScalarType = { .pred, .b16, .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-rem + rem.type d, a, b => { + ast::Instruction::Rem { + data: type_, + arguments: RemArgs { dst: d, src1: a, src2: b, }, + } + } + .type: ScalarType = { .u16, .u32, .u64, .s16, .s32, .s64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfe + bfe.type d, a, b, c => { + ast::Instruction::Bfe { + data: type_, + arguments: BfeArgs { dst: d, src1: a, src2: b, src3: c }, + } + } + .type: ScalarType = { .u32, .u64, .s32, .s64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfi + bfi.type f, a, b, c, d => { + ast::Instruction::Bfi { + data: type_, + arguments: BfiArgs { dst: f, src1: a, src2: b, src3: c, src4: d }, + } + } + .type: ScalarType = { .b32, .b64 }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt + // prmt.b32{.mode} d, a, b, c; + // .mode = { .f4e, .b4e, .rc8, .ecl, .ecr, .rc16 }; + prmt.b32 d, a, b, c => { + match c { + ast::ParsedOperand::Imm(ImmediateValue::S64(control)) => ast::Instruction::Prmt { + data: control as u16, + arguments: PrmtArgs { + dst: d, src1: a, src2: b + } + }, + _ => ast::Instruction::PrmtSlow { + arguments: PrmtSlowArgs { + dst: d, src1: a, src2: b, src3: c + } + } + } + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-activemask + activemask.b32 d => { + ast::Instruction::Activemask { + arguments: ActivemaskArgs { dst: d } + } + } + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar + // fence{.sem}.scope; + // fence.op_restrict.release.cluster; + // fence.proxy.proxykind; + // fence.proxy.to_proxykind::from_proxykind.release.scope; + // fence.proxy.to_proxykind::from_proxykind.acquire.scope [addr], size; + //membar.proxy.proxykind; + //.sem = { .sc, .acq_rel }; + //.scope = { .cta, .cluster, .gpu, .sys }; + //.proxykind = { .alias, .async, async.global, .async.shared::{cta, cluster} }; + //.op_restrict = { .mbarrier_init }; + //.to_proxykind::from_proxykind = {.tensormap::generic}; + + membar.level => { + ast::Instruction::Membar { data: level } + } + membar.gl => { + ast::Instruction::Membar { data: MemScope::Gpu } + } + .level: MemScope = { .cta, .sys }; + + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret + ret{.uni} => { + Instruction::Ret { data: RetData { uniform: uni } } + } + +); + +#[cfg(test)] +mod tests { + use super::target; + use super::PtxParserState; + use super::Token; + use logos::Logos; + use winnow::prelude::*; + + #[test] + fn sm_11() { + let tokens = Token::lexer(".target sm_11") + .collect::<Result<Vec<_>, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: PtxParserState::new(), + }; + assert_eq!(target.parse(stream).unwrap(), (11, None)); + } + + #[test] + fn sm_90a() { + let tokens = Token::lexer(".target sm_90a") + .collect::<Result<Vec<_>, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: PtxParserState::new(), + }; + assert_eq!(target.parse(stream).unwrap(), (90, Some('a'))); + } + + #[test] + fn sm_90ab() { + let tokens = Token::lexer(".target sm_90ab") + .collect::<Result<Vec<_>, ()>>() + .unwrap(); + let stream = super::PtxParser { + input: &tokens[..], + state: PtxParserState::new(), + }; + assert!(target.parse(stream).is_err()); + } +} |