aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx_parser
diff options
context:
space:
mode:
Diffstat (limited to 'ptx_parser')
-rw-r--r--ptx_parser/src/ast.rs53
-rw-r--r--ptx_parser/src/lib.rs24
2 files changed, 67 insertions, 10 deletions
diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs
index 65c624e..f0d3fbe 100644
--- a/ptx_parser/src/ast.rs
+++ b/ptx_parser/src/ast.rs
@@ -4,7 +4,7 @@ use super::{
};
use crate::{PtxError, PtxParserState};
use bitflags::bitflags;
-use std::{cmp::Ordering, num::NonZeroU8};
+use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8};
pub enum Statement<P: Operand> {
Label(P::Ident),
@@ -806,6 +806,32 @@ impl Type {
None => Self::maybe_vector_parsed(prefix, scalar),
}
}
+
+ pub fn layout(&self) -> Layout {
+ match self {
+ Type::Scalar(type_) => type_.layout(),
+ Type::Vector(elements, scalar_type) => {
+ let scalar_layout = scalar_type.layout();
+ unsafe {
+ Layout::from_size_align_unchecked(
+ scalar_layout.size() * *elements as usize,
+ scalar_layout.align() * *elements as usize,
+ )
+ }
+ }
+ Type::Array(non_zero, scalar, vec) => {
+ let element_layout = Type::maybe_vector_parsed(*non_zero, *scalar).layout();
+ let len = vec.iter().copied().reduce(std::ops::Mul::mul).unwrap_or(0);
+ unsafe {
+ Layout::from_size_align_unchecked(
+ element_layout.size() * (len as usize),
+ element_layout.align(),
+ )
+ }
+ }
+ Type::Pointer(..) => Layout::new::<usize>(),
+ }
+ }
}
impl ScalarType {
@@ -831,6 +857,31 @@ impl ScalarType {
}
}
+ pub fn layout(self) -> Layout {
+ match self {
+ ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => Layout::new::<u8>(),
+ ScalarType::U16
+ | ScalarType::S16
+ | ScalarType::B16
+ | ScalarType::F16
+ | ScalarType::BF16 => Layout::new::<u16>(),
+ ScalarType::U32
+ | ScalarType::S32
+ | ScalarType::B32
+ | ScalarType::F32
+ | ScalarType::U16x2
+ | ScalarType::S16x2
+ | ScalarType::F16x2
+ | ScalarType::BF16x2 => Layout::new::<u32>(),
+ ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => {
+ Layout::new::<u64>()
+ }
+ ScalarType::B128 => unsafe { Layout::from_size_align_unchecked(16, 16) },
+ // Close enough
+ ScalarType::Pred => Layout::new::<u8>(),
+ }
+ }
+
pub fn kind(self) -> ScalarKind {
match self {
ScalarType::U8 => ScalarKind::Unsigned,
diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs
index fee11aa..b49503b 100644
--- a/ptx_parser/src/lib.rs
+++ b/ptx_parser/src/lib.rs
@@ -1349,10 +1349,10 @@ impl std::error::Error for TokenError {}
// * After parsing, each instruction needs to do some early validation and generate a specific,
// strongly-typed object. We want strong-typing because we have a single PTX parser frontend, but
// there can be multiple different code emitter backends
-// * Most importantly, instruction modifiers can come in aby order, so e.g. both
+// * Most importantly, instruction modifiers can come in aby order, so e.g. both
// `ld.relaxed.global.u32 a, b` and `ld.global.relaxed.u32 a, b` are equally valid. This makes
// classic parsing generators fail: if we tried to generate parsing rules that cover every possible
-// ordering we'd need thousands of rules. This is not a purely theoretical problem. NVCC and Clang
+// ordering we'd need thousands of rules. This is not a purely theoretical problem. NVCC and Clang
// will always emit modifiers in the correct order, but people who write inline assembly usually
// get it wrong (even first party developers)
//
@@ -1398,7 +1398,7 @@ impl std::error::Error for TokenError {}
// * List of rules. They are associated with the preceding patterns (until different opcode or
// different rules). Rules are used to resolve modifiers. There are two types of rules:
// * Normal rule: `.foobar: FoobarEnum => { .a, .b, .c }`. This means that instead of `.foobar` we
-// expecte one of `.a`, `.b`, `.c` and will emit value FoobarEnum::DotA, FoobarEnum::DotB,
+// expecte one of `.a`, `.b`, `.c` and will emit value FoobarEnum::DotA, FoobarEnum::DotB,
// FoobarEnum::DotC appropriately
// * Type-only rule: `FoobarEnum => { .a, .b, .c }` this means that all the occurences of `.a` will
// emit FoobarEnum::DotA to the code block. This helps to avoid copy-paste errors
@@ -3233,36 +3233,42 @@ mod tests {
#[test]
fn sm_11() {
let tokens = Token::lexer(".target sm_11")
- .collect::<Result<Vec<_>, ()>>()
+ .collect::<Result<Vec<_>, _>>()
.unwrap();
+ let mut errors = Vec::new();
let stream = super::PtxParser {
input: &tokens[..],
- state: PtxParserState::new(),
+ state: PtxParserState::new(&mut errors),
};
assert_eq!(target.parse(stream).unwrap(), (11, None));
+ assert_eq!(errors.len(), 0);
}
#[test]
fn sm_90a() {
let tokens = Token::lexer(".target sm_90a")
- .collect::<Result<Vec<_>, ()>>()
+ .collect::<Result<Vec<_>, _>>()
.unwrap();
+ let mut errors = Vec::new();
let stream = super::PtxParser {
input: &tokens[..],
- state: PtxParserState::new(),
+ state: PtxParserState::new(&mut errors),
};
assert_eq!(target.parse(stream).unwrap(), (90, Some('a')));
+ assert_eq!(errors.len(), 0);
}
#[test]
fn sm_90ab() {
let tokens = Token::lexer(".target sm_90ab")
- .collect::<Result<Vec<_>, ()>>()
+ .collect::<Result<Vec<_>, _>>()
.unwrap();
+ let mut errors = Vec::new();
let stream = super::PtxParser {
input: &tokens[..],
- state: PtxParserState::new(),
+ state: PtxParserState::new(&mut errors),
};
assert!(target.parse(stream).is_err());
+ assert_eq!(errors.len(), 0);
}
}