aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx_parser/src/ast.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx_parser/src/ast.rs')
-rw-r--r--ptx_parser/src/ast.rs53
1 files changed, 52 insertions, 1 deletions
diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs
index 65c624e..f0d3fbe 100644
--- a/ptx_parser/src/ast.rs
+++ b/ptx_parser/src/ast.rs
@@ -4,7 +4,7 @@ use super::{
};
use crate::{PtxError, PtxParserState};
use bitflags::bitflags;
-use std::{cmp::Ordering, num::NonZeroU8};
+use std::{alloc::Layout, cmp::Ordering, num::NonZeroU8};
pub enum Statement<P: Operand> {
Label(P::Ident),
@@ -806,6 +806,32 @@ impl Type {
None => Self::maybe_vector_parsed(prefix, scalar),
}
}
+
+ pub fn layout(&self) -> Layout {
+ match self {
+ Type::Scalar(type_) => type_.layout(),
+ Type::Vector(elements, scalar_type) => {
+ let scalar_layout = scalar_type.layout();
+ unsafe {
+ Layout::from_size_align_unchecked(
+ scalar_layout.size() * *elements as usize,
+ scalar_layout.align() * *elements as usize,
+ )
+ }
+ }
+ Type::Array(non_zero, scalar, vec) => {
+ let element_layout = Type::maybe_vector_parsed(*non_zero, *scalar).layout();
+ let len = vec.iter().copied().reduce(std::ops::Mul::mul).unwrap_or(0);
+ unsafe {
+ Layout::from_size_align_unchecked(
+ element_layout.size() * (len as usize),
+ element_layout.align(),
+ )
+ }
+ }
+ Type::Pointer(..) => Layout::new::<usize>(),
+ }
+ }
}
impl ScalarType {
@@ -831,6 +857,31 @@ impl ScalarType {
}
}
+ pub fn layout(self) -> Layout {
+ match self {
+ ScalarType::U8 | ScalarType::S8 | ScalarType::B8 => Layout::new::<u8>(),
+ ScalarType::U16
+ | ScalarType::S16
+ | ScalarType::B16
+ | ScalarType::F16
+ | ScalarType::BF16 => Layout::new::<u16>(),
+ ScalarType::U32
+ | ScalarType::S32
+ | ScalarType::B32
+ | ScalarType::F32
+ | ScalarType::U16x2
+ | ScalarType::S16x2
+ | ScalarType::F16x2
+ | ScalarType::BF16x2 => Layout::new::<u32>(),
+ ScalarType::U64 | ScalarType::S64 | ScalarType::B64 | ScalarType::F64 => {
+ Layout::new::<u64>()
+ }
+ ScalarType::B128 => unsafe { Layout::from_size_align_unchecked(16, 16) },
+ // Close enough
+ ScalarType::Pred => Layout::new::<u8>(),
+ }
+ }
+
pub fn kind(self) -> ScalarKind {
match self {
ScalarType::U8 => ScalarKind::Unsigned,