diff options
Diffstat (limited to 'ptx_parser')
-rw-r--r-- | ptx_parser/src/ast.rs | 53 | ||||
-rw-r--r-- | ptx_parser/src/lib.rs | 24 |
2 files changed, 67 insertions, 10 deletions
diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index 65c624e..f0d3fbe 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -4,7 +4,7 @@ use super::{ };
use crate::{PtxError, PtxParserState};
use bitflags::bitflags;
-use std::{cmp::Ordering, num::NonZeroU8};
+use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8};
pub enum Statement<P: Operand> {
Label(P::Ident),
@@ -806,6 +806,32 @@ impl Type { None => Self::maybe_vector_parsed(prefix, scalar),
}
}
+
+ pub fn layout(&self) -> Layout {
+ match self {
+ Type::Scalar(type_) => type_.layout(),
+ Type::Vector(elements, scalar_type) => {
+ let scalar_layout = scalar_type.layout();
+ unsafe {
+ Layout::from_size_align_unchecked(
+ scalar_layout.size() * *elements as usize,
+ scalar_layout.align() * *elements as usize,
+ )
+ }
+ }
+ Type::Array(non_zero, scalar, vec) => {
+ let element_layout = Type::maybe_vector_parsed(*non_zero, *scalar).layout();
+ let len = vec.iter().copied().reduce(std::ops::Mul::mul).unwrap_or(0);
+ unsafe {
+ Layout::from_size_align_unchecked(
+ element_layout.size() * (len as usize),
+ element_layout.align(),
+ )
+ }
+ }
+ Type::Pointer(..) => Layout::new::<usize>(),
+ }
+ }
}
impl ScalarType {
@@ -831,6 +857,31 @@ impl ScalarType { }
}
+ pub fn layout(self) -> Layout {
+ match self {
+ ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => Layout::new::<u8>(),
+ ScalarType::U16
+ | ScalarType::S16
+ | ScalarType::B16
+ | ScalarType::F16
+ | ScalarType::BF16 => Layout::new::<u16>(),
+ ScalarType::U32
+ | ScalarType::S32
+ | ScalarType::B32
+ | ScalarType::F32
+ | ScalarType::U16x2
+ | ScalarType::S16x2
+ | ScalarType::F16x2
+ | ScalarType::BF16x2 => Layout::new::<u32>(),
+ ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => {
+ Layout::new::<u64>()
+ }
+ ScalarType::B128 => unsafe { Layout::from_size_align_unchecked(16, 16) },
+ // Close enough
+ ScalarType::Pred => Layout::new::<u8>(),
+ }
+ }
+
pub fn kind(self) -> ScalarKind {
match self {
ScalarType::U8 => ScalarKind::Unsigned,
diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index fee11aa..b49503b 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1349,10 +1349,10 @@ impl std::error::Error for TokenError {} // * After parsing, each instruction needs to do some early validation and generate a specific, // strongly-typed object. We want strong-typing because we have a single PTX parser frontend, but // there can be multiple different code emitter backends -// * Most importantly, instruction modifiers can come in aby order, so e.g. both +// * Most importantly, instruction modifiers can come in aby order, so e.g. both // `ld.relaxed.global.u32 a, b` and `ld.global.relaxed.u32 a, b` are equally valid. This makes // classic parsing generators fail: if we tried to generate parsing rules that cover every possible -// ordering we'd need thousands of rules. This is not a purely theoretical problem. NVCC and Clang +// ordering we'd need thousands of rules. This is not a purely theoretical problem. NVCC and Clang // will always emit modifiers in the correct order, but people who write inline assembly usually // get it wrong (even first party developers) // @@ -1398,7 +1398,7 @@ impl std::error::Error for TokenError {} // * List of rules. They are associated with the preceding patterns (until different opcode or // different rules). Rules are used to resolve modifiers. There are two types of rules: // * Normal rule: `.foobar: FoobarEnum => { .a, .b, .c }`. This means that instead of `.foobar` we -// expecte one of `.a`, `.b`, `.c` and will emit value FoobarEnum::DotA, FoobarEnum::DotB, +// expecte one of `.a`, `.b`, `.c` and will emit value FoobarEnum::DotA, FoobarEnum::DotB, // FoobarEnum::DotC appropriately // * Type-only rule: `FoobarEnum => { .a, .b, .c }` this means that all the occurences of `.a` will // emit FoobarEnum::DotA to the code block. This helps to avoid copy-paste errors @@ -3233,36 +3233,42 @@ mod tests { #[test] fn sm_11() { let tokens = Token::lexer(".target sm_11") - .collect::<Result<Vec<_>, ()>>() + .collect::<Result<Vec<_>, _>>() .unwrap(); + let mut errors = Vec::new(); let stream = super::PtxParser { input: &tokens[..], - state: PtxParserState::new(), + state: PtxParserState::new(&mut errors), }; assert_eq!(target.parse(stream).unwrap(), (11, None)); + assert_eq!(errors.len(), 0); } #[test] fn sm_90a() { let tokens = Token::lexer(".target sm_90a") - .collect::<Result<Vec<_>, ()>>() + .collect::<Result<Vec<_>, _>>() .unwrap(); + let mut errors = Vec::new(); let stream = super::PtxParser { input: &tokens[..], - state: PtxParserState::new(), + state: PtxParserState::new(&mut errors), }; assert_eq!(target.parse(stream).unwrap(), (90, Some('a'))); + assert_eq!(errors.len(), 0); } #[test] fn sm_90ab() { let tokens = Token::lexer(".target sm_90ab") - .collect::<Result<Vec<_>, ()>>() + .collect::<Result<Vec<_>, _>>() .unwrap(); + let mut errors = Vec::new(); let stream = super::PtxParser { input: &tokens[..], - state: PtxParserState::new(), + state: PtxParserState::new(&mut errors), }; assert!(target.parse(stream).is_err()); + assert_eq!(errors.len(), 0); } } |