aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx_parser
diff options
context:
space:
mode:
Diffstat (limited to 'ptx_parser')
-rw-r--r--ptx_parser/Cargo.toml17
-rw-r--r--ptx_parser/src/ast.rs1695
-rw-r--r--ptx_parser/src/check_args.py69
-rw-r--r--ptx_parser/src/lib.rs3269
4 files changed, 5050 insertions, 0 deletions
diff --git a/ptx_parser/Cargo.toml b/ptx_parser/Cargo.toml
new file mode 100644
index 0000000..9032de5
--- /dev/null
+++ b/ptx_parser/Cargo.toml
@@ -0,0 +1,17 @@
+[package]
+name = "ptx_parser"
+version = "0.0.0"
+authors = ["Andrzej Janik <[email protected]>"]
+edition = "2021"
+
+[lib]
+
+[dependencies]
+logos = "0.14"
+winnow = { version = "0.6.18" }
+#winnow = { version = "0.6.18", features = ["debug"] }
+ptx_parser_macros = { path = "../ptx_parser_macros" }
+thiserror = "1.0"
+bitflags = "1.2"
+rustc-hash = "2.0.0"
+derive_more = { version = "1", features = ["display"] }
diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs
new file mode 100644
index 0000000..d0dc303
--- /dev/null
+++ b/ptx_parser/src/ast.rs
@@ -0,0 +1,1695 @@
+use super::{
+ AtomSemantics, MemScope, RawRoundingMode, RawSetpCompareOp, ScalarType, SetpBoolPostOp,
+ StateSpace, VectorPrefix,
+};
+use crate::{PtxError, PtxParserState};
+use bitflags::bitflags;
+use std::{cmp::Ordering, num::NonZeroU8};
+
+pub enum Statement<P: Operand> {
+ Label(P::Ident),
+ Variable(MultiVariable<P::Ident>),
+ Instruction(Option<PredAt<P::Ident>>, Instruction<P>),
+ Block(Vec<Statement<P>>),
+}
+
+// We define the instruction enum through the macro instead of normally, because we have some of how
+// we use this type in the compilee. Each instruction can be logically split into two parts:
+// properties that define instruction semantics (e.g. is memory load volatile?) that don't change
+// during compilation and arguments (e.g. memory load source and destination) that evolve during
+// compilation. To support compilation passes we need to be able to visit (and change) every
+// argument in a generic way. This macro has visibility over all the fields. Consequently, we use it
+// to generate visitor functions. There re three functions to support three different semantics:
+// visit-by-ref, visit-by-mutable-ref, visit-and-map. In a previous version of the compiler it was
+// done by hand and was very limiting (we supported only visit-and-map).
+// The visitor must implement appropriate visitor trait defined below this macro. For convenience,
+// we implemented visitors for some corresponding FnMut(...) types.
+// Properties in this macro are used to encode information about the instruction arguments (what
+// Rust type is used for it post-parsing, what PTX type does it expect, what PTX address space does
+// it expect, etc.).
+// This information is then available to a visitor.
+ptx_parser_macros::generate_instruction_type!(
+ pub enum Instruction<T: Operand> {
+ Mov {
+ type: { &data.typ },
+ data: MovDetails,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Ld {
+ type: { &data.typ },
+ data: LdDetails,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ relaxed_type_check: true,
+ },
+ src: {
+ repr: T,
+ space: { data.state_space },
+ }
+ }
+ },
+ Add {
+ type: { Type::from(data.type_()) },
+ data: ArithDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ St {
+ type: { &data.typ },
+ data: StData,
+ arguments<T>: {
+ src1: {
+ repr: T,
+ space: { data.state_space },
+ },
+ src2: {
+ repr: T,
+ relaxed_type_check: true,
+ }
+ }
+ },
+ Mul {
+ type: { Type::from(data.type_()) },
+ data: MulDetails,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: { Type::from(data.dst_type()) },
+ },
+ src1: T,
+ src2: T,
+ }
+ },
+ Setp {
+ data: SetpData,
+ arguments<T>: {
+ dst1: {
+ repr: T,
+ type: Type::from(ScalarType::Pred)
+ },
+ dst2: {
+ repr: Option<T>,
+ type: Type::from(ScalarType::Pred)
+ },
+ src1: {
+ repr: T,
+ type: Type::from(data.type_),
+ },
+ src2: {
+ repr: T,
+ type: Type::from(data.type_),
+ }
+ }
+ },
+ SetpBool {
+ data: SetpBoolData,
+ arguments<T>: {
+ dst1: {
+ repr: T,
+ type: Type::from(ScalarType::Pred)
+ },
+ dst2: {
+ repr: Option<T>,
+ type: Type::from(ScalarType::Pred)
+ },
+ src1: {
+ repr: T,
+ type: Type::from(data.base.type_),
+ },
+ src2: {
+ repr: T,
+ type: Type::from(data.base.type_),
+ },
+ src3: {
+ repr: T,
+ type: Type::from(ScalarType::Pred)
+ }
+ }
+ },
+ Not {
+ data: ScalarType,
+ type: { Type::Scalar(data.clone()) },
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Or {
+ data: ScalarType,
+ type: { Type::Scalar(data.clone()) },
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ And {
+ data: ScalarType,
+ type: { Type::Scalar(data.clone()) },
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Bra {
+ type: !,
+ arguments<T::Ident>: {
+ src: T
+ }
+ },
+ Call {
+ data: CallDetails,
+ arguments: CallArgs<T>,
+ visit: arguments.visit(data, visitor)?,
+ visit_mut: arguments.visit_mut(data, visitor)?,
+ map: Instruction::Call{ arguments: arguments.map(&data, visitor)?, data }
+ },
+ Cvt {
+ data: CvtDetails,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: { Type::Scalar(data.to) },
+ // TODO: double check
+ relaxed_type_check: true,
+ },
+ src: {
+ repr: T,
+ type: { Type::Scalar(data.from) },
+ relaxed_type_check: true,
+ },
+ }
+ },
+ Shr {
+ data: ShrData,
+ type: { Type::Scalar(data.type_.clone()) },
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: {
+ repr: T,
+ type: { Type::Scalar(ScalarType::U32) },
+ },
+ }
+ },
+ Shl {
+ data: ScalarType,
+ type: { Type::Scalar(data.clone()) },
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: {
+ repr: T,
+ type: { Type::Scalar(ScalarType::U32) },
+ },
+ }
+ },
+ Ret {
+ data: RetData
+ },
+ Cvta {
+ data: CvtaDetails,
+ type: { Type::Scalar(ScalarType::B64) },
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Abs {
+ data: TypeFtz,
+ type: { Type::Scalar(data.type_) },
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Mad {
+ type: { Type::from(data.type_()) },
+ data: MadDetails,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: { Type::from(data.dst_type()) },
+ },
+ src1: T,
+ src2: T,
+ src3: T,
+ }
+ },
+ Fma {
+ type: { Type::from(data.type_) },
+ data: ArithFloat,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ src3: T,
+ }
+ },
+ Sub {
+ type: { Type::from(data.type_()) },
+ data: ArithDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Min {
+ type: { Type::from(data.type_()) },
+ data: MinMaxDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Max {
+ type: { Type::from(data.type_()) },
+ data: MinMaxDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Rcp {
+ type: { Type::from(data.type_) },
+ data: RcpData,
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Sqrt {
+ type: { Type::from(data.type_) },
+ data: RcpData,
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Rsqrt {
+ type: { Type::from(data.type_) },
+ data: TypeFtz,
+ arguments<T>: {
+ dst: T,
+ src: T,
+ }
+ },
+ Selp {
+ type: { Type::Scalar(data.clone()) },
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ src3: {
+ repr: T,
+ type: Type::Scalar(ScalarType::Pred)
+ },
+ }
+ },
+ Bar {
+ type: Type::Scalar(ScalarType::U32),
+ data: BarData,
+ arguments<T>: {
+ src1: T,
+ src2: Option<T>,
+ }
+ },
+ Atom {
+ type: &data.type_,
+ data: AtomDetails,
+ arguments<T>: {
+ dst: T,
+ src1: {
+ repr: T,
+ space: { data.space },
+ },
+ src2: T,
+ }
+ },
+ AtomCas {
+ type: Type::Scalar(data.type_),
+ data: AtomCasDetails,
+ arguments<T>: {
+ dst: T,
+ src1: {
+ repr: T,
+ space: { data.space },
+ },
+ src2: T,
+ src3: T,
+ }
+ },
+ Div {
+ type: Type::Scalar(data.type_()),
+ data: DivDetails,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ }
+ },
+ Neg {
+ type: Type::Scalar(data.type_),
+ data: TypeFtz,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Sin {
+ type: Type::Scalar(ScalarType::F32),
+ data: FlushToZero,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Cos {
+ type: Type::Scalar(ScalarType::F32),
+ data: FlushToZero,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Lg2 {
+ type: Type::Scalar(ScalarType::F32),
+ data: FlushToZero,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Ex2 {
+ type: Type::Scalar(ScalarType::F32),
+ data: TypeFtz,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Clz {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ src: T
+ }
+ },
+ Brev {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src: T
+ }
+ },
+ Popc {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ src: T
+ }
+ },
+ Xor {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T
+ }
+ },
+ Rem {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T
+ }
+ },
+ Bfe {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ src3: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ }
+ },
+ Bfi {
+ type: Type::Scalar(data.clone()),
+ data: ScalarType,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ src3: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ src4: {
+ repr: T,
+ type: Type::Scalar(ScalarType::U32)
+ },
+ }
+ },
+ PrmtSlow {
+ type: Type::Scalar(ScalarType::U32),
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T,
+ src3: T
+ }
+ },
+ Prmt {
+ type: Type::Scalar(ScalarType::B32),
+ data: u16,
+ arguments<T>: {
+ dst: T,
+ src1: T,
+ src2: T
+ }
+ },
+ Activemask {
+ type: Type::Scalar(ScalarType::B32),
+ arguments<T>: {
+ dst: T
+ }
+ },
+ Membar {
+ data: MemScope
+ },
+ Trap { }
+ }
+);
+
+pub trait Visitor<T: Operand, Err> {
+ fn visit(
+ &mut self,
+ args: &T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err>;
+ fn visit_ident(
+ &mut self,
+ args: &T::Ident,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err>;
+}
+
+impl<
+ T: Operand,
+ Err,
+ Fn: FnMut(&T, Option<(&Type, StateSpace)>, bool, bool) -> Result<(), Err>,
+ > Visitor<T, Err> for Fn
+{
+ fn visit(
+ &mut self,
+ args: &T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err> {
+ (self)(args, type_space, is_dst, relaxed_type_check)
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: &T::Ident,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err> {
+ (self)(
+ &T::from_ident(*args),
+ type_space,
+ is_dst,
+ relaxed_type_check,
+ )
+ }
+}
+
+pub trait VisitorMut<T: Operand, Err> {
+ fn visit(
+ &mut self,
+ args: &mut T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err>;
+ fn visit_ident(
+ &mut self,
+ args: &mut T::Ident,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<(), Err>;
+}
+
+pub trait VisitorMap<From: Operand, To: Operand, Err> {
+ fn visit(
+ &mut self,
+ args: From,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<To, Err>;
+ fn visit_ident(
+ &mut self,
+ args: From::Ident,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<To::Ident, Err>;
+}
+
+impl<T: Copy, U: Copy, Err, Fn> VisitorMap<ParsedOperand<T>, ParsedOperand<U>, Err> for Fn
+where
+ Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result<U, Err>,
+{
+ fn visit(
+ &mut self,
+ args: ParsedOperand<T>,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<ParsedOperand<U>, Err> {
+ Ok(match args {
+ ParsedOperand::Reg(ident) => {
+ ParsedOperand::Reg((self)(ident, type_space, is_dst, relaxed_type_check)?)
+ }
+ ParsedOperand::RegOffset(ident, imm) => ParsedOperand::RegOffset(
+ (self)(ident, type_space, is_dst, relaxed_type_check)?,
+ imm,
+ ),
+ ParsedOperand::Imm(imm) => ParsedOperand::Imm(imm),
+ ParsedOperand::VecMember(ident, index) => ParsedOperand::VecMember(
+ (self)(ident, type_space, is_dst, relaxed_type_check)?,
+ index,
+ ),
+ ParsedOperand::VecPack(vec) => ParsedOperand::VecPack(
+ vec.into_iter()
+ .map(|ident| (self)(ident, type_space, is_dst, relaxed_type_check))
+ .collect::<Result<Vec<_>, _>>()?,
+ ),
+ })
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<U, Err> {
+ (self)(args, type_space, is_dst, relaxed_type_check)
+ }
+}
+
+impl<T: Operand<Ident = T>, U: Operand<Ident = U>, Err, Fn> VisitorMap<T, U, Err> for Fn
+where
+ Fn: FnMut(T, Option<(&Type, StateSpace)>, bool, bool) -> Result<U, Err>,
+{
+ fn visit(
+ &mut self,
+ args: T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<U, Err> {
+ (self)(args, type_space, is_dst, relaxed_type_check)
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: T,
+ type_space: Option<(&Type, StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<U, Err> {
+ (self)(args, type_space, is_dst, relaxed_type_check)
+ }
+}
+
+trait VisitOperand<Err> {
+ type Operand: Operand;
+ #[allow(unused)] // Used by generated code
+ fn visit(&self, fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err>;
+ #[allow(unused)] // Used by generated code
+ fn visit_mut(
+ &mut self,
+ fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
+ ) -> Result<(), Err>;
+}
+
+impl<T: Operand, Err> VisitOperand<Err> for T {
+ type Operand = Self;
+ fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
+ fn_(self)
+ }
+ fn visit_mut(
+ &mut self,
+ mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
+ ) -> Result<(), Err> {
+ fn_(self)
+ }
+}
+
+impl<T: Operand, Err> VisitOperand<Err> for Option<T> {
+ type Operand = T;
+ fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
+ if let Some(x) = self {
+ fn_(x)?;
+ }
+ Ok(())
+ }
+ fn visit_mut(
+ &mut self,
+ mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
+ ) -> Result<(), Err> {
+ if let Some(x) = self {
+ fn_(x)?;
+ }
+ Ok(())
+ }
+}
+
+impl<T: Operand, Err> VisitOperand<Err> for Vec<T> {
+ type Operand = T;
+ fn visit(&self, mut fn_: impl FnMut(&Self::Operand) -> Result<(), Err>) -> Result<(), Err> {
+ for o in self {
+ fn_(o)?;
+ }
+ Ok(())
+ }
+ fn visit_mut(
+ &mut self,
+ mut fn_: impl FnMut(&mut Self::Operand) -> Result<(), Err>,
+ ) -> Result<(), Err> {
+ for o in self {
+ fn_(o)?;
+ }
+ Ok(())
+ }
+}
+
+trait MapOperand<Err>: Sized {
+ type Input;
+ type Output<U>;
+ #[allow(unused)] // Used by generated code
+ fn map<U>(
+ self,
+ fn_: impl FnOnce(Self::Input) -> Result<U, Err>,
+ ) -> Result<Self::Output<U>, Err>;
+}
+
+impl<T: Operand, Err> MapOperand<Err> for T {
+ type Input = Self;
+ type Output<U> = U;
+ fn map<U>(self, fn_: impl FnOnce(T) -> Result<U, Err>) -> Result<U, Err> {
+ fn_(self)
+ }
+}
+
+impl<T: Operand, Err> MapOperand<Err> for Option<T> {
+ type Input = T;
+ type Output<U> = Option<U>;
+ fn map<U>(self, fn_: impl FnOnce(T) -> Result<U, Err>) -> Result<Option<U>, Err> {
+ self.map(|x| fn_(x)).transpose()
+ }
+}
+
+pub struct MultiVariable<ID> {
+ pub var: Variable<ID>,
+ pub count: Option<u32>,
+}
+
+#[derive(Clone)]
+pub struct Variable<ID> {
+ pub align: Option<u32>,
+ pub v_type: Type,
+ pub state_space: StateSpace,
+ pub name: ID,
+ pub array_init: Vec<u8>,
+}
+
+pub struct PredAt<ID> {
+ pub not: bool,
+ pub label: ID,
+}
+
+#[derive(PartialEq, Eq, Clone, Hash)]
+pub enum Type {
+ // .param.b32 foo;
+ Scalar(ScalarType),
+ // .param.v2.b32 foo;
+ Vector(u8, ScalarType),
+ // .param.b32 foo[4];
+ Array(Option<NonZeroU8>, ScalarType, Vec<u32>),
+ Pointer(ScalarType, StateSpace),
+}
+
+impl Type {
+ pub(crate) fn maybe_vector(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
+ match vector {
+ Some(prefix) => Type::Vector(prefix.len().get(), scalar),
+ None => Type::Scalar(scalar),
+ }
+ }
+
+ pub(crate) fn maybe_vector_parsed(prefix: Option<NonZeroU8>, scalar: ScalarType) -> Self {
+ match prefix {
+ Some(prefix) => Type::Vector(prefix.get(), scalar),
+ None => Type::Scalar(scalar),
+ }
+ }
+
+ pub(crate) fn maybe_array(
+ prefix: Option<NonZeroU8>,
+ scalar: ScalarType,
+ array: Option<Vec<u32>>,
+ ) -> Self {
+ match array {
+ Some(dimensions) => Type::Array(prefix, scalar, dimensions),
+ None => Self::maybe_vector_parsed(prefix, scalar),
+ }
+ }
+}
+
+impl ScalarType {
+ pub fn size_of(self) -> u8 {
+ match self {
+ ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => 1,
+ ScalarType::U16
+ | ScalarType::S16
+ | ScalarType::B16
+ | ScalarType::F16
+ | ScalarType::BF16 => 2,
+ ScalarType::U32
+ | ScalarType::S32
+ | ScalarType::B32
+ | ScalarType::F32
+ | ScalarType::U16x2
+ | ScalarType::S16x2
+ | ScalarType::F16x2
+ | ScalarType::BF16x2 => 4,
+ ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => 8,
+ ScalarType::B128 => 16,
+ ScalarType::Pred => 1,
+ }
+ }
+
+ pub fn kind(self) -> ScalarKind {
+ match self {
+ ScalarType::U8 => ScalarKind::Unsigned,
+ ScalarType::U16 => ScalarKind::Unsigned,
+ ScalarType::U16x2 => ScalarKind::Unsigned,
+ ScalarType::U32 => ScalarKind::Unsigned,
+ ScalarType::U64 => ScalarKind::Unsigned,
+ ScalarType::S8 => ScalarKind::Signed,
+ ScalarType::S16 => ScalarKind::Signed,
+ ScalarType::S16x2 => ScalarKind::Signed,
+ ScalarType::S32 => ScalarKind::Signed,
+ ScalarType::S64 => ScalarKind::Signed,
+ ScalarType::B8 => ScalarKind::Bit,
+ ScalarType::B16 => ScalarKind::Bit,
+ ScalarType::B32 => ScalarKind::Bit,
+ ScalarType::B64 => ScalarKind::Bit,
+ ScalarType::B128 => ScalarKind::Bit,
+ ScalarType::F16 => ScalarKind::Float,
+ ScalarType::F16x2 => ScalarKind::Float,
+ ScalarType::F32 => ScalarKind::Float,
+ ScalarType::F64 => ScalarKind::Float,
+ ScalarType::BF16 => ScalarKind::Float,
+ ScalarType::BF16x2 => ScalarKind::Float,
+ ScalarType::Pred => ScalarKind::Pred,
+ }
+ }
+}
+
+#[derive(Clone, Copy, PartialEq, Eq)]
+pub enum ScalarKind {
+ Bit,
+ Unsigned,
+ Signed,
+ Float,
+ Pred,
+}
+impl From<ScalarType> for Type {
+ fn from(value: ScalarType) -> Self {
+ Type::Scalar(value)
+ }
+}
+
+#[derive(Clone)]
+pub struct MovDetails {
+ pub typ: super::Type,
+ pub src_is_address: bool,
+ // two fields below are in use by member moves
+ pub dst_width: u8,
+ pub src_width: u8,
+ // This is in use by auto-generated movs
+ pub relaxed_src2_conv: bool,
+}
+
+impl MovDetails {
+ pub(crate) fn new(vector: Option<VectorPrefix>, scalar: ScalarType) -> Self {
+ MovDetails {
+ typ: Type::maybe_vector(vector, scalar),
+ src_is_address: false,
+ dst_width: 0,
+ src_width: 0,
+ relaxed_src2_conv: false,
+ }
+ }
+}
+
+#[derive(Clone)]
+pub enum ParsedOperand<Ident> {
+ Reg(Ident),
+ RegOffset(Ident, i32),
+ Imm(ImmediateValue),
+ VecMember(Ident, u8),
+ VecPack(Vec<Ident>),
+}
+
+impl<Ident: Copy> Operand for ParsedOperand<Ident> {
+ type Ident = Ident;
+
+ fn from_ident(ident: Self::Ident) -> Self {
+ ParsedOperand::Reg(ident)
+ }
+}
+
+pub trait Operand: Sized {
+ type Ident: Copy;
+
+ fn from_ident(ident: Self::Ident) -> Self;
+}
+
+#[derive(Copy, Clone)]
+pub enum ImmediateValue {
+ U64(u64),
+ S64(i64),
+ F32(f32),
+ F64(f64),
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum StCacheOperator {
+ Writeback,
+ L2Only,
+ Streaming,
+ Writethrough,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum LdCacheOperator {
+ Cached,
+ L2Only,
+ Streaming,
+ LastUse,
+ Uncached,
+}
+
+#[derive(Copy, Clone)]
+pub enum ArithDetails {
+ Integer(ArithInteger),
+ Float(ArithFloat),
+}
+
+impl ArithDetails {
+ pub fn type_(&self) -> ScalarType {
+ match self {
+ ArithDetails::Integer(t) => t.type_,
+ ArithDetails::Float(arith) => arith.type_,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct ArithInteger {
+ pub type_: ScalarType,
+ pub saturate: bool,
+}
+
+#[derive(Copy, Clone)]
+pub struct ArithFloat {
+ pub type_: ScalarType,
+ pub rounding: Option<RoundingMode>,
+ pub flush_to_zero: Option<bool>,
+ pub saturate: bool,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum LdStQualifier {
+ Weak,
+ Volatile,
+ Relaxed(MemScope),
+ Acquire(MemScope),
+ Release(MemScope),
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum RoundingMode {
+ NearestEven,
+ Zero,
+ NegativeInf,
+ PositiveInf,
+}
+
+pub struct LdDetails {
+ pub qualifier: LdStQualifier,
+ pub state_space: StateSpace,
+ pub caching: LdCacheOperator,
+ pub typ: Type,
+ pub non_coherent: bool,
+}
+
+pub struct StData {
+ pub qualifier: LdStQualifier,
+ pub state_space: StateSpace,
+ pub caching: StCacheOperator,
+ pub typ: Type,
+}
+
+#[derive(Copy, Clone)]
+pub struct RetData {
+ pub uniform: bool,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum TuningDirective {
+ MaxNReg(u32),
+ MaxNtid(u32, u32, u32),
+ ReqNtid(u32, u32, u32),
+ MinNCtaPerSm(u32),
+}
+
+pub struct MethodDeclaration<'input, ID> {
+ pub return_arguments: Vec<Variable<ID>>,
+ pub name: MethodName<'input, ID>,
+ pub input_arguments: Vec<Variable<ID>>,
+ pub shared_mem: Option<ID>,
+}
+
+impl<'input> MethodDeclaration<'input, &'input str> {
+ pub fn name(&self) -> &'input str {
+ match self.name {
+ MethodName::Kernel(n) => n,
+ MethodName::Func(n) => n,
+ }
+ }
+}
+
+#[derive(Hash, PartialEq, Eq, Copy, Clone)]
+pub enum MethodName<'input, ID> {
+ Kernel(&'input str),
+ Func(ID),
+}
+
+bitflags! {
+ pub struct LinkingDirective: u8 {
+ const NONE = 0b000;
+ const EXTERN = 0b001;
+ const VISIBLE = 0b10;
+ const WEAK = 0b100;
+ }
+}
+
+pub struct Function<'a, ID, S> {
+ pub func_directive: MethodDeclaration<'a, ID>,
+ pub tuning: Vec<TuningDirective>,
+ pub body: Option<Vec<S>>,
+}
+
+pub enum Directive<'input, O: Operand> {
+ Variable(LinkingDirective, Variable<O::Ident>),
+ Method(
+ LinkingDirective,
+ Function<'input, &'input str, Statement<O>>,
+ ),
+}
+
+pub struct Module<'input> {
+ pub version: (u8, u8),
+ pub directives: Vec<Directive<'input, ParsedOperand<&'input str>>>,
+}
+
+#[derive(Copy, Clone)]
+pub enum MulDetails {
+ Integer {
+ type_: ScalarType,
+ control: MulIntControl,
+ },
+ Float(ArithFloat),
+}
+
+impl MulDetails {
+ pub fn type_(&self) -> ScalarType {
+ match self {
+ MulDetails::Integer { type_, .. } => *type_,
+ MulDetails::Float(arith) => arith.type_,
+ }
+ }
+
+ pub fn dst_type(&self) -> ScalarType {
+ match self {
+ MulDetails::Integer {
+ type_,
+ control: MulIntControl::Wide,
+ } => match type_ {
+ ScalarType::U16 => ScalarType::U32,
+ ScalarType::S16 => ScalarType::S32,
+ ScalarType::U32 => ScalarType::U64,
+ ScalarType::S32 => ScalarType::S64,
+ _ => unreachable!(),
+ },
+ _ => self.type_(),
+ }
+ }
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub enum MulIntControl {
+ Low,
+ High,
+ Wide,
+}
+
+pub struct SetpData {
+ pub type_: ScalarType,
+ pub flush_to_zero: Option<bool>,
+ pub cmp_op: SetpCompareOp,
+}
+
+impl SetpData {
+ pub(crate) fn try_parse(
+ state: &mut PtxParserState,
+ cmp_op: super::RawSetpCompareOp,
+ ftz: bool,
+ type_: ScalarType,
+ ) -> Self {
+ let flush_to_zero = match (ftz, type_) {
+ (_, ScalarType::F32) => Some(ftz),
+ (true, _) => {
+ state.errors.push(PtxError::NonF32Ftz);
+ None
+ }
+ _ => None
+ };
+ let type_kind = type_.kind();
+ let cmp_op = if type_kind == ScalarKind::Float {
+ SetpCompareOp::Float(SetpCompareFloat::from(cmp_op))
+ } else {
+ match SetpCompareInt::try_from((cmp_op, type_kind)) {
+ Ok(op) => SetpCompareOp::Integer(op),
+ Err(err) => {
+ state.errors.push(err);
+ SetpCompareOp::Integer(SetpCompareInt::Eq)
+ }
+ }
+ };
+ Self {
+ type_,
+ flush_to_zero,
+ cmp_op,
+ }
+ }
+}
+
+pub struct SetpBoolData {
+ pub base: SetpData,
+ pub bool_op: SetpBoolPostOp,
+ pub negate_src3: bool,
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum SetpCompareOp {
+ Integer(SetpCompareInt),
+ Float(SetpCompareFloat),
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum SetpCompareInt {
+ Eq,
+ NotEq,
+ UnsignedLess,
+ UnsignedLessOrEq,
+ UnsignedGreater,
+ UnsignedGreaterOrEq,
+ SignedLess,
+ SignedLessOrEq,
+ SignedGreater,
+ SignedGreaterOrEq,
+}
+
+#[derive(PartialEq, Eq, Copy, Clone)]
+pub enum SetpCompareFloat {
+ Eq,
+ NotEq,
+ Less,
+ LessOrEq,
+ Greater,
+ GreaterOrEq,
+ NanEq,
+ NanNotEq,
+ NanLess,
+ NanLessOrEq,
+ NanGreater,
+ NanGreaterOrEq,
+ IsNotNan,
+ IsAnyNan,
+}
+
+impl TryFrom<(RawSetpCompareOp, ScalarKind)> for SetpCompareInt {
+ type Error = PtxError;
+
+ fn try_from((value, kind): (RawSetpCompareOp, ScalarKind)) -> Result<Self, PtxError> {
+ match (value, kind) {
+ (RawSetpCompareOp::Eq, _) => Ok(SetpCompareInt::Eq),
+ (RawSetpCompareOp::Ne, _) => Ok(SetpCompareInt::NotEq),
+ (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, ScalarKind::Signed) => {
+ Ok(SetpCompareInt::SignedLess)
+ }
+ (RawSetpCompareOp::Lt | RawSetpCompareOp::Lo, _) => Ok(SetpCompareInt::UnsignedLess),
+ (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, ScalarKind::Signed) => {
+ Ok(SetpCompareInt::SignedLessOrEq)
+ }
+ (RawSetpCompareOp::Le | RawSetpCompareOp::Ls, _) => {
+ Ok(SetpCompareInt::UnsignedLessOrEq)
+ }
+ (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, ScalarKind::Signed) => {
+ Ok(SetpCompareInt::SignedGreater)
+ }
+ (RawSetpCompareOp::Gt | RawSetpCompareOp::Hi, _) => Ok(SetpCompareInt::UnsignedGreater),
+ (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, ScalarKind::Signed) => {
+ Ok(SetpCompareInt::SignedGreaterOrEq)
+ }
+ (RawSetpCompareOp::Ge | RawSetpCompareOp::Hs, _) => {
+ Ok(SetpCompareInt::UnsignedGreaterOrEq)
+ }
+ (RawSetpCompareOp::Equ, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Neu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Ltu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Leu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Gtu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Geu, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Num, _) => Err(PtxError::WrongType),
+ (RawSetpCompareOp::Nan, _) => Err(PtxError::WrongType),
+ }
+ }
+}
+
+impl From<RawSetpCompareOp> for SetpCompareFloat {
+ fn from(value: RawSetpCompareOp) -> Self {
+ match value {
+ RawSetpCompareOp::Eq => SetpCompareFloat::Eq,
+ RawSetpCompareOp::Ne => SetpCompareFloat::NotEq,
+ RawSetpCompareOp::Lt => SetpCompareFloat::Less,
+ RawSetpCompareOp::Le => SetpCompareFloat::LessOrEq,
+ RawSetpCompareOp::Gt => SetpCompareFloat::Greater,
+ RawSetpCompareOp::Ge => SetpCompareFloat::GreaterOrEq,
+ RawSetpCompareOp::Lo => SetpCompareFloat::Less,
+ RawSetpCompareOp::Ls => SetpCompareFloat::LessOrEq,
+ RawSetpCompareOp::Hi => SetpCompareFloat::Greater,
+ RawSetpCompareOp::Hs => SetpCompareFloat::GreaterOrEq,
+ RawSetpCompareOp::Equ => SetpCompareFloat::NanEq,
+ RawSetpCompareOp::Neu => SetpCompareFloat::NanNotEq,
+ RawSetpCompareOp::Ltu => SetpCompareFloat::NanLess,
+ RawSetpCompareOp::Leu => SetpCompareFloat::NanLessOrEq,
+ RawSetpCompareOp::Gtu => SetpCompareFloat::NanGreater,
+ RawSetpCompareOp::Geu => SetpCompareFloat::NanGreaterOrEq,
+ RawSetpCompareOp::Num => SetpCompareFloat::IsNotNan,
+ RawSetpCompareOp::Nan => SetpCompareFloat::IsAnyNan,
+ }
+ }
+}
+
+pub struct CallDetails {
+ pub uniform: bool,
+ pub return_arguments: Vec<(Type, StateSpace)>,
+ pub input_arguments: Vec<(Type, StateSpace)>,
+}
+
+pub struct CallArgs<T: Operand> {
+ pub return_arguments: Vec<T::Ident>,
+ pub func: T::Ident,
+ pub input_arguments: Vec<T>,
+}
+
+impl<T: Operand> CallArgs<T> {
+ #[allow(dead_code)] // Used by generated code
+ fn visit<Err>(
+ &self,
+ details: &CallDetails,
+ visitor: &mut impl Visitor<T, Err>,
+ ) -> Result<(), Err> {
+ for (param, (type_, space)) in self
+ .return_arguments
+ .iter()
+ .zip(details.return_arguments.iter())
+ {
+ visitor.visit_ident(param, Some((type_, *space)), true, false)?;
+ }
+ visitor.visit_ident(&self.func, None, false, false)?;
+ for (param, (type_, space)) in self
+ .input_arguments
+ .iter()
+ .zip(details.input_arguments.iter())
+ {
+ visitor.visit(param, Some((type_, *space)), false, false)?;
+ }
+ Ok(())
+ }
+
+ #[allow(dead_code)] // Used by generated code
+ fn visit_mut<Err>(
+ &mut self,
+ details: &CallDetails,
+ visitor: &mut impl VisitorMut<T, Err>,
+ ) -> Result<(), Err> {
+ for (param, (type_, space)) in self
+ .return_arguments
+ .iter_mut()
+ .zip(details.return_arguments.iter())
+ {
+ visitor.visit_ident(param, Some((type_, *space)), true, false)?;
+ }
+ visitor.visit_ident(&mut self.func, None, false, false)?;
+ for (param, (type_, space)) in self
+ .input_arguments
+ .iter_mut()
+ .zip(details.input_arguments.iter())
+ {
+ visitor.visit(param, Some((type_, *space)), false, false)?;
+ }
+ Ok(())
+ }
+
+ #[allow(dead_code)] // Used by generated code
+ fn map<U: Operand, Err>(
+ self,
+ details: &CallDetails,
+ visitor: &mut impl VisitorMap<T, U, Err>,
+ ) -> Result<CallArgs<U>, Err> {
+ let return_arguments = self
+ .return_arguments
+ .into_iter()
+ .zip(details.return_arguments.iter())
+ .map(|(param, (type_, space))| {
+ visitor.visit_ident(param, Some((type_, *space)), true, false)
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ let func = visitor.visit_ident(self.func, None, false, false)?;
+ let input_arguments = self
+ .input_arguments
+ .into_iter()
+ .zip(details.input_arguments.iter())
+ .map(|(param, (type_, space))| {
+ visitor.visit(param, Some((type_, *space)), false, false)
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(CallArgs {
+ return_arguments,
+ func,
+ input_arguments,
+ })
+ }
+}
+
+pub struct CvtDetails {
+ pub from: ScalarType,
+ pub to: ScalarType,
+ pub mode: CvtMode,
+}
+
+pub enum CvtMode {
+ // int from int
+ ZeroExtend,
+ SignExtend,
+ Truncate,
+ Bitcast,
+ SaturateUnsignedToSigned,
+ SaturateSignedToUnsigned,
+ // float from float
+ FPExtend {
+ flush_to_zero: Option<bool>,
+ },
+ FPTruncate {
+ // float rounding
+ rounding: RoundingMode,
+ flush_to_zero: Option<bool>,
+ },
+ FPRound {
+ integer_rounding: Option<RoundingMode>,
+ flush_to_zero: Option<bool>,
+ },
+ // int from float
+ SignedFromFP {
+ rounding: RoundingMode,
+ flush_to_zero: Option<bool>,
+ }, // integer rounding
+ UnsignedFromFP {
+ rounding: RoundingMode,
+ flush_to_zero: Option<bool>,
+ }, // integer rounding
+ // float from int, ftz is allowed in the grammar, but clearly nonsensical
+ FPFromSigned(RoundingMode), // float rounding
+ FPFromUnsigned(RoundingMode), // float rounding
+}
+
+impl CvtDetails {
+ pub(crate) fn new(
+ errors: &mut Vec<PtxError>,
+ rnd: Option<RawRoundingMode>,
+ ftz: bool,
+ saturate: bool,
+ dst: ScalarType,
+ src: ScalarType,
+ ) -> Self {
+ if saturate && dst.kind() == ScalarKind::Float {
+ errors.push(PtxError::SyntaxError);
+ }
+ // Modifier .ftz can only be specified when either .dtype or .atype is .f32 and applies only to single precision (.f32) inputs and results.
+ let flush_to_zero = match (dst, src) {
+ (ScalarType::F32, _) | (_, ScalarType::F32) => Some(ftz),
+ _ => {
+ if ftz {
+ errors.push(PtxError::NonF32Ftz);
+ }
+ None
+ }
+ };
+ let rounding = rnd.map(Into::into);
+ let mut unwrap_rounding = || match rounding {
+ Some(rnd) => rnd,
+ None => {
+ errors.push(PtxError::SyntaxError);
+ RoundingMode::NearestEven
+ }
+ };
+ let mode = match (dst.kind(), src.kind()) {
+ (ScalarKind::Float, ScalarKind::Float) => match dst.size_of().cmp(&src.size_of()) {
+ Ordering::Less => CvtMode::FPTruncate {
+ rounding: unwrap_rounding(),
+ flush_to_zero,
+ },
+ Ordering::Equal => CvtMode::FPRound {
+ integer_rounding: rounding,
+ flush_to_zero,
+ },
+ Ordering::Greater => {
+ if rounding.is_some() {
+ errors.push(PtxError::SyntaxError);
+ }
+ CvtMode::FPExtend { flush_to_zero }
+ }
+ },
+ (ScalarKind::Unsigned, ScalarKind::Float) => CvtMode::UnsignedFromFP {
+ rounding: unwrap_rounding(),
+ flush_to_zero,
+ },
+ (ScalarKind::Signed, ScalarKind::Float) => CvtMode::SignedFromFP {
+ rounding: unwrap_rounding(),
+ flush_to_zero,
+ },
+ (ScalarKind::Float, ScalarKind::Unsigned) => CvtMode::FPFromUnsigned(unwrap_rounding()),
+ (ScalarKind::Float, ScalarKind::Signed) => CvtMode::FPFromSigned(unwrap_rounding()),
+ (ScalarKind::Signed, ScalarKind::Unsigned) if saturate => {
+ CvtMode::SaturateUnsignedToSigned
+ }
+ (ScalarKind::Unsigned, ScalarKind::Signed) if saturate => {
+ CvtMode::SaturateSignedToUnsigned
+ }
+ (ScalarKind::Unsigned, ScalarKind::Signed)
+ | (ScalarKind::Signed, ScalarKind::Unsigned)
+ if dst.size_of() == src.size_of() =>
+ {
+ CvtMode::Bitcast
+ }
+ (ScalarKind::Unsigned, ScalarKind::Unsigned)
+ | (ScalarKind::Signed, ScalarKind::Signed) => match dst.size_of().cmp(&src.size_of()) {
+ Ordering::Less => CvtMode::Truncate,
+ Ordering::Equal => CvtMode::Bitcast,
+ Ordering::Greater => {
+ if src.kind() == ScalarKind::Signed {
+ CvtMode::SignExtend
+ } else {
+ CvtMode::ZeroExtend
+ }
+ }
+ },
+ (ScalarKind::Unsigned, ScalarKind::Signed) => CvtMode::SaturateSignedToUnsigned,
+ (_, _) => {
+ errors.push(PtxError::SyntaxError);
+ CvtMode::Bitcast
+ }
+ };
+ CvtDetails {
+ mode,
+ to: dst,
+ from: src,
+ }
+ }
+}
+
+pub struct CvtIntToIntDesc {
+ pub dst: ScalarType,
+ pub src: ScalarType,
+ pub saturate: bool,
+}
+
+pub struct CvtDesc {
+ pub rounding: Option<RoundingMode>,
+ pub flush_to_zero: Option<bool>,
+ pub saturate: bool,
+ pub dst: ScalarType,
+ pub src: ScalarType,
+}
+
+pub struct ShrData {
+ pub type_: ScalarType,
+ pub kind: RightShiftKind,
+}
+
+pub enum RightShiftKind {
+ Arithmetic,
+ Logical,
+}
+
+pub struct CvtaDetails {
+ pub state_space: StateSpace,
+ pub direction: CvtaDirection,
+}
+
+pub enum CvtaDirection {
+ GenericToExplicit,
+ ExplicitToGeneric,
+}
+
+#[derive(Copy, Clone, PartialEq, Eq)]
+pub struct TypeFtz {
+ pub flush_to_zero: Option<bool>,
+ pub type_: ScalarType,
+}
+
+#[derive(Copy, Clone)]
+pub enum MadDetails {
+ Integer {
+ control: MulIntControl,
+ saturate: bool,
+ type_: ScalarType,
+ },
+ Float(ArithFloat),
+}
+
+impl MadDetails {
+ pub fn dst_type(&self) -> ScalarType {
+ match self {
+ MadDetails::Integer {
+ type_,
+ control: MulIntControl::Wide,
+ ..
+ } => match type_ {
+ ScalarType::U16 => ScalarType::U32,
+ ScalarType::S16 => ScalarType::S32,
+ ScalarType::U32 => ScalarType::U64,
+ ScalarType::S32 => ScalarType::S64,
+ _ => unreachable!(),
+ },
+ _ => self.type_(),
+ }
+ }
+
+ fn type_(&self) -> ScalarType {
+ match self {
+ MadDetails::Integer { type_, .. } => *type_,
+ MadDetails::Float(arith) => arith.type_,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub enum MinMaxDetails {
+ Signed(ScalarType),
+ Unsigned(ScalarType),
+ Float(MinMaxFloat),
+}
+
+impl MinMaxDetails {
+ pub fn type_(&self) -> ScalarType {
+ match self {
+ MinMaxDetails::Signed(t) => *t,
+ MinMaxDetails::Unsigned(t) => *t,
+ MinMaxDetails::Float(float) => float.type_,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct MinMaxFloat {
+ pub flush_to_zero: Option<bool>,
+ pub nan: bool,
+ pub type_: ScalarType,
+}
+
+#[derive(Copy, Clone)]
+pub struct RcpData {
+ pub kind: RcpKind,
+ pub flush_to_zero: Option<bool>,
+ pub type_: ScalarType,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum RcpKind {
+ Approx,
+ Compliant(RoundingMode),
+}
+
+pub struct BarData {
+ pub aligned: bool,
+}
+
+pub struct AtomDetails {
+ pub type_: Type,
+ pub semantics: AtomSemantics,
+ pub scope: MemScope,
+ pub space: StateSpace,
+ pub op: AtomicOp,
+}
+
+#[derive(Copy, Clone)]
+pub enum AtomicOp {
+ And,
+ Or,
+ Xor,
+ Exchange,
+ Add,
+ IncrementWrap,
+ DecrementWrap,
+ SignedMin,
+ UnsignedMin,
+ SignedMax,
+ UnsignedMax,
+ FloatAdd,
+ FloatMin,
+ FloatMax,
+}
+
+impl AtomicOp {
+ pub(crate) fn new(op: super::RawAtomicOp, kind: ScalarKind) -> Self {
+ use super::RawAtomicOp;
+ match (op, kind) {
+ (RawAtomicOp::And, _) => Self::And,
+ (RawAtomicOp::Or, _) => Self::Or,
+ (RawAtomicOp::Xor, _) => Self::Xor,
+ (RawAtomicOp::Exch, _) => Self::Exchange,
+ (RawAtomicOp::Add, ScalarKind::Float) => Self::FloatAdd,
+ (RawAtomicOp::Add, _) => Self::Add,
+ (RawAtomicOp::Inc, _) => Self::IncrementWrap,
+ (RawAtomicOp::Dec, _) => Self::DecrementWrap,
+ (RawAtomicOp::Min, ScalarKind::Signed) => Self::SignedMin,
+ (RawAtomicOp::Min, ScalarKind::Float) => Self::FloatMin,
+ (RawAtomicOp::Min, _) => Self::UnsignedMin,
+ (RawAtomicOp::Max, ScalarKind::Signed) => Self::SignedMax,
+ (RawAtomicOp::Max, ScalarKind::Float) => Self::FloatMax,
+ (RawAtomicOp::Max, _) => Self::UnsignedMax,
+ }
+ }
+}
+
+pub struct AtomCasDetails {
+ pub type_: ScalarType,
+ pub semantics: AtomSemantics,
+ pub scope: MemScope,
+ pub space: StateSpace,
+}
+
+#[derive(Copy, Clone)]
+pub enum DivDetails {
+ Unsigned(ScalarType),
+ Signed(ScalarType),
+ Float(DivFloatDetails),
+}
+
+impl DivDetails {
+ pub fn type_(&self) -> ScalarType {
+ match self {
+ DivDetails::Unsigned(t) => *t,
+ DivDetails::Signed(t) => *t,
+ DivDetails::Float(float) => float.type_,
+ }
+ }
+}
+
+#[derive(Copy, Clone)]
+pub struct DivFloatDetails {
+ pub type_: ScalarType,
+ pub flush_to_zero: Option<bool>,
+ pub kind: DivFloatKind,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub enum DivFloatKind {
+ Approx,
+ ApproxFull,
+ Rounding(RoundingMode),
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub struct FlushToZero {
+ pub flush_to_zero: bool,
+}
diff --git a/ptx_parser/src/check_args.py b/ptx_parser/src/check_args.py
new file mode 100644
index 0000000..04ffdb9
--- /dev/null
+++ b/ptx_parser/src/check_args.py
@@ -0,0 +1,69 @@
+import os, sys, subprocess
+
+
+SPACE = [".reg", ".sreg", ".param", ".param::entry", ".param::func", ".local", ".global", ".const", ".shared", ".shared::cta", ".shared::cluster"]
+TYPE_AND_INIT = ["", " = 1", "[1]", "[1] = {1}"]
+MULTIVAR = ["", "<1>" ]
+VECTOR = ["", ".v2" ]
+
+HEADER = """
+ .version 8.5
+ .target sm_90
+ .address_size 64
+"""
+
+
+def directive(space, variable, multivar, vector):
+ return """{3}
+ {0} {4} .b32 variable{2} {1};
+ """.format(space, variable, multivar, HEADER, vector)
+
+def entry_arg(space, variable, multivar, vector):
+ return """{3}
+ .entry foobar ( {0} {4} .b32 variable{2} {1})
+ {{
+ ret;
+ }}
+ """.format(space, variable, multivar, HEADER, vector)
+
+
+def fn_arg(space, variable, multivar, vector):
+ return """{3}
+ .func foobar ( {0} {4} .b32 variable{2} {1})
+ {{
+ ret;
+ }}
+ """.format(space, variable, multivar, HEADER, vector)
+
+
+def fn_body(space, variable, multivar, vector):
+ return """{3}
+ .func foobar ()
+ {{
+ {0} {4} .b32 variable{2} {1};
+ ret;
+ }}
+ """.format(space, variable, multivar, HEADER, vector)
+
+
+def generate(generator):
+ legal = []
+ for space in SPACE:
+ for init in TYPE_AND_INIT:
+ for multi in MULTIVAR:
+ for vector in VECTOR:
+ ptx = generator(space, init, multi, vector)
+ if 0 == subprocess.call(["C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v12.6\\bin\\ptxas.exe", "-arch", "sm_90", "-ias", ptx], stdout = subprocess.DEVNULL): #
+ legal.append((space, vector, init, multi))
+ print(generator.__name__)
+ print(legal)
+
+
+def main():
+ generate(directive)
+ generate(entry_arg)
+ generate(fn_arg)
+ generate(fn_body)
+
+if __name__ == "__main__":
+ main()
diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs
new file mode 100644
index 0000000..f842ace
--- /dev/null
+++ b/ptx_parser/src/lib.rs
@@ -0,0 +1,3269 @@
+use derive_more::Display;
+use logos::Logos;
+use ptx_parser_macros::derive_parser;
+use rustc_hash::FxHashMap;
+use std::fmt::Debug;
+use std::iter;
+use std::num::{NonZeroU8, ParseFloatError, ParseIntError};
+use winnow::ascii::dec_uint;
+use winnow::combinator::*;
+use winnow::error::{ErrMode, ErrorKind};
+use winnow::stream::Accumulate;
+use winnow::token::any;
+use winnow::{
+ error::{ContextError, ParserError},
+ stream::{Offset, Stream, StreamIsPartial},
+ PResult,
+};
+use winnow::{prelude::*, Stateful};
+
+mod ast;
+pub use ast::*;
+
+impl From<RawMulIntControl> for ast::MulIntControl {
+ fn from(value: RawMulIntControl) -> Self {
+ match value {
+ RawMulIntControl::Lo => ast::MulIntControl::Low,
+ RawMulIntControl::Hi => ast::MulIntControl::High,
+ RawMulIntControl::Wide => ast::MulIntControl::Wide,
+ }
+ }
+}
+
+impl From<RawStCacheOperator> for ast::StCacheOperator {
+ fn from(value: RawStCacheOperator) -> Self {
+ match value {
+ RawStCacheOperator::Wb => ast::StCacheOperator::Writeback,
+ RawStCacheOperator::Cg => ast::StCacheOperator::L2Only,
+ RawStCacheOperator::Cs => ast::StCacheOperator::Streaming,
+ RawStCacheOperator::Wt => ast::StCacheOperator::Writethrough,
+ }
+ }
+}
+
+impl From<RawLdCacheOperator> for ast::LdCacheOperator {
+ fn from(value: RawLdCacheOperator) -> Self {
+ match value {
+ RawLdCacheOperator::Ca => ast::LdCacheOperator::Cached,
+ RawLdCacheOperator::Cg => ast::LdCacheOperator::L2Only,
+ RawLdCacheOperator::Cs => ast::LdCacheOperator::Streaming,
+ RawLdCacheOperator::Lu => ast::LdCacheOperator::LastUse,
+ RawLdCacheOperator::Cv => ast::LdCacheOperator::Uncached,
+ }
+ }
+}
+
+impl From<RawLdStQualifier> for ast::LdStQualifier {
+ fn from(value: RawLdStQualifier) -> Self {
+ match value {
+ RawLdStQualifier::Weak => ast::LdStQualifier::Weak,
+ RawLdStQualifier::Volatile => ast::LdStQualifier::Volatile,
+ }
+ }
+}
+
+impl From<RawRoundingMode> for ast::RoundingMode {
+ fn from(value: RawRoundingMode) -> Self {
+ match value {
+ RawRoundingMode::Rn | RawRoundingMode::Rni => ast::RoundingMode::NearestEven,
+ RawRoundingMode::Rz | RawRoundingMode::Rzi => ast::RoundingMode::Zero,
+ RawRoundingMode::Rm | RawRoundingMode::Rmi => ast::RoundingMode::NegativeInf,
+ RawRoundingMode::Rp | RawRoundingMode::Rpi => ast::RoundingMode::PositiveInf,
+ }
+ }
+}
+
+impl VectorPrefix {
+ pub(crate) fn len(self) -> NonZeroU8 {
+ unsafe {
+ match self {
+ VectorPrefix::V2 => NonZeroU8::new_unchecked(2),
+ VectorPrefix::V4 => NonZeroU8::new_unchecked(4),
+ VectorPrefix::V8 => NonZeroU8::new_unchecked(8),
+ }
+ }
+ }
+}
+
+struct PtxParserState<'a, 'input> {
+ errors: &'a mut Vec<PtxError>,
+ function_declarations:
+ FxHashMap<&'input str, (Vec<(ast::Type, StateSpace)>, Vec<(ast::Type, StateSpace)>)>,
+}
+
+impl<'a, 'input> PtxParserState<'a, 'input> {
+ fn new(errors: &'a mut Vec<PtxError>) -> Self {
+ Self {
+ errors,
+ function_declarations: FxHashMap::default(),
+ }
+ }
+
+ fn record_function(&mut self, function_decl: &MethodDeclaration<'input, &'input str>) {
+ let name = match function_decl.name {
+ MethodName::Kernel(name) => name,
+ MethodName::Func(name) => name,
+ };
+ let return_arguments = Self::get_type_space(&*function_decl.return_arguments);
+ let input_arguments = Self::get_type_space(&*function_decl.input_arguments);
+ // TODO: check if declarations match
+ self.function_declarations
+ .insert(name, (return_arguments, input_arguments));
+ }
+
+ fn get_type_space(input_arguments: &[Variable<&str>]) -> Vec<(Type, StateSpace)> {
+ input_arguments
+ .iter()
+ .map(|var| (var.v_type.clone(), var.state_space))
+ .collect::<Vec<_>>()
+ }
+}
+
+impl<'a, 'input> Debug for PtxParserState<'a, 'input> {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("PtxParserState")
+ .field("errors", &self.errors) /* .field("function_decl", &self.function_decl) */
+ .finish()
+ }
+}
+
+type PtxParser<'a, 'input> = Stateful<&'a [Token<'input>], PtxParserState<'a, 'input>>;
+
+fn ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> {
+ any.verify_map(|t| {
+ if let Token::Ident(text) = t {
+ Some(text)
+ } else if let Some(text) = t.opcode_text() {
+ Some(text)
+ } else {
+ None
+ }
+ })
+ .parse_next(stream)
+}
+
+fn dot_ident<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<&'input str> {
+ any.verify_map(|t| {
+ if let Token::DotIdent(text) = t {
+ Some(text)
+ } else {
+ None
+ }
+ })
+ .parse_next(stream)
+}
+
+fn num<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(&'input str, u32, bool)> {
+ any.verify_map(|t| {
+ Some(match t {
+ Token::Hex(s) => {
+ if s.ends_with('U') {
+ (&s[2..s.len() - 1], 16, true)
+ } else {
+ (&s[2..], 16, false)
+ }
+ }
+ Token::Decimal(s) => {
+ let radix = if s.starts_with('0') { 8 } else { 10 };
+ if s.ends_with('U') {
+ (&s[..s.len() - 1], radix, true)
+ } else {
+ (s, radix, false)
+ }
+ }
+ _ => return None,
+ })
+ })
+ .parse_next(stream)
+}
+
+fn take_error<'a, 'input: 'a, O, E>(
+ mut parser: impl Parser<PtxParser<'a, 'input>, Result<O, (O, PtxError)>, E>,
+) -> impl Parser<PtxParser<'a, 'input>, O, E> {
+ move |input: &mut PtxParser<'a, 'input>| {
+ Ok(match parser.parse_next(input)? {
+ Ok(x) => x,
+ Err((x, err)) => {
+ input.state.errors.push(err);
+ x
+ }
+ })
+ }
+}
+
+fn int_immediate<'a, 'input>(input: &mut PtxParser<'a, 'input>) -> PResult<ast::ImmediateValue> {
+ take_error((opt(Token::Minus), num).map(|(neg, x)| {
+ let (num, radix, is_unsigned) = x;
+ if neg.is_some() {
+ match i64::from_str_radix(num, radix) {
+ Ok(x) => Ok(ast::ImmediateValue::S64(-x)),
+ Err(err) => Err((ast::ImmediateValue::S64(0), PtxError::from(err))),
+ }
+ } else if is_unsigned {
+ match u64::from_str_radix(num, radix) {
+ Ok(x) => Ok(ast::ImmediateValue::U64(x)),
+ Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))),
+ }
+ } else {
+ match i64::from_str_radix(num, radix) {
+ Ok(x) => Ok(ast::ImmediateValue::S64(x)),
+ Err(_) => match u64::from_str_radix(num, radix) {
+ Ok(x) => Ok(ast::ImmediateValue::U64(x)),
+ Err(err) => Err((ast::ImmediateValue::U64(0), PtxError::from(err))),
+ },
+ }
+ }
+ }))
+ .parse_next(input)
+}
+
+fn f32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f32> {
+ take_error(any.verify_map(|t| match t {
+ Token::F32(f) => Some(match u32::from_str_radix(&f[2..], 16) {
+ Ok(x) => Ok(f32::from_bits(x)),
+ Err(err) => Err((0.0, PtxError::from(err))),
+ }),
+ _ => None,
+ }))
+ .parse_next(stream)
+}
+
+fn f64<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<f64> {
+ take_error(any.verify_map(|t| match t {
+ Token::F64(f) => Some(match u64::from_str_radix(&f[2..], 16) {
+ Ok(x) => Ok(f64::from_bits(x)),
+ Err(err) => Err((0.0, PtxError::from(err))),
+ }),
+ _ => None,
+ }))
+ .parse_next(stream)
+}
+
+fn s32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<i32> {
+ take_error((opt(Token::Minus), num).map(|(sign, x)| {
+ let (text, radix, _) = x;
+ match i32::from_str_radix(text, radix) {
+ Ok(x) => Ok(if sign.is_some() { -x } else { x }),
+ Err(err) => Err((0, PtxError::from(err))),
+ }
+ }))
+ .parse_next(stream)
+}
+
+fn u8<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u8> {
+ take_error(num.map(|x| {
+ let (text, radix, _) = x;
+ match u8::from_str_radix(text, radix) {
+ Ok(x) => Ok(x),
+ Err(err) => Err((0, PtxError::from(err))),
+ }
+ }))
+ .parse_next(stream)
+}
+
+fn u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u32> {
+ take_error(num.map(|x| {
+ let (text, radix, _) = x;
+ match u32::from_str_radix(text, radix) {
+ Ok(x) => Ok(x),
+ Err(err) => Err((0, PtxError::from(err))),
+ }
+ }))
+ .parse_next(stream)
+}
+
+fn immediate_value<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::ImmediateValue> {
+ alt((
+ int_immediate,
+ f32.map(ast::ImmediateValue::F32),
+ f64.map(ast::ImmediateValue::F64),
+ ))
+ .parse_next(stream)
+}
+
+pub fn parse_module_unchecked<'input>(text: &'input str) -> Option<ast::Module<'input>> {
+ let lexer = Token::lexer(text);
+ let input = lexer.collect::<Result<Vec<_>, _>>().ok()?;
+ let mut errors = Vec::new();
+ let state = PtxParserState::new(&mut errors);
+ let parser = PtxParser {
+ state,
+ input: &input[..],
+ };
+ let parsing_result = module.parse(parser).ok();
+ if !errors.is_empty() {
+ None
+ } else {
+ parsing_result
+ }
+}
+
+pub fn parse_module_checked<'input>(
+ text: &'input str,
+) -> Result<ast::Module<'input>, Vec<PtxError>> {
+ let mut lexer = Token::lexer(text);
+ let mut errors = Vec::new();
+ let mut tokens = Vec::new();
+ loop {
+ let maybe_token = match lexer.next() {
+ Some(maybe_token) => maybe_token,
+ None => break,
+ };
+ match maybe_token {
+ Ok(token) => tokens.push(token),
+ Err(mut err) => {
+ err.0 = lexer.span();
+ errors.push(PtxError::from(err))
+ }
+ }
+ }
+ if !errors.is_empty() {
+ return Err(errors);
+ }
+ let parse_result = {
+ let state = PtxParserState::new(&mut errors);
+ let parser = PtxParser {
+ state,
+ input: &tokens[..],
+ };
+ module
+ .parse(parser)
+ .map_err(|err| PtxError::Parser(err.into_inner()))
+ };
+ match parse_result {
+ Ok(result) if errors.is_empty() => Ok(result),
+ Ok(_) => Err(errors),
+ Err(err) => {
+ errors.push(err);
+ Err(errors)
+ }
+ }
+}
+
+fn module<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Module<'input>> {
+ (
+ version,
+ target,
+ opt(address_size),
+ repeat_without_none(directive),
+ eof,
+ )
+ .map(|(version, _, _, directives, _)| ast::Module {
+ version,
+ directives,
+ })
+ .parse_next(stream)
+}
+
+fn address_size<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
+ (Token::DotAddressSize, u8_literal(64))
+ .void()
+ .parse_next(stream)
+}
+
+fn version<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u8, u8)> {
+ (Token::DotVersion, u8, Token::Dot, u8)
+ .map(|(_, major, _, minor)| (major, minor))
+ .parse_next(stream)
+}
+
+fn target<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, Option<char>)> {
+ preceded(Token::DotTarget, ident.and_then(shader_model)).parse_next(stream)
+}
+
+fn shader_model<'a>(stream: &mut &str) -> PResult<(u32, Option<char>)> {
+ (
+ "sm_",
+ dec_uint,
+ opt(any.verify(|c: &char| c.is_ascii_lowercase())),
+ eof,
+ )
+ .map(|(_, digits, arch_variant, _)| (digits, arch_variant))
+ .parse_next(stream)
+}
+
+fn directive<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<Option<ast::Directive<'input, ast::ParsedOperand<&'input str>>>> {
+ alt((
+ function.map(|(linking, func)| Some(ast::Directive::Method(linking, func))),
+ file.map(|_| None),
+ section.map(|_| None),
+ (module_variable, Token::Semicolon)
+ .map(|((linking, var), _)| Some(ast::Directive::Variable(linking, var))),
+ ))
+ .parse_next(stream)
+}
+
+fn module_variable<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<(ast::LinkingDirective, ast::Variable<&'input str>)> {
+ let linking = linking_directives.parse_next(stream)?;
+ let var = global_space
+ .flat_map(|space| multi_variable(linking.contains(LinkingDirective::EXTERN), space))
+ // TODO: support multi var in globals
+ .map(|multi_var| multi_var.var)
+ .parse_next(stream)?;
+ Ok((linking, var))
+}
+
+fn file<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
+ (
+ Token::DotFile,
+ u32,
+ Token::String,
+ opt((Token::Comma, u32, Token::Comma, u32)),
+ )
+ .void()
+ .parse_next(stream)
+}
+
+fn section<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
+ (
+ Token::DotSection.void(),
+ dot_ident.void(),
+ Token::LBrace.void(),
+ repeat::<_, _, (), _, _>(0.., section_dwarf_line),
+ Token::RBrace.void(),
+ )
+ .void()
+ .parse_next(stream)
+}
+
+fn section_dwarf_line<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
+ alt((
+ (section_label, Token::Colon).void(),
+ (Token::DotB32, section_label, opt((Token::Add, u32))).void(),
+ (Token::DotB64, section_label, opt((Token::Add, u32))).void(),
+ (
+ any_bit_type,
+ separated::<_, _, (), _, _, _, _>(1.., u32, Token::Comma),
+ )
+ .void(),
+ ))
+ .parse_next(stream)
+}
+
+fn any_bit_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
+ alt((Token::DotB8, Token::DotB16, Token::DotB32, Token::DotB64))
+ .void()
+ .parse_next(stream)
+}
+
+fn section_label<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
+ alt((ident, dot_ident)).void().parse_next(stream)
+}
+
+fn function<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<(
+ ast::LinkingDirective,
+ ast::Function<'input, &'input str, ast::Statement<ParsedOperand<&'input str>>>,
+)> {
+ let (linking, function) = (
+ linking_directives,
+ method_declaration,
+ repeat(0.., tuning_directive),
+ function_body,
+ )
+ .map(|(linking, func_directive, tuning, body)| {
+ (
+ linking,
+ ast::Function {
+ func_directive,
+ tuning,
+ body,
+ },
+ )
+ })
+ .parse_next(stream)?;
+ stream.state.record_function(&function.func_directive);
+ Ok((linking, function))
+}
+
+fn linking_directives<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<ast::LinkingDirective> {
+ repeat(
+ 0..,
+ dispatch! { any;
+ Token::DotExtern => empty.value(ast::LinkingDirective::EXTERN),
+ Token::DotVisible => empty.value(ast::LinkingDirective::VISIBLE),
+ Token::DotWeak => empty.value(ast::LinkingDirective::WEAK),
+ _ => fail
+ },
+ )
+ .fold(|| ast::LinkingDirective::NONE, |x, y| x | y)
+ .parse_next(stream)
+}
+
+fn tuning_directive<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<ast::TuningDirective> {
+ dispatch! {any;
+ Token::DotMaxnreg => u32.map(ast::TuningDirective::MaxNReg),
+ Token::DotMaxntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::MaxNtid(nx, ny, nz)),
+ Token::DotReqntid => tuple1to3_u32.map(|(nx, ny, nz)| ast::TuningDirective::ReqNtid(nx, ny, nz)),
+ Token::DotMinnctapersm => u32.map(ast::TuningDirective::MinNCtaPerSm),
+ _ => fail
+ }
+ .parse_next(stream)
+}
+
+fn method_declaration<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<ast::MethodDeclaration<'input, &'input str>> {
+ dispatch! {any;
+ Token::DotEntry => (ident, kernel_arguments).map(|(name, input_arguments)| ast::MethodDeclaration{
+ return_arguments: Vec::new(), name: ast::MethodName::Kernel(name), input_arguments, shared_mem: None
+ }),
+ Token::DotFunc => (opt(fn_arguments), ident, fn_arguments).map(|(return_arguments, name,input_arguments)| {
+ let return_arguments = return_arguments.unwrap_or_else(|| Vec::new());
+ let name = ast::MethodName::Func(name);
+ ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None }
+ }),
+ _ => fail
+ }
+ .parse_next(stream)
+}
+
+fn fn_arguments<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<Vec<ast::Variable<&'input str>>> {
+ delimited(
+ Token::LParen,
+ separated(0.., fn_input, Token::Comma),
+ Token::RParen,
+ )
+ .parse_next(stream)
+}
+
+fn kernel_arguments<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<Vec<ast::Variable<&'input str>>> {
+ delimited(
+ Token::LParen,
+ separated(0.., kernel_input, Token::Comma),
+ Token::RParen,
+ )
+ .parse_next(stream)
+}
+
+fn kernel_input<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<ast::Variable<&'input str>> {
+ preceded(Token::DotParam, method_parameter(StateSpace::Param)).parse_next(stream)
+}
+
+fn fn_input<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::Variable<&'input str>> {
+ dispatch! { any;
+ Token::DotParam => method_parameter(StateSpace::Param),
+ Token::DotReg => method_parameter(StateSpace::Reg),
+ _ => fail
+ }
+ .parse_next(stream)
+}
+
+fn tuple1to3_u32<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<(u32, u32, u32)> {
+ struct Tuple3AccumulateU32 {
+ index: usize,
+ value: (u32, u32, u32),
+ }
+
+ impl Accumulate<u32> for Tuple3AccumulateU32 {
+ fn initial(_: Option<usize>) -> Self {
+ Self {
+ index: 0,
+ value: (1, 1, 1),
+ }
+ }
+
+ fn accumulate(&mut self, value: u32) {
+ match self.index {
+ 0 => {
+ self.value = (value, self.value.1, self.value.2);
+ self.index = 1;
+ }
+ 1 => {
+ self.value = (self.value.0, value, self.value.2);
+ self.index = 2;
+ }
+ 2 => {
+ self.value = (self.value.0, self.value.1, value);
+ self.index = 3;
+ }
+ _ => unreachable!(),
+ }
+ }
+ }
+
+ separated::<_, _, Tuple3AccumulateU32, _, _, _, _>(1..=3, u32, Token::Comma)
+ .map(|acc| acc.value)
+ .parse_next(stream)
+}
+
+fn function_body<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<Option<Vec<ast::Statement<ParsedOperandStr<'input>>>>> {
+ dispatch! {any;
+ Token::LBrace => terminated(repeat_without_none(statement), Token::RBrace).map(Some),
+ Token::Semicolon => empty.map(|_| None),
+ _ => fail
+ }
+ .parse_next(stream)
+}
+
+fn statement<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<Option<Statement<ParsedOperandStr<'input>>>> {
+ alt((
+ label.map(Some),
+ debug_directive.map(|_| None),
+ terminated(
+ method_space
+ .flat_map(|space| multi_variable(false, space))
+ .map(|var| Some(Statement::Variable(var))),
+ Token::Semicolon,
+ ),
+ predicated_instruction.map(Some),
+ pragma.map(|_| None),
+ block_statement.map(Some),
+ ))
+ .parse_next(stream)
+}
+
+fn pragma<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
+ (Token::DotPragma, Token::String, Token::Semicolon)
+ .void()
+ .parse_next(stream)
+}
+
+fn method_parameter<'a, 'input: 'a>(
+ state_space: StateSpace,
+) -> impl Parser<PtxParser<'a, 'input>, Variable<&'input str>, ContextError> {
+ move |stream: &mut PtxParser<'a, 'input>| {
+ let (align, vector, type_, name) = variable_declaration.parse_next(stream)?;
+ let array_dimensions = if state_space != StateSpace::Reg {
+ opt(array_dimensions).parse_next(stream)?
+ } else {
+ None
+ };
+ // TODO: push this check into array_dimensions(...)
+ if let Some(ref dims) = array_dimensions {
+ if dims[0] == 0 {
+ return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
+ }
+ }
+ Ok(Variable {
+ align,
+ v_type: Type::maybe_array(vector, type_, array_dimensions),
+ state_space,
+ name,
+ array_init: Vec::new(),
+ })
+ }
+}
+
+// TODO: split to a separate type
+fn variable_declaration<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<(Option<u32>, Option<NonZeroU8>, ScalarType, &'input str)> {
+ (
+ opt(align.verify(|x| x.count_ones() == 1)),
+ vector_prefix,
+ scalar_type,
+ ident,
+ )
+ .parse_next(stream)
+}
+
+fn multi_variable<'a, 'input: 'a>(
+ extern_: bool,
+ state_space: StateSpace,
+) -> impl Parser<PtxParser<'a, 'input>, MultiVariable<&'input str>, ContextError> {
+ move |stream: &mut PtxParser<'a, 'input>| {
+ let ((align, vector, type_, name), count) = (
+ variable_declaration,
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameterized-variable-names
+ opt(delimited(Token::Lt, u32.verify(|x| *x != 0), Token::Gt)),
+ )
+ .parse_next(stream)?;
+ if count.is_some() {
+ return Ok(MultiVariable {
+ var: Variable {
+ align,
+ v_type: Type::maybe_vector_parsed(vector, type_),
+ state_space,
+ name,
+ array_init: Vec::new(),
+ },
+ count,
+ });
+ }
+ let mut array_dimensions = if state_space != StateSpace::Reg {
+ opt(array_dimensions).parse_next(stream)?
+ } else {
+ None
+ };
+ let initializer = match state_space {
+ StateSpace::Global | StateSpace::Const => match array_dimensions {
+ Some(ref mut dimensions) => {
+ opt(array_initializer(vector, type_, dimensions)).parse_next(stream)?
+ }
+ None => opt(value_initializer(vector, type_)).parse_next(stream)?,
+ },
+ _ => None,
+ };
+ if let Some(ref dims) = array_dimensions {
+ if !extern_ && dims[0] == 0 {
+ return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
+ }
+ }
+ Ok(MultiVariable {
+ var: Variable {
+ align,
+ v_type: Type::maybe_array(vector, type_, array_dimensions),
+ state_space,
+ name,
+ array_init: initializer.unwrap_or(Vec::new()),
+ },
+ count,
+ })
+ }
+}
+
+fn array_initializer<'a, 'input: 'a>(
+ vector: Option<NonZeroU8>,
+ type_: ScalarType,
+ array_dimensions: &mut Vec<u32>,
+) -> impl Parser<PtxParser<'a, 'input>, Vec<u8>, ContextError> + '_ {
+ move |stream: &mut PtxParser<'a, 'input>| {
+ Token::Eq.parse_next(stream)?;
+ let mut result = Vec::new();
+ // TODO: vector constants and multi dim arrays
+ if vector.is_some() || array_dimensions[0] == 0 || array_dimensions.len() > 1 {
+ return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
+ }
+ delimited(
+ Token::LBrace,
+ separated(
+ 0..=array_dimensions[0] as usize,
+ single_value_append(&mut result, type_),
+ Token::Comma,
+ ),
+ Token::RBrace,
+ )
+ .parse_next(stream)?;
+ // pad with zeros
+ let result_size = type_.size_of() as usize * array_dimensions[0] as usize;
+ result.extend(iter::repeat(0u8).take(result_size - result.len()));
+ Ok(result)
+ }
+}
+
+fn value_initializer<'a, 'input: 'a>(
+ vector: Option<NonZeroU8>,
+ type_: ScalarType,
+) -> impl Parser<PtxParser<'a, 'input>, Vec<u8>, ContextError> {
+ move |stream: &mut PtxParser<'a, 'input>| {
+ Token::Eq.parse_next(stream)?;
+ let mut result = Vec::new();
+ // TODO: vector constants
+ if vector.is_some() {
+ return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify));
+ }
+ single_value_append(&mut result, type_).parse_next(stream)?;
+ Ok(result)
+ }
+}
+
+fn single_value_append<'a, 'input: 'a>(
+ accumulator: &mut Vec<u8>,
+ type_: ScalarType,
+) -> impl Parser<PtxParser<'a, 'input>, (), ContextError> + '_ {
+ move |stream: &mut PtxParser<'a, 'input>| {
+ let value = immediate_value.parse_next(stream)?;
+ match (type_, value) {
+ (ScalarType::U8 | ScalarType::B8, ImmediateValue::U64(x)) => {
+ accumulator.extend_from_slice(&(x as u8).to_le_bytes())
+ }
+ (ScalarType::U8 | ScalarType::B8, ImmediateValue::S64(x)) => {
+ accumulator.extend_from_slice(&(x as u8).to_le_bytes())
+ }
+ (ScalarType::U16 | ScalarType::B16, ImmediateValue::U64(x)) => {
+ accumulator.extend_from_slice(&(x as u16).to_le_bytes())
+ }
+ (ScalarType::U16 | ScalarType::B16, ImmediateValue::S64(x)) => {
+ accumulator.extend_from_slice(&(x as u16).to_le_bytes())
+ }
+ (ScalarType::U32 | ScalarType::B32, ImmediateValue::U64(x)) => {
+ accumulator.extend_from_slice(&(x as u32).to_le_bytes())
+ }
+ (ScalarType::U32 | ScalarType::B32, ImmediateValue::S64(x)) => {
+ accumulator.extend_from_slice(&(x as u32).to_le_bytes())
+ }
+ (ScalarType::U64 | ScalarType::B64, ImmediateValue::U64(x)) => {
+ accumulator.extend_from_slice(&(x as u64).to_le_bytes())
+ }
+ (ScalarType::U64 | ScalarType::B64, ImmediateValue::S64(x)) => {
+ accumulator.extend_from_slice(&(x as u64).to_le_bytes())
+ }
+ (ScalarType::S8, ImmediateValue::U64(x)) => {
+ accumulator.extend_from_slice(&(x as i8).to_le_bytes())
+ }
+ (ScalarType::S8, ImmediateValue::S64(x)) => {
+ accumulator.extend_from_slice(&(x as i8).to_le_bytes())
+ }
+ (ScalarType::S16, ImmediateValue::U64(x)) => {
+ accumulator.extend_from_slice(&(x as i16).to_le_bytes())
+ }
+ (ScalarType::S16, ImmediateValue::S64(x)) => {
+ accumulator.extend_from_slice(&(x as i16).to_le_bytes())
+ }
+ (ScalarType::S32, ImmediateValue::U64(x)) => {
+ accumulator.extend_from_slice(&(x as i32).to_le_bytes())
+ }
+ (ScalarType::S32, ImmediateValue::S64(x)) => {
+ accumulator.extend_from_slice(&(x as i32).to_le_bytes())
+ }
+ (ScalarType::S64, ImmediateValue::U64(x)) => {
+ accumulator.extend_from_slice(&(x as i64).to_le_bytes())
+ }
+ (ScalarType::S64, ImmediateValue::S64(x)) => {
+ accumulator.extend_from_slice(&(x as i64).to_le_bytes())
+ }
+ (ScalarType::F32, ImmediateValue::F32(x)) => {
+ accumulator.extend_from_slice(&x.to_le_bytes())
+ }
+ (ScalarType::F64, ImmediateValue::F64(x)) => {
+ accumulator.extend_from_slice(&x.to_le_bytes())
+ }
+ _ => return Err(ErrMode::from_error_kind(stream, ErrorKind::Verify)),
+ }
+ Ok(())
+ }
+}
+
+fn array_dimensions<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Vec<u32>> {
+ let dimension = delimited(
+ Token::LBracket,
+ opt(u32).verify(|dim| *dim != Some(0)),
+ Token::RBracket,
+ )
+ .parse_next(stream)?;
+ let result = vec![dimension.unwrap_or(0)];
+ repeat_fold_0_or_more(
+ delimited(
+ Token::LBracket,
+ u32.verify(|dim| *dim != 0),
+ Token::RBracket,
+ ),
+ move || result,
+ |mut result: Vec<u32>, x| {
+ result.push(x);
+ result
+ },
+ stream,
+ )
+}
+
+// Copied and fixed from Winnow sources (fold_repeat0_)
+// Winnow Repeat::fold takes FnMut() -> Result to initalize accumulator,
+// this really should be FnOnce() -> Result
+fn repeat_fold_0_or_more<I, O, E, F, G, H, R>(
+ mut f: F,
+ init: H,
+ mut g: G,
+ input: &mut I,
+) -> PResult<R, E>
+where
+ I: Stream,
+ F: Parser<I, O, E>,
+ G: FnMut(R, O) -> R,
+ H: FnOnce() -> R,
+ E: ParserError<I>,
+{
+ use winnow::error::ErrMode;
+ let mut res = init();
+ loop {
+ let start = input.checkpoint();
+ match f.parse_next(input) {
+ Ok(o) => {
+ res = g(res, o);
+ }
+ Err(ErrMode::Backtrack(_)) => {
+ input.reset(&start);
+ return Ok(res);
+ }
+ Err(e) => {
+ return Err(e);
+ }
+ }
+ }
+}
+
+fn global_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<StateSpace> {
+ alt((
+ Token::DotGlobal.value(StateSpace::Global),
+ Token::DotConst.value(StateSpace::Const),
+ Token::DotShared.value(StateSpace::Shared),
+ ))
+ .parse_next(stream)
+}
+
+fn method_space<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<StateSpace> {
+ alt((
+ Token::DotReg.value(StateSpace::Reg),
+ Token::DotLocal.value(StateSpace::Local),
+ Token::DotParam.value(StateSpace::Param),
+ global_space,
+ ))
+ .parse_next(stream)
+}
+
+fn align<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<u32> {
+ preceded(Token::DotAlign, u32).parse_next(stream)
+}
+
+fn vector_prefix<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<Option<NonZeroU8>> {
+ opt(alt((
+ Token::DotV2.value(unsafe { NonZeroU8::new_unchecked(2) }),
+ Token::DotV4.value(unsafe { NonZeroU8::new_unchecked(4) }),
+ Token::DotV8.value(unsafe { NonZeroU8::new_unchecked(8) }),
+ )))
+ .parse_next(stream)
+}
+
+fn scalar_type<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ScalarType> {
+ any.verify_map(|t| {
+ Some(match t {
+ Token::DotS8 => ScalarType::S8,
+ Token::DotS16 => ScalarType::S16,
+ Token::DotS16x2 => ScalarType::S16x2,
+ Token::DotS32 => ScalarType::S32,
+ Token::DotS64 => ScalarType::S64,
+ Token::DotU8 => ScalarType::U8,
+ Token::DotU16 => ScalarType::U16,
+ Token::DotU16x2 => ScalarType::U16x2,
+ Token::DotU32 => ScalarType::U32,
+ Token::DotU64 => ScalarType::U64,
+ Token::DotB8 => ScalarType::B8,
+ Token::DotB16 => ScalarType::B16,
+ Token::DotB32 => ScalarType::B32,
+ Token::DotB64 => ScalarType::B64,
+ Token::DotB128 => ScalarType::B128,
+ Token::DotPred => ScalarType::Pred,
+ Token::DotF16 => ScalarType::F16,
+ Token::DotF16x2 => ScalarType::F16x2,
+ Token::DotF32 => ScalarType::F32,
+ Token::DotF64 => ScalarType::F64,
+ Token::DotBF16 => ScalarType::BF16,
+ Token::DotBF16x2 => ScalarType::BF16x2,
+ _ => return None,
+ })
+ })
+ .parse_next(stream)
+}
+
+fn predicated_instruction<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<ast::Statement<ParsedOperandStr<'input>>> {
+ (opt(pred_at), parse_instruction, Token::Semicolon)
+ .map(|(p, i, _)| ast::Statement::Instruction(p, i))
+ .parse_next(stream)
+}
+
+fn pred_at<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<ast::PredAt<&'input str>> {
+ (Token::At, opt(Token::Exclamation), ident)
+ .map(|(_, not, label)| ast::PredAt {
+ not: not.is_some(),
+ label,
+ })
+ .parse_next(stream)
+}
+
+fn label<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<ast::Statement<ParsedOperandStr<'input>>> {
+ terminated(ident, Token::Colon)
+ .map(|l| ast::Statement::Label(l))
+ .parse_next(stream)
+}
+
+fn debug_directive<'a, 'input>(stream: &mut PtxParser<'a, 'input>) -> PResult<()> {
+ (
+ Token::DotLoc,
+ u32,
+ u32,
+ u32,
+ opt((
+ Token::Comma,
+ ident_literal("function_name"),
+ ident,
+ dispatch! { any;
+ Token::Comma => (ident_literal("inlined_at"), u32, u32, u32).void(),
+ Token::Plus => (u32, Token::Comma, ident_literal("inlined_at"), u32, u32, u32).void(),
+ _ => fail
+ },
+ )),
+ )
+ .void()
+ .parse_next(stream)
+}
+
+fn block_statement<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<ast::Statement<ParsedOperandStr<'input>>> {
+ delimited(Token::LBrace, repeat_without_none(statement), Token::RBrace)
+ .map(|s| ast::Statement::Block(s))
+ .parse_next(stream)
+}
+
+fn repeat_without_none<Input: Stream, Output, Error: ParserError<Input>>(
+ parser: impl Parser<Input, Option<Output>, Error>,
+) -> impl Parser<Input, Vec<Output>, Error> {
+ repeat(0.., parser).fold(Vec::new, |mut acc: Vec<_>, item| {
+ if let Some(item) = item {
+ acc.push(item);
+ }
+ acc
+ })
+}
+
+fn ident_literal<
+ 'a,
+ 'input,
+ I: Stream<Token = Token<'input>> + StreamIsPartial,
+ E: ParserError<I>,
+>(
+ s: &'input str,
+) -> impl Parser<I, (), E> + 'input {
+ move |stream: &mut I| {
+ any.verify(|t| matches!(t, Token::Ident(text) if *text == s))
+ .void()
+ .parse_next(stream)
+ }
+}
+
+fn u8_literal<'a, 'input>(x: u8) -> impl Parser<PtxParser<'a, 'input>, (), ContextError> {
+ move |stream: &mut PtxParser| u8.verify(|t| *t == x).void().parse_next(stream)
+}
+
+impl<Ident> ast::ParsedOperand<Ident> {
+ fn parse<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+ ) -> PResult<ast::ParsedOperand<&'input str>> {
+ use winnow::combinator::*;
+ use winnow::token::any;
+ fn vector_index<'input>(inp: &'input str) -> Result<u8, PtxError> {
+ match inp {
+ ".x" | ".r" => Ok(0),
+ ".y" | ".g" => Ok(1),
+ ".z" | ".b" => Ok(2),
+ ".w" | ".a" => Ok(3),
+ _ => Err(PtxError::WrongVectorElement),
+ }
+ }
+ fn ident_operands<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+ ) -> PResult<ast::ParsedOperand<&'input str>> {
+ let main_ident = ident.parse_next(stream)?;
+ alt((
+ preceded(Token::Plus, s32)
+ .map(move |offset| ast::ParsedOperand::RegOffset(main_ident, offset)),
+ take_error(dot_ident.map(move |suffix| {
+ let vector_index = vector_index(suffix)
+ .map_err(move |e| (ast::ParsedOperand::VecMember(main_ident, 0), e))?;
+ Ok(ast::ParsedOperand::VecMember(main_ident, vector_index))
+ })),
+ empty.value(ast::ParsedOperand::Reg(main_ident)),
+ ))
+ .parse_next(stream)
+ }
+ fn vector_operand<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+ ) -> PResult<Vec<&'input str>> {
+ let (_, r1, _, r2) = (Token::LBrace, ident, Token::Comma, ident).parse_next(stream)?;
+ // TODO: parse .v8 literals
+ dispatch! {any;
+ Token::RBrace => empty.map(|_| vec![r1, r2]),
+ Token::Comma => (ident, Token::Comma, ident, Token::RBrace).map(|(r3, _, r4, _)| vec![r1, r2, r3, r4]),
+ _ => fail
+ }
+ .parse_next(stream)
+ }
+ alt((
+ ident_operands,
+ immediate_value.map(ast::ParsedOperand::Imm),
+ vector_operand.map(ast::ParsedOperand::VecPack),
+ ))
+ .parse_next(stream)
+ }
+}
+
+#[derive(Debug, thiserror::Error)]
+pub enum PtxError {
+ #[error("{source}")]
+ ParseInt {
+ #[from]
+ source: ParseIntError,
+ },
+ #[error("{source}")]
+ ParseFloat {
+ #[from]
+ source: ParseFloatError,
+ },
+ #[error("{source}")]
+ Lexer {
+ #[from]
+ source: TokenError,
+ },
+ #[error("")]
+ Parser(ContextError),
+ #[error("")]
+ Todo,
+ #[error("")]
+ SyntaxError,
+ #[error("")]
+ NonF32Ftz,
+ #[error("")]
+ Unsupported32Bit,
+ #[error("")]
+ WrongType,
+ #[error("")]
+ UnknownFunction,
+ #[error("")]
+ MalformedCall,
+ #[error("")]
+ WrongArrayType,
+ #[error("")]
+ WrongVectorElement,
+ #[error("")]
+ MultiArrayVariable,
+ #[error("")]
+ ZeroDimensionArray,
+ #[error("")]
+ ArrayInitalizer,
+ #[error("")]
+ NonExternPointer,
+ #[error("{start}:{end}")]
+ UnrecognizedStatement { start: usize, end: usize },
+ #[error("{start}:{end}")]
+ UnrecognizedDirective { start: usize, end: usize },
+}
+
+#[derive(Debug)]
+struct ReverseStream<'a, T>(pub &'a [T]);
+
+impl<'i, T> Stream for ReverseStream<'i, T>
+where
+ T: Clone + ::std::fmt::Debug,
+{
+ type Token = T;
+ type Slice = &'i [T];
+
+ type IterOffsets =
+ std::iter::Enumerate<std::iter::Cloned<std::iter::Rev<std::slice::Iter<'i, T>>>>;
+
+ type Checkpoint = &'i [T];
+
+ fn iter_offsets(&self) -> Self::IterOffsets {
+ self.0.iter().rev().cloned().enumerate()
+ }
+
+ fn eof_offset(&self) -> usize {
+ self.0.len()
+ }
+
+ fn next_token(&mut self) -> Option<Self::Token> {
+ let (token, next) = self.0.split_last()?;
+ self.0 = next;
+ Some(token.clone())
+ }
+
+ fn offset_for<P>(&self, predicate: P) -> Option<usize>
+ where
+ P: Fn(Self::Token) -> bool,
+ {
+ self.0.iter().rev().position(|b| predicate(b.clone()))
+ }
+
+ fn offset_at(&self, tokens: usize) -> Result<usize, winnow::error::Needed> {
+ if let Some(needed) = tokens
+ .checked_sub(self.0.len())
+ .and_then(std::num::NonZeroUsize::new)
+ {
+ Err(winnow::error::Needed::Size(needed))
+ } else {
+ Ok(tokens)
+ }
+ }
+
+ fn next_slice(&mut self, offset: usize) -> Self::Slice {
+ let offset = self.0.len() - offset;
+ let (next, slice) = self.0.split_at(offset);
+ self.0 = next;
+ slice
+ }
+
+ fn checkpoint(&self) -> Self::Checkpoint {
+ self.0
+ }
+
+ fn reset(&mut self, checkpoint: &Self::Checkpoint) {
+ self.0 = checkpoint;
+ }
+
+ fn raw(&self) -> &dyn std::fmt::Debug {
+ self
+ }
+}
+
+impl<'a, T> Offset<&'a [T]> for ReverseStream<'a, T> {
+ fn offset_from(&self, start: &&'a [T]) -> usize {
+ let fst = start.as_ptr();
+ let snd = self.0.as_ptr();
+
+ debug_assert!(
+ snd <= fst,
+ "`Offset::offset_from({snd:?}, {fst:?})` only accepts slices of `self`"
+ );
+ (fst as usize - snd as usize) / std::mem::size_of::<T>()
+ }
+}
+
+impl<'a, T> StreamIsPartial for ReverseStream<'a, T> {
+ type PartialState = ();
+
+ fn complete(&mut self) -> Self::PartialState {}
+
+ fn restore_partial(&mut self, _state: Self::PartialState) {}
+
+ fn is_partial_supported() -> bool {
+ false
+ }
+}
+
+impl<'input, I: Stream<Token = Self> + StreamIsPartial, E: ParserError<I>> Parser<I, Self, E>
+ for Token<'input>
+{
+ fn parse_next(&mut self, input: &mut I) -> PResult<Self, E> {
+ any.verify(|t| t == self).parse_next(input)
+ }
+}
+
+fn bra<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<ast::Instruction<ParsedOperandStr<'input>>> {
+ preceded(
+ opt(Token::DotUni),
+ any.verify_map(|t| match t {
+ Token::Ident(ident) => Some(ast::Instruction::Bra {
+ arguments: BraArgs { src: ident },
+ }),
+ _ => None,
+ }),
+ )
+ .parse_next(stream)
+}
+
+fn call<'a, 'input>(
+ stream: &mut PtxParser<'a, 'input>,
+) -> PResult<ast::Instruction<ParsedOperandStr<'input>>> {
+ let (uni, return_arguments, name, input_arguments) = (
+ opt(Token::DotUni),
+ opt((
+ Token::LParen,
+ separated(1.., ident, Token::Comma).map(|x: Vec<_>| x),
+ Token::RParen,
+ Token::Comma,
+ )
+ .map(|(_, arguments, _, _)| arguments)),
+ ident,
+ opt((
+ Token::Comma.void(),
+ Token::LParen.void(),
+ separated(1.., ParsedOperand::<&'input str>::parse, Token::Comma).map(|x: Vec<_>| x),
+ Token::RParen.void(),
+ )
+ .map(|(_, _, arguments, _)| arguments)),
+ )
+ .parse_next(stream)?;
+ let uniform = uni.is_some();
+ let recorded_fn = match stream.state.function_declarations.get(name) {
+ Some(decl) => decl,
+ None => {
+ stream.state.errors.push(PtxError::UnknownFunction);
+ return Ok(empty_call(uniform, name));
+ }
+ };
+ let return_arguments = return_arguments.unwrap_or(Vec::new());
+ let input_arguments = input_arguments.unwrap_or(Vec::new());
+ if recorded_fn.0.len() != return_arguments.len() || recorded_fn.1.len() != input_arguments.len()
+ {
+ stream.state.errors.push(PtxError::MalformedCall);
+ return Ok(empty_call(uniform, name));
+ }
+ let data = CallDetails {
+ uniform,
+ return_arguments: recorded_fn.0.clone(),
+ input_arguments: recorded_fn.1.clone(),
+ };
+ let arguments = CallArgs {
+ return_arguments,
+ func: name,
+ input_arguments,
+ };
+ Ok(ast::Instruction::Call { data, arguments })
+}
+
+fn empty_call<'input>(
+ uniform: bool,
+ name: &'input str,
+) -> ast::Instruction<ParsedOperandStr<'input>> {
+ ast::Instruction::Call {
+ data: CallDetails {
+ uniform,
+ return_arguments: Vec::new(),
+ input_arguments: Vec::new(),
+ },
+ arguments: CallArgs {
+ return_arguments: Vec::new(),
+ func: name,
+ input_arguments: Vec::new(),
+ },
+ }
+}
+
+type ParsedOperandStr<'input> = ast::ParsedOperand<&'input str>;
+
+#[derive(Clone, PartialEq, Default, Debug, Display)]
+#[display("({}:{})", _0.start, _0.end)]
+pub struct TokenError(std::ops::Range<usize>);
+
+impl std::error::Error for TokenError {}
+
+// This macro is responsible for generating parser code for instruction parser.
+// Instruction parsing is by far the most complex part of parsing PTX code:
+// * There are tens of instruction kinds, each with slightly different parsing rules
+// * After parsing, each instruction needs to do some early validation and generate a specific,
+// strongly-typed object. We want strong-typing because we have a single PTX parser frontend, but
+// there can be multiple different code emitter backends
+// * Most importantly, instruction modifiers can come in aby order, so e.g. both
+// `ld.relaxed.global.u32 a, b` and `ld.global.relaxed.u32 a, b` are equally valid. This makes
+// classic parsing generators fail: if we tried to generate parsing rules that cover every possible
+// ordering we'd need thousands of rules. This is not a purely theoretical problem. NVCC and Clang
+// will always emit modifiers in the correct order, but people who write inline assembly usually
+// get it wrong (even first party developers)
+//
+// This macro exists purely to generate repetitive code for parsing each instruction. It is
+// _not_ self-contained and is _not_ general-purpose: it relies on certain types and functions from
+// the enclosing module
+//
+// derive_parser!(...) input is split into three parts:
+// * Token type definition
+// * Partial enums
+// * Parsing definitions
+//
+// Token type definition:
+// This is the enum type that will be usesby the instruction parser. For every instruction and
+// modifier, derive_parser!(...) will add appropriate variant into this type. So e.g. if there is a
+// rule for for `bar.sync` then those two variants wil be appended to the Token enum:
+// #[token("bar")] Bar,
+// #[token(".sync")] DotSync,
+//
+// Partial enums:
+// With proper annotations, derive_parser!(...) parsing definitions are able to interpret
+// instruction modifiers as variants of a single enum type. So e.g. for definitions `ld.u32` and
+// `ld.u64` the macro can generate `enum ScalarType { U32, U64 }`. The problem is that for some
+// (but not all) of those generated enum types we want to add some attributes and additional
+// variants. In order to do so, you need to define this enum and derive_parser!(...) will append to
+// the type instead of creating a new type. This is sort of replacement for partial classes known
+// from C#
+//
+// Parsing definitions:
+// Parsing definitions consist of a list of patterns and rules:
+// * Pattern consists of:
+// * Opcode: `ld`
+// * Modifiers, always start with a dot: `.global`, `.relaxed`. Optionals are enclosed in braces
+// * Arguments: `a`, `b`. Optionals are enclosed in braces
+// * Code block: => { <code expression> }. Code blocks implictly take all modifiers ansd arguments
+// as parameters. All modifiers and arguments are passed to the code block:
+// * If it is an alternative (as defined in rules list later):
+// * If it is mandatory then its type is Foo (as defined by the relevant rule)
+// * If it is optional then its type is Option<Foo>
+// * Otherwise:
+// * If it is mandatory then it is skipped
+// * If it is optional then its type is `bool`
+// * List of rules. They are associated with the preceding patterns (until different opcode or
+// different rules). Rules are used to resolve modifiers. There are two types of rules:
+// * Normal rule: `.foobar: FoobarEnum => { .a, .b, .c }`. This means that instead of `.foobar` we
+// expecte one of `.a`, `.b`, `.c` and will emit value FoobarEnum::DotA, FoobarEnum::DotB,
+// FoobarEnum::DotC appropriately
+// * Type-only rule: `FoobarEnum => { .a, .b, .c }` this means that all the occurences of `.a` will
+// emit FoobarEnum::DotA to the code block. This helps to avoid copy-paste errors
+// Additionally, you can opt out from the usual parsing rule generation with a special `<=` pattern.
+// See `call` instruction to see it in action
+derive_parser!(
+ #[derive(Logos, PartialEq, Eq, Debug, Clone, Copy)]
+ #[logos(skip r"(?:\s+)|(?://[^\n\r]*[\n\r]*)|(?:/\*[^*]*\*+(?:[^/*][^*]*\*+)*/)")]
+ #[logos(error = TokenError)]
+ enum Token<'input> {
+ #[token(",")]
+ Comma,
+ #[token(".")]
+ Dot,
+ #[token(":")]
+ Colon,
+ #[token(";")]
+ Semicolon,
+ #[token("@")]
+ At,
+ #[regex(r"[a-zA-Z][a-zA-Z0-9_$]*|[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)]
+ Ident(&'input str),
+ #[regex(r"\.[a-zA-Z][a-zA-Z0-9_$]*|\.[_$%][a-zA-Z0-9_$]+", |lex| lex.slice(), priority = 0)]
+ DotIdent(&'input str),
+ #[regex(r#""[^"]*""#)]
+ String,
+ #[token("|")]
+ Pipe,
+ #[token("!")]
+ Exclamation,
+ #[token("(")]
+ LParen,
+ #[token(")")]
+ RParen,
+ #[token("[")]
+ LBracket,
+ #[token("]")]
+ RBracket,
+ #[token("{")]
+ LBrace,
+ #[token("}")]
+ RBrace,
+ #[token("<")]
+ Lt,
+ #[token(">")]
+ Gt,
+ #[regex(r"0[fF][0-9a-zA-Z]{8}", |lex| lex.slice())]
+ F32(&'input str),
+ #[regex(r"0[dD][0-9a-zA-Z]{16}", |lex| lex.slice())]
+ F64(&'input str),
+ #[regex(r"0[xX][0-9a-zA-Z]+U?", |lex| lex.slice())]
+ Hex(&'input str),
+ #[regex(r"[0-9]+U?", |lex| lex.slice())]
+ Decimal(&'input str),
+ #[token("-")]
+ Minus,
+ #[token("+")]
+ Plus,
+ #[token("=")]
+ Eq,
+ #[token(".version")]
+ DotVersion,
+ #[token(".loc")]
+ DotLoc,
+ #[token(".reg")]
+ DotReg,
+ #[token(".align")]
+ DotAlign,
+ #[token(".pragma")]
+ DotPragma,
+ #[token(".maxnreg")]
+ DotMaxnreg,
+ #[token(".maxntid")]
+ DotMaxntid,
+ #[token(".reqntid")]
+ DotReqntid,
+ #[token(".minnctapersm")]
+ DotMinnctapersm,
+ #[token(".entry")]
+ DotEntry,
+ #[token(".func")]
+ DotFunc,
+ #[token(".extern")]
+ DotExtern,
+ #[token(".visible")]
+ DotVisible,
+ #[token(".target")]
+ DotTarget,
+ #[token(".address_size")]
+ DotAddressSize,
+ #[token(".action")]
+ DotSection,
+ #[token(".file")]
+ DotFile
+ }
+
+ #[derive(Copy, Clone, PartialEq, Eq, Hash)]
+ pub enum StateSpace {
+ Reg,
+ Generic,
+ Sreg,
+ }
+
+ #[derive(Copy, Clone, PartialEq, Eq, Hash)]
+ pub enum MemScope { }
+
+ #[derive(Copy, Clone, PartialEq, Eq, Hash)]
+ pub enum ScalarType { }
+
+ #[derive(Copy, Clone, PartialEq, Eq, Hash)]
+ pub enum SetpBoolPostOp { }
+
+ #[derive(Copy, Clone, PartialEq, Eq, Hash)]
+ pub enum AtomSemantics { }
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
+ mov{.vec}.type d, a => {
+ Instruction::Mov {
+ data: ast::MovDetails::new(vec, type_),
+ arguments: MovArgs { dst: d, src: a },
+ }
+ }
+ .vec: VectorPrefix = { .v2, .v4 };
+ .type: ScalarType = { .pred,
+ .b16, .b32, .b64,
+ .u16, .u32, .u64,
+ .s16, .s32, .s64,
+ .f32, .f64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-st
+ st{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
+ if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ Instruction::St {
+ data: StData {
+ qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(),
+ state_space: ss.unwrap_or(StateSpace::Generic),
+ caching: cop.unwrap_or(RawStCacheOperator::Wb).into(),
+ typ: ast::Type::maybe_vector(vec, type_)
+ },
+ arguments: StArgs { src1:a, src2:b }
+ }
+ }
+ st.volatile{.ss}{.vec}.type [a], b => {
+ Instruction::St {
+ data: StData {
+ qualifier: volatile.into(),
+ state_space: ss.unwrap_or(StateSpace::Generic),
+ caching: ast::StCacheOperator::Writeback,
+ typ: ast::Type::maybe_vector(vec, type_)
+ },
+ arguments: StArgs { src1:a, src2:b }
+ }
+ }
+ st.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
+ if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ Instruction::St {
+ data: StData {
+ qualifier: ast::LdStQualifier::Relaxed(scope),
+ state_space: ss.unwrap_or(StateSpace::Generic),
+ caching: ast::StCacheOperator::Writeback,
+ typ: ast::Type::maybe_vector(vec, type_)
+ },
+ arguments: StArgs { src1:a, src2:b }
+ }
+ }
+ st.release.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.vec}.type [a], b{, cache_policy} => {
+ if level_eviction_priority.is_some() || level_cache_hint || cache_policy.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ Instruction::St {
+ data: StData {
+ qualifier: ast::LdStQualifier::Release(scope),
+ state_space: ss.unwrap_or(StateSpace::Generic),
+ caching: ast::StCacheOperator::Writeback,
+ typ: ast::Type::maybe_vector(vec, type_)
+ },
+ arguments: StArgs { src1:a, src2:b }
+ }
+ }
+ st.mmio.relaxed.sys{.global}.type [a], b => {
+ state.errors.push(PtxError::Todo);
+ Instruction::St {
+ data: ast::StData {
+ qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys),
+ state_space: global.unwrap_or(StateSpace::Generic),
+ caching: ast::StCacheOperator::Writeback,
+ typ: type_.into()
+ },
+ arguments: ast::StArgs { src1:a, src2:b }
+ }
+ }
+ .ss: StateSpace = { .global, .local, .param{::func}, .shared{::cta, ::cluster} };
+ .level::eviction_priority: EvictionPriority =
+ { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate };
+ .level::cache_hint = { .L2::cache_hint };
+ .cop: RawStCacheOperator = { .wb, .cg, .cs, .wt };
+ .scope: MemScope = { .cta, .cluster, .gpu, .sys };
+ .vec: VectorPrefix = { .v2, .v4 };
+ .type: ScalarType = { .b8, .b16, .b32, .b64, .b128,
+ .u8, .u16, .u32, .u64,
+ .s8, .s16, .s32, .s64,
+ .f32, .f64 };
+ RawLdStQualifier = { .weak, .volatile };
+ StateSpace = { .global };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-ld
+ ld{.weak}{.ss}{.cop}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{.unified}{, cache_policy} => {
+ let (a, unified) = a;
+ if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || unified || cache_policy.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ Instruction::Ld {
+ data: LdDetails {
+ qualifier: weak.unwrap_or(RawLdStQualifier::Weak).into(),
+ state_space: ss.unwrap_or(StateSpace::Generic),
+ caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(),
+ typ: ast::Type::maybe_vector(vec, type_),
+ non_coherent: false
+ },
+ arguments: LdArgs { dst:d, src:a }
+ }
+ }
+ ld.volatile{.ss}{.level::prefetch_size}{.vec}.type d, [a] => {
+ if level_prefetch_size.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ Instruction::Ld {
+ data: LdDetails {
+ qualifier: volatile.into(),
+ state_space: ss.unwrap_or(StateSpace::Generic),
+ caching: ast::LdCacheOperator::Cached,
+ typ: ast::Type::maybe_vector(vec, type_),
+ non_coherent: false
+ },
+ arguments: LdArgs { dst:d, src:a }
+ }
+ }
+ ld.relaxed.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => {
+ if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ Instruction::Ld {
+ data: LdDetails {
+ qualifier: ast::LdStQualifier::Relaxed(scope),
+ state_space: ss.unwrap_or(StateSpace::Generic),
+ caching: ast::LdCacheOperator::Cached,
+ typ: ast::Type::maybe_vector(vec, type_),
+ non_coherent: false
+ },
+ arguments: LdArgs { dst:d, src:a }
+ }
+ }
+ ld.acquire.scope{.ss}{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => {
+ if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ Instruction::Ld {
+ data: LdDetails {
+ qualifier: ast::LdStQualifier::Acquire(scope),
+ state_space: ss.unwrap_or(StateSpace::Generic),
+ caching: ast::LdCacheOperator::Cached,
+ typ: ast::Type::maybe_vector(vec, type_),
+ non_coherent: false
+ },
+ arguments: LdArgs { dst:d, src:a }
+ }
+ }
+ ld.mmio.relaxed.sys{.global}.type d, [a] => {
+ state.errors.push(PtxError::Todo);
+ Instruction::Ld {
+ data: LdDetails {
+ qualifier: ast::LdStQualifier::Relaxed(MemScope::Sys),
+ state_space: global.unwrap_or(StateSpace::Generic),
+ caching: ast::LdCacheOperator::Cached,
+ typ: type_.into(),
+ non_coherent: false
+ },
+ arguments: LdArgs { dst:d, src:a }
+ }
+ }
+ .ss: StateSpace = { .const, .global, .local, .param{::entry, ::func}, .shared{::cta, ::cluster} };
+ .cop: RawLdCacheOperator = { .ca, .cg, .cs, .lu, .cv };
+ .level::eviction_priority: EvictionPriority =
+ { .L1::evict_normal, .L1::evict_unchanged, .L1::evict_first, .L1::evict_last, .L1::no_allocate };
+ .level::cache_hint = { .L2::cache_hint };
+ .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B };
+ .scope: MemScope = { .cta, .cluster, .gpu, .sys };
+ .vec: VectorPrefix = { .v2, .v4 };
+ .type: ScalarType = { .b8, .b16, .b32, .b64, .b128,
+ .u8, .u16, .u32, .u64,
+ .s8, .s16, .s32, .s64,
+ .f32, .f64 };
+ RawLdStQualifier = { .weak, .volatile };
+ StateSpace = { .global };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld-global-nc
+ ld.global{.cop}.nc{.level::eviction_priority}{.level::cache_hint}{.level::prefetch_size}{.vec}.type d, [a]{, cache_policy} => {
+ if cop.is_some() && level_eviction_priority.is_some() {
+ state.errors.push(PtxError::SyntaxError);
+ }
+ if level_eviction_priority.is_some() || level_cache_hint || level_prefetch_size.is_some() || cache_policy.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ Instruction::Ld {
+ data: LdDetails {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: global,
+ caching: cop.unwrap_or(RawLdCacheOperator::Ca).into(),
+ typ: Type::maybe_vector(vec, type_),
+ non_coherent: true
+ },
+ arguments: LdArgs { dst:d, src:a }
+ }
+ }
+ .cop: RawLdCacheOperator = { .ca, .cg, .cs };
+ .level::eviction_priority: EvictionPriority =
+ { .L1::evict_normal, .L1::evict_unchanged,
+ .L1::evict_first, .L1::evict_last, .L1::no_allocate};
+ .level::cache_hint = { .L2::cache_hint };
+ .level::prefetch_size: PrefetchSize = { .L2::64B, .L2::128B, .L2::256B };
+ .vec: VectorPrefix = { .v2, .v4 };
+ .type: ScalarType = { .b8, .b16, .b32, .b64, .b128,
+ .u8, .u16, .u32, .u64,
+ .s8, .s16, .s32, .s64,
+ .f32, .f64 };
+ StateSpace = { .global };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#integer-arithmetic-instructions-add
+ add.type d, a, b => {
+ Instruction::Add {
+ data: ast::ArithDetails::Integer(
+ ast::ArithInteger {
+ type_,
+ saturate: false
+ }
+ ),
+ arguments: AddArgs {
+ dst: d, src1: a, src2: b
+ }
+ }
+ }
+ add{.sat}.s32 d, a, b => {
+ Instruction::Add {
+ data: ast::ArithDetails::Integer(
+ ast::ArithInteger {
+ type_: s32,
+ saturate: sat
+ }
+ ),
+ arguments: AddArgs {
+ dst: d, src1: a, src2: b
+ }
+ }
+ }
+ .type: ScalarType = { .u16, .u32, .u64,
+ .s16, .s64,
+ .u16x2, .s16x2 };
+ ScalarType = { .s32 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions-add
+ add{.rnd}{.ftz}{.sat}.f32 d, a, b => {
+ Instruction::Add {
+ data: ast::ArithDetails::Float(
+ ast::ArithFloat {
+ type_: f32,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: Some(ftz),
+ saturate: sat
+ }
+ ),
+ arguments: AddArgs {
+ dst: d, src1: a, src2: b
+ }
+ }
+ }
+ add{.rnd}.f64 d, a, b => {
+ Instruction::Add {
+ data: ast::ArithDetails::Float(
+ ast::ArithFloat {
+ type_: f64,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: None,
+ saturate: false
+ }
+ ),
+ arguments: AddArgs {
+ dst: d, src1: a, src2: b
+ }
+ }
+ }
+ .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
+ ScalarType = { .f32, .f64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/#half-precision-floating-point-instructions-add
+ add{.rnd}{.ftz}{.sat}.f16 d, a, b => {
+ Instruction::Add {
+ data: ast::ArithDetails::Float(
+ ast::ArithFloat {
+ type_: f16,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: Some(ftz),
+ saturate: sat
+ }
+ ),
+ arguments: AddArgs {
+ dst: d, src1: a, src2: b
+ }
+ }
+ }
+ add{.rnd}{.ftz}{.sat}.f16x2 d, a, b => {
+ Instruction::Add {
+ data: ast::ArithDetails::Float(
+ ast::ArithFloat {
+ type_: f16x2,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: Some(ftz),
+ saturate: sat
+ }
+ ),
+ arguments: AddArgs {
+ dst: d, src1: a, src2: b
+ }
+ }
+ }
+ add{.rnd}.bf16 d, a, b => {
+ Instruction::Add {
+ data: ast::ArithDetails::Float(
+ ast::ArithFloat {
+ type_: bf16,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: None,
+ saturate: false
+ }
+ ),
+ arguments: AddArgs {
+ dst: d, src1: a, src2: b
+ }
+ }
+ }
+ add{.rnd}.bf16x2 d, a, b => {
+ Instruction::Add {
+ data: ast::ArithDetails::Float(
+ ast::ArithFloat {
+ type_: bf16x2,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: None,
+ saturate: false
+ }
+ ),
+ arguments: AddArgs {
+ dst: d, src1: a, src2: b
+ }
+ }
+ }
+ .rnd: RawRoundingMode = { .rn };
+ ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mul
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mul
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-mul
+ mul.mode.type d, a, b => {
+ ast::Instruction::Mul {
+ data: ast::MulDetails::Integer {
+ type_,
+ control: mode.into()
+ },
+ arguments: MulArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .mode: RawMulIntControl = { .hi, .lo };
+ .type: ScalarType = { .u16, .u32, .u64,
+ .s16, .s32, .s64 };
+ // "The .wide suffix is supported only for 16- and 32-bit integer types"
+ mul.wide.type d, a, b => {
+ ast::Instruction::Mul {
+ data: ast::MulDetails::Integer {
+ type_,
+ control: wide.into()
+ },
+ arguments: MulArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .type: ScalarType = { .u16, .u32,
+ .s16, .s32 };
+ RawMulIntControl = { .wide };
+
+ mul{.rnd}{.ftz}{.sat}.f32 d, a, b => {
+ ast::Instruction::Mul {
+ data: ast::MulDetails::Float (
+ ast::ArithFloat {
+ type_: f32,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: Some(ftz),
+ saturate: sat,
+ }
+ ),
+ arguments: MulArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ mul{.rnd}.f64 d, a, b => {
+ ast::Instruction::Mul {
+ data: ast::MulDetails::Float (
+ ast::ArithFloat {
+ type_: f64,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: None,
+ saturate: false,
+ }
+ ),
+ arguments: MulArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
+ ScalarType = { .f32, .f64 };
+
+ mul{.rnd}{.ftz}{.sat}.f16 d, a, b => {
+ ast::Instruction::Mul {
+ data: ast::MulDetails::Float (
+ ast::ArithFloat {
+ type_: f16,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: Some(ftz),
+ saturate: sat,
+ }
+ ),
+ arguments: MulArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ mul{.rnd}{.ftz}{.sat}.f16x2 d, a, b => {
+ ast::Instruction::Mul {
+ data: ast::MulDetails::Float (
+ ast::ArithFloat {
+ type_: f16x2,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: Some(ftz),
+ saturate: sat,
+ }
+ ),
+ arguments: MulArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ mul{.rnd}.bf16 d, a, b => {
+ ast::Instruction::Mul {
+ data: ast::MulDetails::Float (
+ ast::ArithFloat {
+ type_: bf16,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: None,
+ saturate: false,
+ }
+ ),
+ arguments: MulArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ mul{.rnd}.bf16x2 d, a, b => {
+ ast::Instruction::Mul {
+ data: ast::MulDetails::Float (
+ ast::ArithFloat {
+ type_: bf16x2,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: None,
+ saturate: false,
+ }
+ ),
+ arguments: MulArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .rnd: RawRoundingMode = { .rn };
+ ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-setp
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-comparison-instructions-setp
+ setp.CmpOp{.ftz}.type p[|q], a, b => {
+ let data = ast::SetpData::try_parse(state, cmpop, ftz, type_);
+ ast::Instruction::Setp {
+ data,
+ arguments: SetpArgs { dst1: p, dst2: q, src1: a, src2: b }
+ }
+ }
+ setp.CmpOp.BoolOp{.ftz}.type p[|q], a, b, {!}c => {
+ let (negate_src3, c) = c;
+ let base = ast::SetpData::try_parse(state, cmpop, ftz, type_);
+ let data = ast::SetpBoolData {
+ base,
+ bool_op: boolop,
+ negate_src3
+ };
+ ast::Instruction::SetpBool {
+ data,
+ arguments: SetpBoolArgs { dst1: p, dst2: q, src1: a, src2: b, src3: c }
+ }
+ }
+ .CmpOp: RawSetpCompareOp = { .eq, .ne, .lt, .le, .gt, .ge,
+ .lo, .ls, .hi, .hs, // signed
+ .equ, .neu, .ltu, .leu, .gtu, .geu, .num, .nan }; // float-only
+ .BoolOp: SetpBoolPostOp = { .and, .or, .xor };
+ .type: ScalarType = { .b16, .b32, .b64,
+ .u16, .u32, .u64,
+ .s16, .s32, .s64,
+ .f32, .f64,
+ .f16, .f16x2, .bf16, .bf16x2 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-not
+ not.type d, a => {
+ ast::Instruction::Not {
+ data: type_,
+ arguments: NotArgs { dst: d, src: a }
+ }
+ }
+ .type: ScalarType = { .pred, .b16, .b32, .b64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-or
+ or.type d, a, b => {
+ ast::Instruction::Or {
+ data: type_,
+ arguments: OrArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .type: ScalarType = { .pred, .b16, .b32, .b64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-and
+ and.type d, a, b => {
+ ast::Instruction::And {
+ data: type_,
+ arguments: AndArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .type: ScalarType = { .pred, .b16, .b32, .b64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-bra
+ bra <= { bra(stream) }
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-call
+ call <= { call(stream) }
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt
+ cvt{.ifrnd}{.ftz}{.sat}.dtype.atype d, a => {
+ let data = ast::CvtDetails::new(&mut state.errors, ifrnd, ftz, sat, dtype, atype);
+ let arguments = ast::CvtArgs { dst: d, src: a };
+ ast::Instruction::Cvt {
+ data, arguments
+ }
+ }
+ // cvt.frnd2{.relu}{.satfinite}.f16.f32 d, a;
+ // cvt.frnd2{.relu}{.satfinite}.f16x2.f32 d, a, b;
+ // cvt.frnd2{.relu}{.satfinite}.bf16.f32 d, a;
+ // cvt.frnd2{.relu}{.satfinite}.bf16x2.f32 d, a, b;
+ // cvt.rna{.satfinite}.tf32.f32 d, a;
+ // cvt.frnd2{.relu}.tf32.f32 d, a;
+ // cvt.rn.satfinite{.relu}.f8x2type.f32 d, a, b;
+ // cvt.rn.satfinite{.relu}.f8x2type.f16x2 d, a;
+ // cvt.rn.{.relu}.f16x2.f8x2type d, a;
+
+ .ifrnd: RawRoundingMode = { .rn, .rz, .rm, .rp, .rni, .rzi, .rmi, .rpi };
+ .frnd2: RawRoundingMode = { .rn, .rz };
+ .dtype: ScalarType = { .u8, .u16, .u32, .u64,
+ .s8, .s16, .s32, .s64,
+ .bf16, .f16, .f32, .f64 };
+ .atype: ScalarType = { .u8, .u16, .u32, .u64,
+ .s8, .s16, .s32, .s64,
+ .bf16, .f16, .f32, .f64 };
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
+ shl.type d, a, b => {
+ ast::Instruction::Shl { data: type_, arguments: ShlArgs { dst: d, src1: a, src2: b } }
+ }
+ .type: ScalarType = { .b16, .b32, .b64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr
+ shr.type d, a, b => {
+ let kind = if type_.kind() == ast::ScalarKind::Signed { RightShiftKind::Arithmetic} else { RightShiftKind::Logical };
+ ast::Instruction::Shr {
+ data: ast::ShrData { type_, kind },
+ arguments: ShrArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .type: ScalarType = { .b16, .b32, .b64,
+ .u16, .u32, .u64,
+ .s16, .s32, .s64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvta
+ cvta.space.size p, a => {
+ if size != ScalarType::U64 {
+ state.errors.push(PtxError::Unsupported32Bit);
+ }
+ let data = ast::CvtaDetails {
+ state_space: space,
+ direction: ast::CvtaDirection::ExplicitToGeneric
+ };
+ let arguments = ast::CvtaArgs {
+ dst: p, src: a
+ };
+ ast::Instruction::Cvta {
+ data, arguments
+ }
+ }
+ cvta.to.space.size p, a => {
+ if size != ScalarType::U64 {
+ state.errors.push(PtxError::Unsupported32Bit);
+ }
+ let data = ast::CvtaDetails {
+ state_space: space,
+ direction: ast::CvtaDirection::GenericToExplicit
+ };
+ let arguments = ast::CvtaArgs {
+ dst: p, src: a
+ };
+ ast::Instruction::Cvta {
+ data, arguments
+ }
+ }
+ .space: StateSpace = { .const, .global, .local, .shared{::cta, ::cluster}, .param{::entry} };
+ .size: ScalarType = { .u32, .u64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-abs
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-abs
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-abs
+ abs.type d, a => {
+ ast::Instruction::Abs {
+ data: ast::TypeFtz {
+ flush_to_zero: None,
+ type_
+ },
+ arguments: ast::AbsArgs {
+ dst: d, src: a
+ }
+ }
+ }
+ abs{.ftz}.f32 d, a => {
+ ast::Instruction::Abs {
+ data: ast::TypeFtz {
+ flush_to_zero: Some(ftz),
+ type_: f32
+ },
+ arguments: ast::AbsArgs {
+ dst: d, src: a
+ }
+ }
+ }
+ abs.f64 d, a => {
+ ast::Instruction::Abs {
+ data: ast::TypeFtz {
+ flush_to_zero: None,
+ type_: f64
+ },
+ arguments: ast::AbsArgs {
+ dst: d, src: a
+ }
+ }
+ }
+ abs{.ftz}.f16 d, a => {
+ ast::Instruction::Abs {
+ data: ast::TypeFtz {
+ flush_to_zero: Some(ftz),
+ type_: f16
+ },
+ arguments: ast::AbsArgs {
+ dst: d, src: a
+ }
+ }
+ }
+ abs{.ftz}.f16x2 d, a => {
+ ast::Instruction::Abs {
+ data: ast::TypeFtz {
+ flush_to_zero: Some(ftz),
+ type_: f16x2
+ },
+ arguments: ast::AbsArgs {
+ dst: d, src: a
+ }
+ }
+ }
+ abs.bf16 d, a => {
+ ast::Instruction::Abs {
+ data: ast::TypeFtz {
+ flush_to_zero: None,
+ type_: bf16
+ },
+ arguments: ast::AbsArgs {
+ dst: d, src: a
+ }
+ }
+ }
+ abs.bf16x2 d, a => {
+ ast::Instruction::Abs {
+ data: ast::TypeFtz {
+ flush_to_zero: None,
+ type_: bf16x2
+ },
+ arguments: ast::AbsArgs {
+ dst: d, src: a
+ }
+ }
+ }
+ .type: ScalarType = { .s16, .s32, .s64 };
+ ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-mad
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-mad
+ mad.mode.type d, a, b, c => {
+ ast::Instruction::Mad {
+ data: ast::MadDetails::Integer {
+ type_,
+ control: mode.into(),
+ saturate: false
+ },
+ arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
+ }
+ }
+ .type: ScalarType = { .u16, .u32, .u64,
+ .s16, .s32, .s64 };
+ .mode: RawMulIntControl = { .hi, .lo };
+
+ // The .wide suffix is supported only for 16-bit and 32-bit integer types.
+ mad.wide.type d, a, b, c => {
+ ast::Instruction::Mad {
+ data: ast::MadDetails::Integer {
+ type_,
+ control: wide.into(),
+ saturate: false
+ },
+ arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
+ }
+ }
+ .type: ScalarType = { .u16, .u32,
+ .s16, .s32 };
+ RawMulIntControl = { .wide };
+
+ mad.hi.sat.s32 d, a, b, c => {
+ ast::Instruction::Mad {
+ data: ast::MadDetails::Integer {
+ type_: s32,
+ control: hi.into(),
+ saturate: true
+ },
+ arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
+ }
+ }
+ RawMulIntControl = { .hi };
+ ScalarType = { .s32 };
+
+ mad{.ftz}{.sat}.f32 d, a, b, c => {
+ ast::Instruction::Mad {
+ data: ast::MadDetails::Float(
+ ast::ArithFloat {
+ type_: f32,
+ rounding: None,
+ flush_to_zero: Some(ftz),
+ saturate: sat
+ }
+ ),
+ arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
+ }
+ }
+ mad.rnd{.ftz}{.sat}.f32 d, a, b, c => {
+ ast::Instruction::Mad {
+ data: ast::MadDetails::Float(
+ ast::ArithFloat {
+ type_: f32,
+ rounding: Some(rnd.into()),
+ flush_to_zero: Some(ftz),
+ saturate: sat
+ }
+ ),
+ arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
+ }
+ }
+ mad.rnd.f64 d, a, b, c => {
+ ast::Instruction::Mad {
+ data: ast::MadDetails::Float(
+ ast::ArithFloat {
+ type_: f64,
+ rounding: Some(rnd.into()),
+ flush_to_zero: None,
+ saturate: false
+ }
+ ),
+ arguments: MadArgs { dst: d, src1: a, src2: b, src3: c }
+ }
+ }
+ .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
+ ScalarType = { .f32, .f64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-fma
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-fma
+ fma.rnd{.ftz}{.sat}.f32 d, a, b, c => {
+ ast::Instruction::Fma {
+ data: ast::ArithFloat {
+ type_: f32,
+ rounding: Some(rnd.into()),
+ flush_to_zero: Some(ftz),
+ saturate: sat
+ },
+ arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
+ }
+ }
+ fma.rnd.f64 d, a, b, c => {
+ ast::Instruction::Fma {
+ data: ast::ArithFloat {
+ type_: f64,
+ rounding: Some(rnd.into()),
+ flush_to_zero: None,
+ saturate: false
+ },
+ arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
+ }
+ }
+ .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
+ ScalarType = { .f32, .f64 };
+
+ fma.rnd{.ftz}{.sat}.f16 d, a, b, c => {
+ ast::Instruction::Fma {
+ data: ast::ArithFloat {
+ type_: f16,
+ rounding: Some(rnd.into()),
+ flush_to_zero: Some(ftz),
+ saturate: sat
+ },
+ arguments: FmaArgs { dst: d, src1: a, src2: b, src3: c }
+ }
+ }
+ //fma.rnd{.ftz}{.sat}.f16x2 d, a, b, c;
+ //fma.rnd{.ftz}.relu.f16 d, a, b, c;
+ //fma.rnd{.ftz}.relu.f16x2 d, a, b, c;
+ //fma.rnd{.relu}.bf16 d, a, b, c;
+ //fma.rnd{.relu}.bf16x2 d, a, b, c;
+ //fma.rnd.oob.{relu}.type d, a, b, c;
+ .rnd: RawRoundingMode = { .rn };
+ ScalarType = { .f16 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-sub
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sub
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-sub
+ sub.type d, a, b => {
+ ast::Instruction::Sub {
+ data: ast::ArithDetails::Integer(
+ ArithInteger {
+ type_,
+ saturate: false
+ }
+ ),
+ arguments: SubArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ sub.sat.s32 d, a, b => {
+ ast::Instruction::Sub {
+ data: ast::ArithDetails::Integer(
+ ArithInteger {
+ type_: s32,
+ saturate: true
+ }
+ ),
+ arguments: SubArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .type: ScalarType = { .u16, .u32, .u64,
+ .s16, .s32, .s64 };
+ ScalarType = { .s32 };
+
+ sub{.rnd}{.ftz}{.sat}.f32 d, a, b => {
+ ast::Instruction::Sub {
+ data: ast::ArithDetails::Float(
+ ast::ArithFloat {
+ type_: f32,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: Some(ftz),
+ saturate: sat
+ }
+ ),
+ arguments: SubArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ sub{.rnd}.f64 d, a, b => {
+ ast::Instruction::Sub {
+ data: ast::ArithDetails::Float(
+ ast::ArithFloat {
+ type_: f64,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: None,
+ saturate: false
+ }
+ ),
+ arguments: SubArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
+ ScalarType = { .f32, .f64 };
+
+ sub{.rnd}{.ftz}{.sat}.f16 d, a, b => {
+ ast::Instruction::Sub {
+ data: ast::ArithDetails::Float(
+ ast::ArithFloat {
+ type_: f16,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: Some(ftz),
+ saturate: sat
+ }
+ ),
+ arguments: SubArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ sub{.rnd}{.ftz}{.sat}.f16x2 d, a, b => {
+ ast::Instruction::Sub {
+ data: ast::ArithDetails::Float(
+ ast::ArithFloat {
+ type_: f16x2,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: Some(ftz),
+ saturate: sat
+ }
+ ),
+ arguments: SubArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ sub{.rnd}.bf16 d, a, b => {
+ ast::Instruction::Sub {
+ data: ast::ArithDetails::Float(
+ ast::ArithFloat {
+ type_: bf16,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: None,
+ saturate: false
+ }
+ ),
+ arguments: SubArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ sub{.rnd}.bf16x2 d, a, b => {
+ ast::Instruction::Sub {
+ data: ast::ArithDetails::Float(
+ ast::ArithFloat {
+ type_: bf16x2,
+ rounding: rnd.map(Into::into),
+ flush_to_zero: None,
+ saturate: false
+ }
+ ),
+ arguments: SubArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .rnd: RawRoundingMode = { .rn };
+ ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-min
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-min
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-min
+ min.atype d, a, b => {
+ ast::Instruction::Min {
+ data: if atype.kind() == ast::ScalarKind::Signed {
+ ast::MinMaxDetails::Signed(atype)
+ } else {
+ ast::MinMaxDetails::Unsigned(atype)
+ },
+ arguments: MinArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ //min{.relu}.btype d, a, b => { todo!() }
+ min.btype d, a, b => {
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Signed(btype),
+ arguments: MinArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .atype: ScalarType = { .u16, .u32, .u64,
+ .u16x2, .s16, .s64 };
+ .btype: ScalarType = { .s16x2, .s32 };
+
+ //min{.ftz}{.NaN}{.xorsign.abs}.f32 d, a, b;
+ min{.ftz}{.NaN}.f32 d, a, b => {
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Float(
+ MinMaxFloat {
+ flush_to_zero: Some(ftz),
+ nan,
+ type_: f32
+ }
+ ),
+ arguments: MinArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ min.f64 d, a, b => {
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Float(
+ MinMaxFloat {
+ flush_to_zero: None,
+ nan: false,
+ type_: f64
+ }
+ ),
+ arguments: MinArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ ScalarType = { .f32, .f64 };
+
+ //min{.ftz}{.NaN}{.xorsign.abs}.f16 d, a, b;
+ //min{.ftz}{.NaN}{.xorsign.abs}.f16x2 d, a, b;
+ //min{.NaN}{.xorsign.abs}.bf16 d, a, b;
+ //min{.NaN}{.xorsign.abs}.bf16x2 d, a, b;
+ min{.ftz}{.NaN}.f16 d, a, b => {
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Float(
+ MinMaxFloat {
+ flush_to_zero: Some(ftz),
+ nan,
+ type_: f16
+ }
+ ),
+ arguments: MinArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ min{.ftz}{.NaN}.f16x2 d, a, b => {
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Float(
+ MinMaxFloat {
+ flush_to_zero: Some(ftz),
+ nan,
+ type_: f16x2
+ }
+ ),
+ arguments: MinArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ min{.NaN}.bf16 d, a, b => {
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Float(
+ MinMaxFloat {
+ flush_to_zero: None,
+ nan,
+ type_: bf16
+ }
+ ),
+ arguments: MinArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ min{.NaN}.bf16x2 d, a, b => {
+ ast::Instruction::Min {
+ data: ast::MinMaxDetails::Float(
+ MinMaxFloat {
+ flush_to_zero: None,
+ nan,
+ type_: bf16x2
+ }
+ ),
+ arguments: MinArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-max
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-max
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-max
+ max.atype d, a, b => {
+ ast::Instruction::Max {
+ data: if atype.kind() == ast::ScalarKind::Signed {
+ ast::MinMaxDetails::Signed(atype)
+ } else {
+ ast::MinMaxDetails::Unsigned(atype)
+ },
+ arguments: MaxArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ //max{.relu}.btype d, a, b => { todo!() }
+ max.btype d, a, b => {
+ ast::Instruction::Max {
+ data: ast::MinMaxDetails::Signed(btype),
+ arguments: MaxArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .atype: ScalarType = { .u16, .u32, .u64,
+ .u16x2, .s16, .s64 };
+ .btype: ScalarType = { .s16x2, .s32 };
+
+ //max{.ftz}{.NaN}{.xorsign.abs}.f32 d, a, b;
+ max{.ftz}{.NaN}.f32 d, a, b => {
+ ast::Instruction::Max {
+ data: ast::MinMaxDetails::Float(
+ MinMaxFloat {
+ flush_to_zero: Some(ftz),
+ nan,
+ type_: f32
+ }
+ ),
+ arguments: MaxArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ max.f64 d, a, b => {
+ ast::Instruction::Max {
+ data: ast::MinMaxDetails::Float(
+ MinMaxFloat {
+ flush_to_zero: None,
+ nan: false,
+ type_: f64
+ }
+ ),
+ arguments: MaxArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ ScalarType = { .f32, .f64 };
+
+ //max{.ftz}{.NaN}{.xorsign.abs}.f16 d, a, b;
+ //max{.ftz}{.NaN}{.xorsign.abs}.f16x2 d, a, b;
+ //max{.NaN}{.xorsign.abs}.bf16 d, a, b;
+ //max{.NaN}{.xorsign.abs}.bf16x2 d, a, b;
+ max{.ftz}{.NaN}.f16 d, a, b => {
+ ast::Instruction::Max {
+ data: ast::MinMaxDetails::Float(
+ MinMaxFloat {
+ flush_to_zero: Some(ftz),
+ nan,
+ type_: f16
+ }
+ ),
+ arguments: MaxArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ max{.ftz}{.NaN}.f16x2 d, a, b => {
+ ast::Instruction::Max {
+ data: ast::MinMaxDetails::Float(
+ MinMaxFloat {
+ flush_to_zero: Some(ftz),
+ nan,
+ type_: f16x2
+ }
+ ),
+ arguments: MaxArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ max{.NaN}.bf16 d, a, b => {
+ ast::Instruction::Max {
+ data: ast::MinMaxDetails::Float(
+ MinMaxFloat {
+ flush_to_zero: None,
+ nan,
+ type_: bf16
+ }
+ ),
+ arguments: MaxArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ max{.NaN}.bf16x2 d, a, b => {
+ ast::Instruction::Max {
+ data: ast::MinMaxDetails::Float(
+ MinMaxFloat {
+ flush_to_zero: None,
+ nan,
+ type_: bf16x2
+ }
+ ),
+ arguments: MaxArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ ScalarType = { .f16, .f16x2, .bf16, .bf16x2 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rcp-approx-ftz-f64
+ rcp.approx{.ftz}.type d, a => {
+ ast::Instruction::Rcp {
+ data: ast::RcpData {
+ kind: ast::RcpKind::Approx,
+ flush_to_zero: Some(ftz),
+ type_
+ },
+ arguments: RcpArgs { dst: d, src: a }
+ }
+ }
+ rcp.rnd{.ftz}.f32 d, a => {
+ ast::Instruction::Rcp {
+ data: ast::RcpData {
+ kind: ast::RcpKind::Compliant(rnd.into()),
+ flush_to_zero: Some(ftz),
+ type_: f32
+ },
+ arguments: RcpArgs { dst: d, src: a }
+ }
+ }
+ rcp.rnd.f64 d, a => {
+ ast::Instruction::Rcp {
+ data: ast::RcpData {
+ kind: ast::RcpKind::Compliant(rnd.into()),
+ flush_to_zero: None,
+ type_: f64
+ },
+ arguments: RcpArgs { dst: d, src: a }
+ }
+ }
+ .type: ScalarType = { .f32, .f64 };
+ .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
+ ScalarType = { .f32, .f64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sqrt
+ sqrt.approx{.ftz}.f32 d, a => {
+ ast::Instruction::Sqrt {
+ data: ast::RcpData {
+ kind: ast::RcpKind::Approx,
+ flush_to_zero: Some(ftz),
+ type_: f32
+ },
+ arguments: SqrtArgs { dst: d, src: a }
+ }
+ }
+ sqrt.rnd{.ftz}.f32 d, a => {
+ ast::Instruction::Sqrt {
+ data: ast::RcpData {
+ kind: ast::RcpKind::Compliant(rnd.into()),
+ flush_to_zero: Some(ftz),
+ type_: f32
+ },
+ arguments: SqrtArgs { dst: d, src: a }
+ }
+ }
+ sqrt.rnd.f64 d, a => {
+ ast::Instruction::Sqrt {
+ data: ast::RcpData {
+ kind: ast::RcpKind::Compliant(rnd.into()),
+ flush_to_zero: None,
+ type_: f64
+ },
+ arguments: SqrtArgs { dst: d, src: a }
+ }
+ }
+ .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
+ ScalarType = { .f32, .f64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-rsqrt-approx-ftz-f64
+ rsqrt.approx{.ftz}.f32 d, a => {
+ ast::Instruction::Rsqrt {
+ data: ast::TypeFtz {
+ flush_to_zero: Some(ftz),
+ type_: f32
+ },
+ arguments: RsqrtArgs { dst: d, src: a }
+ }
+ }
+ rsqrt.approx.f64 d, a => {
+ ast::Instruction::Rsqrt {
+ data: ast::TypeFtz {
+ flush_to_zero: None,
+ type_: f64
+ },
+ arguments: RsqrtArgs { dst: d, src: a }
+ }
+ }
+ rsqrt.approx.ftz.f64 d, a => {
+ ast::Instruction::Rsqrt {
+ data: ast::TypeFtz {
+ flush_to_zero: None,
+ type_: f64
+ },
+ arguments: RsqrtArgs { dst: d, src: a }
+ }
+ }
+ ScalarType = { .f32, .f64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#comparison-and-selection-instructions-selp
+ selp.type d, a, b, c => {
+ ast::Instruction::Selp {
+ data: type_,
+ arguments: SelpArgs { dst: d, src1: a, src2: b, src3: c }
+ }
+ }
+ .type: ScalarType = { .b16, .b32, .b64,
+ .u16, .u32, .u64,
+ .s16, .s32, .s64,
+ .f32, .f64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar
+ barrier{.cta}.sync{.aligned} a{, b} => {
+ let _ = cta;
+ ast::Instruction::Bar {
+ data: ast::BarData { aligned },
+ arguments: BarArgs { src1: a, src2: b }
+ }
+ }
+ //barrier{.cta}.arrive{.aligned} a, b;
+ //barrier{.cta}.red.popc{.aligned}.u32 d, a{, b}, {!}c;
+ //barrier{.cta}.red.op{.aligned}.pred p, a{, b}, {!}c;
+ bar{.cta}.sync a{, b} => {
+ let _ = cta;
+ ast::Instruction::Bar {
+ data: ast::BarData { aligned: true },
+ arguments: BarArgs { src1: a, src2: b }
+ }
+ }
+ //bar{.cta}.arrive a, b;
+ //bar{.cta}.red.popc.u32 d, a{, b}, {!}c;
+ //bar{.cta}.red.op.pred p, a{, b}, {!}c;
+ //.op = { .and, .or };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-atom
+ atom{.sem}{.scope}{.space}.op{.level::cache_hint}.type d, [a], b{, cache_policy} => {
+ if level_cache_hint || cache_policy.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ ast::Instruction::Atom {
+ data: AtomDetails {
+ semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(MemScope::Gpu),
+ space: space.unwrap_or(StateSpace::Generic),
+ op: ast::AtomicOp::new(op, type_.kind()),
+ type_: type_.into()
+ },
+ arguments: AtomArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ atom{.sem}{.scope}{.space}.cas.cas_type d, [a], b, c => {
+ ast::Instruction::AtomCas {
+ data: AtomCasDetails {
+ semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(MemScope::Gpu),
+ space: space.unwrap_or(StateSpace::Generic),
+ type_: cas_type
+ },
+ arguments: AtomCasArgs { dst: d, src1: a, src2: b, src3: c }
+ }
+ }
+ atom{.sem}{.scope}{.space}.exch{.level::cache_hint}.b128 d, [a], b{, cache_policy} => {
+ if level_cache_hint || cache_policy.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ ast::Instruction::Atom {
+ data: AtomDetails {
+ semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(MemScope::Gpu),
+ space: space.unwrap_or(StateSpace::Generic),
+ op: ast::AtomicOp::new(exch, b128.kind()),
+ type_: b128.into()
+ },
+ arguments: AtomArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ atom{.sem}{.scope}{.global}.float_op{.level::cache_hint}.vec_32_bit.f32 d, [a], b{, cache_policy} => {
+ if level_cache_hint || cache_policy.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ ast::Instruction::Atom {
+ data: AtomDetails {
+ semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(MemScope::Gpu),
+ space: global.unwrap_or(StateSpace::Generic),
+ op: ast::AtomicOp::new(float_op, f32.kind()),
+ type_: ast::Type::Vector(vec_32_bit.len().get(), f32)
+ },
+ arguments: AtomArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ atom{.sem}{.scope}{.global}.float_op.noftz{.level::cache_hint}{.vec_16_bit}.half_word_type d, [a], b{, cache_policy} => {
+ if level_cache_hint || cache_policy.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ ast::Instruction::Atom {
+ data: AtomDetails {
+ semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(MemScope::Gpu),
+ space: global.unwrap_or(StateSpace::Generic),
+ op: ast::AtomicOp::new(float_op, half_word_type.kind()),
+ type_: ast::Type::maybe_vector(vec_16_bit, half_word_type)
+ },
+ arguments: AtomArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ atom{.sem}{.scope}{.global}.float_op.noftz{.level::cache_hint}{.vec_32_bit}.packed_type d, [a], b{, cache_policy} => {
+ if level_cache_hint || cache_policy.is_some() {
+ state.errors.push(PtxError::Todo);
+ }
+ ast::Instruction::Atom {
+ data: AtomDetails {
+ semantics: sem.map(Into::into).unwrap_or(AtomSemantics::Relaxed),
+ scope: scope.unwrap_or(MemScope::Gpu),
+ space: global.unwrap_or(StateSpace::Generic),
+ op: ast::AtomicOp::new(float_op, packed_type.kind()),
+ type_: ast::Type::maybe_vector(vec_32_bit, packed_type)
+ },
+ arguments: AtomArgs { dst: d, src1: a, src2: b }
+ }
+ }
+ .space: StateSpace = { .global, .shared{::cta, ::cluster} };
+ .sem: AtomSemantics = { .relaxed, .acquire, .release, .acq_rel };
+ .scope: MemScope = { .cta, .cluster, .gpu, .sys };
+ .op: RawAtomicOp = { .and, .or, .xor,
+ .exch,
+ .add, .inc, .dec,
+ .min, .max };
+ .level::cache_hint = { .L2::cache_hint };
+ .type: ScalarType = { .b32, .b64, .u32, .u64, .s32, .s64, .f32, .f64 };
+ .cas_type: ScalarType = { .b32, .b64, .u32, .u64, .s32, .s64, .f32, .f64, .b16, .b128 };
+ .half_word_type: ScalarType = { .f16, .bf16 };
+ .packed_type: ScalarType = { .f16x2, .bf16x2 };
+ .vec_16_bit: VectorPrefix = { .v2, .v4, .v8 };
+ .vec_32_bit: VectorPrefix = { .v2, .v4 };
+ .float_op: RawAtomicOp = { .add, .min, .max };
+ ScalarType = { .b16, .b128, .f32 };
+ StateSpace = { .global };
+ RawAtomicOp = { .exch };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-div
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div
+ div.type d, a, b => {
+ ast::Instruction::Div {
+ data: if type_.kind() == ast::ScalarKind::Signed {
+ ast::DivDetails::Signed(type_)
+ } else {
+ ast::DivDetails::Unsigned(type_)
+ },
+ arguments: DivArgs {
+ dst: d,
+ src1: a,
+ src2: b,
+ },
+ }
+ }
+ .type: ScalarType = { .u16, .u32, .u64,
+ .s16, .s32, .s64 };
+
+ div.approx{.ftz}.f32 d, a, b => {
+ ast::Instruction::Div {
+ data: ast::DivDetails::Float(ast::DivFloatDetails{
+ type_: f32,
+ flush_to_zero: Some(ftz),
+ kind: ast::DivFloatKind::Approx
+ }),
+ arguments: DivArgs {
+ dst: d,
+ src1: a,
+ src2: b,
+ },
+ }
+ }
+ div.full{.ftz}.f32 d, a, b => {
+ ast::Instruction::Div {
+ data: ast::DivDetails::Float(ast::DivFloatDetails{
+ type_: f32,
+ flush_to_zero: Some(ftz),
+ kind: ast::DivFloatKind::ApproxFull
+ }),
+ arguments: DivArgs {
+ dst: d,
+ src1: a,
+ src2: b,
+ },
+ }
+ }
+ div.rnd{.ftz}.f32 d, a, b => {
+ ast::Instruction::Div {
+ data: ast::DivDetails::Float(ast::DivFloatDetails{
+ type_: f32,
+ flush_to_zero: Some(ftz),
+ kind: ast::DivFloatKind::Rounding(rnd.into())
+ }),
+ arguments: DivArgs {
+ dst: d,
+ src1: a,
+ src2: b,
+ },
+ }
+ }
+ div.rnd.f64 d, a, b => {
+ ast::Instruction::Div {
+ data: ast::DivDetails::Float(ast::DivFloatDetails{
+ type_: f64,
+ flush_to_zero: None,
+ kind: ast::DivFloatKind::Rounding(rnd.into())
+ }),
+ arguments: DivArgs {
+ dst: d,
+ src1: a,
+ src2: b,
+ },
+ }
+ }
+ .rnd: RawRoundingMode = { .rn, .rz, .rm, .rp };
+ ScalarType = { .f32, .f64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-neg
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-neg
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-neg
+ neg.type d, a => {
+ ast::Instruction::Neg {
+ data: TypeFtz {
+ type_,
+ flush_to_zero: None
+ },
+ arguments: NegArgs { dst: d, src: a, },
+ }
+ }
+ .type: ScalarType = { .s16, .s32, .s64 };
+
+ neg{.ftz}.f32 d, a => {
+ ast::Instruction::Neg {
+ data: TypeFtz {
+ type_: f32,
+ flush_to_zero: Some(ftz)
+ },
+ arguments: NegArgs { dst: d, src: a, },
+ }
+ }
+ neg.f64 d, a => {
+ ast::Instruction::Neg {
+ data: TypeFtz {
+ type_: f64,
+ flush_to_zero: None
+ },
+ arguments: NegArgs { dst: d, src: a, },
+ }
+ }
+ neg{.ftz}.f16 d, a => {
+ ast::Instruction::Neg {
+ data: TypeFtz {
+ type_: f16,
+ flush_to_zero: Some(ftz)
+ },
+ arguments: NegArgs { dst: d, src: a, },
+ }
+ }
+ neg{.ftz}.f16x2 d, a => {
+ ast::Instruction::Neg {
+ data: TypeFtz {
+ type_: f16x2,
+ flush_to_zero: Some(ftz)
+ },
+ arguments: NegArgs { dst: d, src: a, },
+ }
+ }
+ neg.bf16 d, a => {
+ ast::Instruction::Neg {
+ data: TypeFtz {
+ type_: bf16,
+ flush_to_zero: None
+ },
+ arguments: NegArgs { dst: d, src: a, },
+ }
+ }
+ neg.bf16x2 d, a => {
+ ast::Instruction::Neg {
+ data: TypeFtz {
+ type_: bf16x2,
+ flush_to_zero: None
+ },
+ arguments: NegArgs { dst: d, src: a, },
+ }
+ }
+ ScalarType = { .f32, .f64, .f16, .f16x2, .bf16, .bf16x2 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-sin
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-cos
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-lg2
+ sin.approx{.ftz}.f32 d, a => {
+ ast::Instruction::Sin {
+ data: ast::FlushToZero {
+ flush_to_zero: ftz
+ },
+ arguments: SinArgs { dst: d, src: a, },
+ }
+ }
+ cos.approx{.ftz}.f32 d, a => {
+ ast::Instruction::Cos {
+ data: ast::FlushToZero {
+ flush_to_zero: ftz
+ },
+ arguments: CosArgs { dst: d, src: a, },
+ }
+ }
+ lg2.approx{.ftz}.f32 d, a => {
+ ast::Instruction::Lg2 {
+ data: ast::FlushToZero {
+ flush_to_zero: ftz
+ },
+ arguments: Lg2Args { dst: d, src: a, },
+ }
+ }
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-ex2
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#half-precision-floating-point-instructions-ex2
+ ex2.approx{.ftz}.f32 d, a => {
+ ast::Instruction::Ex2 {
+ data: ast::TypeFtz {
+ type_: f32,
+ flush_to_zero: Some(ftz)
+ },
+ arguments: Ex2Args { dst: d, src: a, },
+ }
+ }
+ ex2.approx.atype d, a => {
+ ast::Instruction::Ex2 {
+ data: ast::TypeFtz {
+ type_: atype,
+ flush_to_zero: None
+ },
+ arguments: Ex2Args { dst: d, src: a, },
+ }
+ }
+ ex2.approx.ftz.btype d, a => {
+ ast::Instruction::Ex2 {
+ data: ast::TypeFtz {
+ type_: btype,
+ flush_to_zero: Some(true)
+ },
+ arguments: Ex2Args { dst: d, src: a, },
+ }
+ }
+ .atype: ScalarType = { .f16, .f16x2 };
+ .btype: ScalarType = { .bf16, .bf16x2 };
+ ScalarType = { .f32 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-clz
+ clz.type d, a => {
+ ast::Instruction::Clz {
+ data: type_,
+ arguments: ClzArgs { dst: d, src: a, },
+ }
+ }
+ .type: ScalarType = { .b32, .b64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-brev
+ brev.type d, a => {
+ ast::Instruction::Brev {
+ data: type_,
+ arguments: BrevArgs { dst: d, src: a, },
+ }
+ }
+ .type: ScalarType = { .b32, .b64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-popc
+ popc.type d, a => {
+ ast::Instruction::Popc {
+ data: type_,
+ arguments: PopcArgs { dst: d, src: a, },
+ }
+ }
+ .type: ScalarType = { .b32, .b64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-xor
+ xor.type d, a, b => {
+ ast::Instruction::Xor {
+ data: type_,
+ arguments: XorArgs { dst: d, src1: a, src2: b, },
+ }
+ }
+ .type: ScalarType = { .pred, .b16, .b32, .b64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-rem
+ rem.type d, a, b => {
+ ast::Instruction::Rem {
+ data: type_,
+ arguments: RemArgs { dst: d, src1: a, src2: b, },
+ }
+ }
+ .type: ScalarType = { .u16, .u32, .u64, .s16, .s32, .s64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfe
+ bfe.type d, a, b, c => {
+ ast::Instruction::Bfe {
+ data: type_,
+ arguments: BfeArgs { dst: d, src1: a, src2: b, src3: c },
+ }
+ }
+ .type: ScalarType = { .u32, .u64, .s32, .s64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-bfi
+ bfi.type f, a, b, c, d => {
+ ast::Instruction::Bfi {
+ data: type_,
+ arguments: BfiArgs { dst: f, src1: a, src2: b, src3: c, src4: d },
+ }
+ }
+ .type: ScalarType = { .b32, .b64 };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt
+ // prmt.b32{.mode} d, a, b, c;
+ // .mode = { .f4e, .b4e, .rc8, .ecl, .ecr, .rc16 };
+ prmt.b32 d, a, b, c => {
+ match c {
+ ast::ParsedOperand::Imm(ImmediateValue::S64(control)) => ast::Instruction::Prmt {
+ data: control as u16,
+ arguments: PrmtArgs {
+ dst: d, src1: a, src2: b
+ }
+ },
+ _ => ast::Instruction::PrmtSlow {
+ arguments: PrmtSlowArgs {
+ dst: d, src1: a, src2: b, src3: c
+ }
+ }
+ }
+ }
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-activemask
+ activemask.b32 d => {
+ ast::Instruction::Activemask {
+ arguments: ActivemaskArgs { dst: d }
+ }
+ }
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-membar
+ // fence{.sem}.scope;
+ // fence.op_restrict.release.cluster;
+ // fence.proxy.proxykind;
+ // fence.proxy.to_proxykind::from_proxykind.release.scope;
+ // fence.proxy.to_proxykind::from_proxykind.acquire.scope [addr], size;
+ //membar.proxy.proxykind;
+ //.sem = { .sc, .acq_rel };
+ //.scope = { .cta, .cluster, .gpu, .sys };
+ //.proxykind = { .alias, .async, async.global, .async.shared::{cta, cluster} };
+ //.op_restrict = { .mbarrier_init };
+ //.to_proxykind::from_proxykind = {.tensormap::generic};
+
+ membar.level => {
+ ast::Instruction::Membar { data: level }
+ }
+ membar.gl => {
+ ast::Instruction::Membar { data: MemScope::Gpu }
+ }
+ .level: MemScope = { .cta, .sys };
+
+ // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-ret
+ ret{.uni} => {
+ Instruction::Ret { data: RetData { uniform: uni } }
+ }
+
+);
+
+#[cfg(test)]
+mod tests {
+ use super::target;
+ use super::PtxParserState;
+ use super::Token;
+ use logos::Logos;
+ use winnow::prelude::*;
+
+ #[test]
+ fn sm_11() {
+ let tokens = Token::lexer(".target sm_11")
+ .collect::<Result<Vec<_>, ()>>()
+ .unwrap();
+ let stream = super::PtxParser {
+ input: &tokens[..],
+ state: PtxParserState::new(),
+ };
+ assert_eq!(target.parse(stream).unwrap(), (11, None));
+ }
+
+ #[test]
+ fn sm_90a() {
+ let tokens = Token::lexer(".target sm_90a")
+ .collect::<Result<Vec<_>, ()>>()
+ .unwrap();
+ let stream = super::PtxParser {
+ input: &tokens[..],
+ state: PtxParserState::new(),
+ };
+ assert_eq!(target.parse(stream).unwrap(), (90, Some('a')));
+ }
+
+ #[test]
+ fn sm_90ab() {
+ let tokens = Token::lexer(".target sm_90ab")
+ .collect::<Result<Vec<_>, ()>>()
+ .unwrap();
+ let stream = super::PtxParser {
+ input: &tokens[..],
+ state: PtxParserState::new(),
+ };
+ assert!(target.parse(stream).is_err());
+ }
+}