aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src/translate.rs
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src/translate.rs')
-rw-r--r--ptx/src/translate.rs12436
1 files changed, 7093 insertions, 5343 deletions
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 18d750f..041c690 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,1460 +1,2925 @@
-use crate::ast;
-use half::f16;
-use rspirv::dr;
-use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem};
-use std::{
- collections::{hash_map, HashMap, HashSet},
- convert::TryInto,
-};
-
-use rspirv::binary::Assemble;
-
-static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv");
+use crate::llvm::Message;
+use crate::{ast, emit, llvm, raytracing};
+use bit_vec::BitVec;
+use hip_common::raytracing::VariablesBlock;
+use hip_common::{kernel_metadata, CompilationMode};
+use paste::paste;
+pub use raytracing::Module as RaytracingModule;
+use rustc_hash::{FxHashMap, FxHashSet};
+use std::alloc::Layout;
+use std::cell::RefCell;
+use std::collections::{btree_map, hash_map, BTreeMap};
+use std::ffi::{CStr, CString};
+use std::num::NonZeroU32;
+use std::{borrow::Cow, collections::BTreeSet, hash::Hash, iter, mem, rc::Rc};
+use zluda_llvm::bit_writer::*;
+use zluda_llvm::core::LLVMPrintModuleToString;
+
+static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.bc");
+const ZLUDA_PTX_PREFIX: &'static str = "__zluda_ptx_impl__";
+
+macro_rules! derive_error {
+ (enum $type_:ident {
+ $( $variant:ident $(($underlying:ty))? ),+
+ }) => {
+ #[derive(Debug)]
+ pub enum $type_ {
+ $(
+ $variant $(($underlying))? ,
+ )+
+ }
+
+ impl $type_ {
+ $(
+ paste! {
+ #[allow(dead_code)]
+ pub(crate) fn [<$variant:snake>] ( $(x: $underlying)? ) -> Self {
+ let result = Self :: $variant $((x as $underlying))?;
+ if cfg!(debug_assertions) {
+ panic!("{:?}", result);
+ } else {
+ result
+ }
+ }
+ }
+ )+
+ }
-quick_error! {
- #[derive(Debug)]
- pub enum TranslateError {
- UnknownSymbol {}
- UntypedSymbol {}
- MismatchedType {}
- Spirv(err: rspirv::dr::Error) {
- from()
- display("{}", err)
- cause(err)
+ impl std::fmt::Display for $type_ {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
+ write!(f, "{:?}", self)
+ }
}
- Unreachable {}
- Todo {}
}
}
-#[cfg(debug_assertions)]
-fn error_unreachable() -> TranslateError {
- unreachable!()
+derive_error! {
+ enum TranslateError {
+ UnknownSymbol,
+ UntypedSymbol,
+ MismatchedType,
+ LLVM(llvm::Message),
+ Unreachable,
+ Todo,
+ UnexpectedPattern,
+ SymbolRedefinition
+ }
}
-#[cfg(not(debug_assertions))]
-fn error_unreachable() -> TranslateError {
- TranslateError::Unreachable
-}
+impl std::error::Error for TranslateError {}
-#[derive(PartialEq, Eq, Hash, Clone)]
-enum SpirvType {
- Base(SpirvScalarKey),
- Vector(SpirvScalarKey, u8),
- Array(SpirvScalarKey, Vec<u32>),
- Pointer(Box<SpirvType>, spirv::StorageClass),
- Func(Option<Box<SpirvType>>, Vec<SpirvType>),
- Struct(Vec<SpirvScalarKey>),
+pub struct Module<'input> {
+ pub(crate) llvm_module: llvm::Module,
+ pub(crate) _llvm_context: llvm::Context,
+ pub kernel_arguments: FxHashMap<String, Vec<Layout>>,
+ pub bitcode_modules: Vec<&'static [u8]>,
+ pub metadata: Metadata<'input>,
+ pub compilation_mode: CompilationMode,
}
-impl SpirvType {
- fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
- let key = t.into();
- SpirvType::Pointer(Box::new(key), sc)
+impl<'input> Module<'input> {
+ pub fn get_bitcode_main(&self) -> llvm::MemoryBuffer {
+ unsafe {
+ llvm::MemoryBuffer::from_ffi(LLVMWriteBitcodeToMemoryBuffer(self.llvm_module.get()))
+ }
}
-}
-impl From<ast::Type> for SpirvType {
- fn from(t: ast::Type) -> Self {
- match t {
- ast::Type::Scalar(t) => SpirvType::Base(t.into()),
- ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
- ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
- ast::Type::Pointer(pointer_t, state_space) => SpirvType::Pointer(
- Box::new(SpirvType::from(ast::Type::from(pointer_t))),
- state_space.to_spirv(),
- ),
+ pub fn get_llvm_text(&self) -> Message {
+ unsafe { llvm::Message::from_ffi(LLVMPrintModuleToString(self.llvm_module.get())) }
+ }
+
+ pub fn get_bitcode_all<'a>(
+ &'a self,
+ ) -> impl Iterator<Item = (llvm::MemoryBuffer, &'a CStr)> + '_ {
+ unsafe {
+ let main_bc = llvm::MemoryBuffer::from_ffi(LLVMWriteBitcodeToMemoryBuffer(
+ self.llvm_module.get(),
+ ));
+ let main_name = CStr::from_bytes_with_nul_unchecked(b"main\0");
+ iter::once((main_bc, main_name)).chain(self.bitcode_modules.iter().map(|ptx_impl| {
+ (
+ llvm::MemoryBuffer::create_no_copy(ptx_impl, false),
+ CStr::from_bytes_with_nul_unchecked(b"ptx_impl\0"),
+ )
+ }))
}
}
-}
-impl From<ast::PointerType> for ast::Type {
- fn from(t: ast::PointerType) -> Self {
- match t {
- ast::PointerType::Scalar(t) => ast::Type::Scalar(t),
- ast::PointerType::Vector(t, len) => ast::Type::Vector(t, len),
- ast::PointerType::Array(t, dims) => ast::Type::Array(t, dims),
- ast::PointerType::Pointer(t, space) => {
- ast::Type::Pointer(ast::PointerType::Scalar(t), space)
+ pub fn get_bitcode_multi<'a>(
+ mods: impl Iterator<Item = &'a Module<'input>>,
+ ) -> Vec<(llvm::MemoryBuffer, CString)>
+ where
+ 'input: 'a,
+ {
+ unsafe {
+ let mut main_bcs = Vec::new();
+ let mut bitcode_mods = Vec::new();
+ for (idx, mod_) in mods.enumerate() {
+ let main_bc = llvm::MemoryBuffer::from_ffi(LLVMWriteBitcodeToMemoryBuffer(
+ mod_.llvm_module.get(),
+ ));
+ main_bcs.push((
+ main_bc,
+ CString::from_vec_unchecked(format!("main_{}\0", idx).into_bytes()),
+ ));
+ for (sub_idx, bitcode) in mod_.bitcode_modules.iter().enumerate() {
+ bitcode_mods.push((
+ llvm::MemoryBuffer::create_no_copy(bitcode, false),
+ CString::from_vec_unchecked(
+ format!("ptx_impl_{}_{}\0", idx, sub_idx).into_bytes(),
+ ),
+ ));
+ }
}
+ main_bcs.extend(bitcode_mods);
+ main_bcs
}
}
}
-impl ast::Type {
- fn param_pointer_to(self, space: ast::LdStateSpace) -> Result<Self, TranslateError> {
- Ok(match self {
- ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
- ast::Type::Vector(t, len) => {
- ast::Type::Pointer(ast::PointerType::Vector(t, len), space)
- }
- ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
- ast::Type::Pointer(ast::PointerType::Scalar(t), space) => {
- ast::Type::Pointer(ast::PointerType::Pointer(t, space), space)
- }
- ast::Type::Pointer(_, _) => return Err(error_unreachable()),
- })
- }
+pub struct Metadata<'input> {
+ sm_version: u32,
+ kernel_metadata: Vec<(Cow<'input, str>, Option<NonZeroU32>, Option<NonZeroU32>)>,
}
-impl Into<spirv::StorageClass> for ast::PointerStateSpace {
- fn into(self) -> spirv::StorageClass {
- match self {
- ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant,
- ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
- ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup,
- ast::PointerStateSpace::Param => spirv::StorageClass::Function,
- ast::PointerStateSpace::Generic => spirv::StorageClass::Generic,
+impl<'input> Metadata<'input> {
+ pub fn empty() -> Self {
+ Self {
+ sm_version: 0,
+ kernel_metadata: Vec::new(),
}
}
-}
-impl From<ast::ScalarType> for SpirvType {
- fn from(t: ast::ScalarType) -> Self {
- SpirvType::Base(t.into())
+ pub fn join(self, other: &Self) -> Self {
+ let sm_version = self.sm_version.max(other.sm_version);
+ let mut kernel_metadata = self.kernel_metadata;
+ kernel_metadata.extend(other.kernel_metadata.iter().cloned());
+ Self {
+ sm_version,
+ kernel_metadata,
+ }
}
-}
-struct TypeWordMap {
- void: spirv::Word,
- complex: HashMap<SpirvType, spirv::Word>,
- constants: HashMap<(SpirvType, u64), spirv::Word>,
+ pub fn to_elf_section(&self) -> Vec<u8> {
+ let mut result = Vec::new();
+ let metadata = kernel_metadata::zluda::write(
+ self.sm_version,
+ self.kernel_metadata
+ .iter()
+ .map(|(name, min, max)| (&**name, *min, *max)),
+ );
+ emit::emit_section(
+ hip_common::kernel_metadata::zluda::SECTION_STR,
+ &metadata,
+ &mut result,
+ );
+ result
+ }
}
-// SPIR-V integer type definitions are signless, more below:
-// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers
-// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a
-#[derive(PartialEq, Eq, Hash, Clone, Copy)]
-enum SpirvScalarKey {
- B8,
- B16,
- B32,
- B64,
- F16,
- F32,
- F64,
- Pred,
- F16x2,
+pub(crate) struct TranslationModule<'input, P: ast::ArgParams> {
+ pub(crate) sm_version: u32,
+ pub(crate) compilation_mode: CompilationMode,
+ pub(crate) id_defs: IdNameMapBuilder<'input>,
+ pub(crate) ptx_impl_imports: BTreeMap<String, Rc<RefCell<ast::MethodDeclaration<'input, Id>>>>,
+ pub(crate) directives: Vec<TranslationDirective<'input, P>>,
}
-impl From<ast::ScalarType> for SpirvScalarKey {
- fn from(t: ast::ScalarType) -> Self {
- match t {
- ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8,
- ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => {
- SpirvScalarKey::B16
- }
- ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => {
- SpirvScalarKey::B32
- }
- ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => {
- SpirvScalarKey::B64
+impl<'input, P: ast::ArgParams> TranslationModule<'input, P> {
+ fn new(compilation_mode: CompilationMode) -> Self {
+ let id_defs = IdNameMapBuilder::new(IdGenerator::new());
+ let ptx_impl_imports = BTreeMap::new();
+ let directives = Vec::new();
+ Self {
+ compilation_mode,
+ sm_version: 0,
+ id_defs,
+ ptx_impl_imports,
+ directives,
+ }
+ }
+}
+
+// https://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
+// There is a bunch of functions like __assertfail and vprintf that must be
+// replaced by imports from ZLUDA PTX implementation library
+fn extract_builtin_functions<'input>(
+ mut module: TranslationModule<'input, NormalizedArgParams>,
+) -> TranslationModule<'input, NormalizedArgParams> {
+ for directive in module.directives.iter_mut() {
+ if let TranslationDirective::Method(TranslationMethod {
+ source_name: Some(name),
+ body: None,
+ is_kernel: false,
+ ..
+ }) = directive
+ {
+ if is_builtin_function_name(&*name) {
+ *name = Cow::Owned([ZLUDA_PTX_PREFIX, name].concat());
}
- ast::ScalarType::F16 => SpirvScalarKey::F16,
- ast::ScalarType::F32 => SpirvScalarKey::F32,
- ast::ScalarType::F64 => SpirvScalarKey::F64,
- ast::ScalarType::F16x2 => SpirvScalarKey::F16x2,
- ast::ScalarType::Pred => SpirvScalarKey::Pred,
}
}
+ module
}
-impl TypeWordMap {
- fn new(b: &mut dr::Builder) -> TypeWordMap {
- let void = b.type_void();
- TypeWordMap {
- void: void,
- complex: HashMap::<SpirvType, spirv::Word>::new(),
- constants: HashMap::new(),
- }
+fn is_builtin_function_name(name: &str) -> bool {
+ match name {
+ "__assertfail" | "malloc" | "free" | "vprintf" => true,
+ _ => false,
}
+}
- fn void(&self) -> spirv::Word {
- self.void
+// PTX linking rules are fairly convoluted. Here's my understanding:
+// * For normal data (.global, .const) and non-kernel functions (.func)
+// * Symbol occurences must be equivalent under following rules:
+// * For data, symbol occurences must be of the same size. Alignment is ignored
+// * For functions, symbol occurences are strictly type-checked.
+// Number, type and alignmnt of input and return parameters must all match
+// * There are 3 classes of directives:
+// * Declarations. Only valid on functions
+// .func foobar();
+// * Definitions (complete definitions). Either data or functions
+// .func foobar() { ret; }
+// .global foobar .u32;
+// Both .global and .const are *always* initialized. If no explicit
+// initializer is present, they are zero-initialized
+// * Incomplete definitions. Data definitions using incomplete type. Only
+// known incomplete type is an array with at least one zero dimension
+// .extern .global foobar .b8 [];
+// Incomplete definitions *must* have an .extern linking specifier
+// * There can be only one definition (normal or incomplete) of a symbol
+// in a module
+// * There can be multiple declarations of a symbol in a module
+// * Declarations must all be the same: same linking specifier,
+// same argument list, same return list
+// * Data (.global and .const) is alwas accessible by cuModuleGetGlobal(...),
+// no matter the linking specifier. So this definition:
+// .global .u32 foobar1[1] = {1}
+// is not visible to othe modules during linking, but is accessible
+// by CUDA runtime after linking
+// * Non-kernel functions are never accessible by cuModuleGetGlobal(...)
+// * There are are four linking specifiers:
+// * (empty): static linking, a symbol is only visible inside a module
+// * For functions; separate functions with the same name in multiple, separate modules behave as expected
+// * For data, compiler selects the first symbol with the given name for access from cuModuleGetGlobal(...)
+// * This is only allowed linking specifier for local globals (globals defined inside function body),
+// which are not visible through cuModuleGetGlobal(...)
+// * .extern: means that symbol is completely-defined strictly in another module.
+// If the same symbol is completely-defined in the same module it's an error
+// It's legal to not resolve the declaration if it's unused
+// .extern is legal for:
+// * declarations and incomplete definitions
+// * normal definitions if all are true:
+// * it's a data definition
+// * it's a non-linking compilation
+// * initializer is not present
+// * .visible: symbol is strong (overrides .weak) and globally visible.
+// Multiple .visible symbol occurences during linking compilation are illegal
+// * .weak: symbol is weak and globally visible.
+// If there's no strong symbol occurence, first weak symbol occurence gets selected
+// * .common: symbol is strong (overrides .weak) and globally visible with some additional rules:
+// * applies only to .global
+// * selects the first occurence from the largest symbol occurences
+// * explicit initializer is only allowed on symbol occurences with the largest size
+fn resolve_linking<'a, 'input>(
+ ast_modules: &'a [ast::Module<'input>],
+ is_raytracing: bool,
+) -> Result<ResolvedLinking<'input>, TranslateError> {
+ let mut resolver = LinkingResolver::new(is_raytracing);
+ for ast_module in ast_modules {
+ resolver.start_module()?;
+ for (index, directive) in ast_module.directives.iter().enumerate() {
+ match directive {
+ ast::Directive::Variable(linking, multivar) => {
+ resolver.on_data(
+ index,
+ *linking,
+ Cow::Borrowed(multivar.variable.name),
+ multivar.variable.state_space,
+ &multivar.variable.type_,
+ )?;
+ }
+ ast::Directive::Method(linking, method) => {
+ resolver.on_function(
+ index,
+ method.body.is_some(),
+ *linking,
+ &method.func_directive,
+ )?;
+ }
+ }
+ }
}
+ resolver.close()
+}
- fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word {
- let key: SpirvScalarKey = t.into();
- self.get_or_add_spirv_scalar(b, key)
+struct LinkingResolver<'a, 'input> {
+ explicit_globals: FxHashMap<Cow<'input, str>, SymbolState<'a, 'input>>,
+ implicit_globals: FxHashMap<Cow<'input, str>, (usize, usize)>,
+ local_definitions: LocalDirectives<'a, 'input>,
+ module_index: usize,
+ is_raytracing: bool,
+}
+
+impl<'a, 'input> LinkingResolver<'a, 'input> {
+ fn new(is_raytracing: bool) -> Self {
+ Self {
+ explicit_globals: FxHashMap::default(),
+ implicit_globals: FxHashMap::default(),
+ local_definitions: LocalDirectives::new(),
+ module_index: 0,
+ is_raytracing,
+ }
}
- fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> spirv::Word {
- *self
- .complex
- .entry(SpirvType::Base(key))
- .or_insert_with(|| match key {
- SpirvScalarKey::B8 => b.type_int(8, 0),
- SpirvScalarKey::B16 => b.type_int(16, 0),
- SpirvScalarKey::B32 => b.type_int(32, 0),
- SpirvScalarKey::B64 => b.type_int(64, 0),
- SpirvScalarKey::F16 => b.type_float(16),
- SpirvScalarKey::F32 => b.type_float(32),
- SpirvScalarKey::F64 => b.type_float(64),
- SpirvScalarKey::Pred => b.type_bool(),
- SpirvScalarKey::F16x2 => todo!(),
- })
+ fn start_module(&mut self) -> Result<(), TranslateError> {
+ self.module_index += 1;
+ mem::replace(&mut self.local_definitions, LocalDirectives::new()).check()
}
- fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word {
- match t {
- SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key),
- SpirvType::Pointer(ref typ, storage) => {
- let base = self.get_or_add(b, *typ.clone());
- *self
- .complex
- .entry(t)
- .or_insert_with(|| b.type_pointer(None, storage, base))
- }
- SpirvType::Vector(typ, len) => {
- let base = self.get_or_add_spirv_scalar(b, typ);
- *self
- .complex
- .entry(t)
- .or_insert_with(|| b.type_vector(base, len as u32))
- }
- SpirvType::Array(typ, array_dimensions) => {
- let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
- let (base_type, length) = match &*array_dimensions {
- &[len] => {
- let base = self.get_or_add_spirv_scalar(b, typ);
- let len_const = b.constant_u32(u32_type, None, len);
- (base, len_const)
- }
- array_dimensions => {
- let base = self
- .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
- let len_const = b.constant_u32(u32_type, None, array_dimensions[0]);
- (base, len_const)
- }
- };
- *self
- .complex
- .entry(SpirvType::Array(typ, array_dimensions))
- .or_insert_with(|| b.type_array(base_type, length))
- }
- SpirvType::Func(ref out_params, ref in_params) => {
- let out_t = match out_params {
- Some(p) => self.get_or_add(b, *p.clone()),
- None => self.void(),
- };
- let in_t = in_params
- .iter()
- .map(|t| self.get_or_add(b, t.clone()))
- .collect::<Vec<_>>();
- *self
- .complex
- .entry(t)
- .or_insert_with(|| b.type_function(out_t, in_t))
- }
- SpirvType::Struct(ref underlying) => {
- let underlying_ids = underlying
- .iter()
- .map(|t| self.get_or_add_spirv_scalar(b, *t))
- .collect::<Vec<_>>();
- *self
- .complex
- .entry(t)
- .or_insert_with(|| b.type_struct(underlying_ids))
- }
- }
+ fn on_data(
+ &mut self,
+ location: usize,
+ linking: ast::LinkingDirective,
+ name: Cow<'input, str>,
+ space: ast::StateSpace,
+ type_: &ast::Type,
+ ) -> Result<(), TranslateError> {
+ if linking == ast::LinkingDirective::Common && space != ast::StateSpace::Global {
+ return Err(TranslateError::SymbolRedefinition);
+ }
+ self.local_definitions.on_data(name.clone())?;
+ let symbol = GlobalSymbol::Data {
+ size: type_.layout().size(),
+ space,
+ type_: type_.clone(),
+ };
+ self.update_global_symbol(
+ location,
+ linking,
+ name,
+ symbol,
+ space != ast::StateSpace::Shared,
+ )
}
- fn get_or_add_fn(
+ fn on_function(
&mut self,
- b: &mut dr::Builder,
- in_params: impl ExactSizeIterator<Item = SpirvType>,
- mut out_params: impl ExactSizeIterator<Item = SpirvType>,
- ) -> (spirv::Word, spirv::Word) {
- let (out_args, out_spirv_type) = if out_params.len() == 0 {
- (None, self.void())
- } else if out_params.len() == 1 {
- let arg_as_key = out_params.next().unwrap();
- (
- Some(Box::new(arg_as_key.clone())),
- self.get_or_add(b, arg_as_key),
- )
+ location: usize,
+ is_definition: bool,
+ linking: ast::LinkingDirective,
+ decl: &'a ast::MethodDeclaration<'input, &'input str>,
+ ) -> Result<(), TranslateError> {
+ // Common is legal only on .global
+ if linking == ast::LinkingDirective::Common {
+ return Err(TranslateError::SymbolRedefinition);
+ }
+ if is_definition {
+ self.local_definitions
+ .on_function_definition(linking, decl)?;
} else {
- todo!()
+ self.local_definitions
+ .on_function_declaration(linking, decl)?;
+ }
+ let symbol = GlobalSymbol::Method {
+ kernel: decl.name.is_kernel(),
+ declaration: !is_definition,
+ return_arguments: &decl.return_arguments,
+ input_arguments: &decl.input_arguments,
};
- (
- out_spirv_type,
- self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::<Vec<_>>())),
+ self.update_global_symbol(
+ location,
+ linking,
+ Cow::Borrowed(decl.name()),
+ symbol,
+ decl.name.is_kernel(),
)
}
- fn get_or_add_constant(
+ fn update_global_symbol(
&mut self,
- b: &mut dr::Builder,
- typ: &ast::Type,
- init: &[u8],
- ) -> Result<spirv::Word, TranslateError> {
- Ok(match typ {
- ast::Type::Scalar(t) => match t {
- ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self
- .get_or_add_constant_single::<u8, _, _>(
- b,
- *t,
- init,
- |v| v as u64,
- |b, result_type, v| b.constant_u32(result_type, None, v as u32),
- ),
- ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self
- .get_or_add_constant_single::<u16, _, _>(
- b,
- *t,
- init,
- |v| v as u64,
- |b, result_type, v| b.constant_u32(result_type, None, v as u32),
- ),
- ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self
- .get_or_add_constant_single::<u32, _, _>(
- b,
- *t,
- init,
- |v| v as u64,
- |b, result_type, v| b.constant_u32(result_type, None, v),
- ),
- ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self
- .get_or_add_constant_single::<u64, _, _>(
- b,
- *t,
- init,
- |v| v,
- |b, result_type, v| b.constant_u64(result_type, None, v),
- ),
- ast::ScalarType::F16 => self.get_or_add_constant_single::<f16, _, _>(
- b,
- *t,
- init,
- |v| unsafe { mem::transmute::<_, u16>(v) } as u64,
- |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()),
- ),
- ast::ScalarType::F32 => self.get_or_add_constant_single::<f32, _, _>(
- b,
- *t,
- init,
- |v| unsafe { mem::transmute::<_, u32>(v) } as u64,
- |b, result_type, v| b.constant_f32(result_type, None, v),
- ),
- ast::ScalarType::F64 => self.get_or_add_constant_single::<f64, _, _>(
- b,
- *t,
- init,
- |v| unsafe { mem::transmute::<_, u64>(v) },
- |b, result_type, v| b.constant_f64(result_type, None, v),
- ),
- ast::ScalarType::F16x2 => return Err(TranslateError::Todo),
- ast::ScalarType::Pred => self.get_or_add_constant_single::<u8, _, _>(
- b,
- *t,
- init,
- |v| v as u64,
- |b, result_type, v| {
- if v == 0 {
- b.constant_false(result_type, None)
+ new_location: usize,
+ new_linking: ast::LinkingDirective,
+ name: Cow<'input, str>,
+ new_symbol: GlobalSymbol<'a, 'input>,
+ implicit_global: bool,
+ ) -> Result<(), TranslateError> {
+ if new_linking == ast::LinkingDirective::None {
+ if implicit_global {
+ let will_be_shadowed = if let Some(global) = self.explicit_globals.get(&name) {
+ match global.symbol {
+ GlobalSymbol::Data { .. }
+ | GlobalSymbol::Method {
+ declaration: false, ..
+ } => true,
+ GlobalSymbol::Method {
+ declaration: true, ..
+ } => false,
+ }
+ } else {
+ false
+ };
+ if !will_be_shadowed {
+ if let hash_map::Entry::Vacant(entry) = self.implicit_globals.entry(name) {
+ entry.insert((self.module_index, new_location));
+ }
+ }
+ }
+ return Ok(());
+ }
+ let is_function_declaration = matches!(
+ new_symbol,
+ GlobalSymbol::Method {
+ declaration: true,
+ ..
+ }
+ );
+ if !is_function_declaration {
+ self.implicit_globals.remove(&name);
+ }
+ match self.explicit_globals.entry(name) {
+ hash_map::Entry::Occupied(mut entry) => {
+ let SymbolState {
+ module,
+ location,
+ linking,
+ symbol,
+ } = entry.get_mut();
+ let override_global = match (new_linking, *linking) {
+ (ast::LinkingDirective::None, _) | (_, ast::LinkingDirective::None) => {
+ return Err(TranslateError::unreachable())
+ }
+ (ast::LinkingDirective::Extern, _) => false,
+ (ast::LinkingDirective::Common, ast::LinkingDirective::Visible)
+ | (ast::LinkingDirective::Visible, ast::LinkingDirective::Common) => {
+ return Err(TranslateError::SymbolRedefinition);
+ }
+ (ast::LinkingDirective::Visible, ast::LinkingDirective::Visible) => {
+ // If it is in another module
+ if *module != self.module_index {
+ return Err(TranslateError::SymbolRedefinition);
} else {
- b.constant_true(result_type, None)
+ !is_function_declaration
}
+ }
+ (
+ ast::LinkingDirective::Visible | ast::LinkingDirective::Common,
+ ast::LinkingDirective::Weak | ast::LinkingDirective::Extern,
+ ) => true,
+ (ast::LinkingDirective::Common, ast::LinkingDirective::Common) => {
+ if let (
+ GlobalSymbol::Data {
+ size,
+ space: ast::StateSpace::Global,
+ type_,
+ },
+ GlobalSymbol::Data {
+ size: new_size,
+ space: ast::StateSpace::Global,
+ type_: new_type,
+ },
+ ) = (symbol, new_symbol)
+ {
+ if new_size > *size {
+ *type_ = new_type;
+ *size = new_size;
+ *module = self.module_index;
+ *location = new_location;
+ *linking = new_linking;
+ }
+ return Ok(());
+ } else {
+ return Err(TranslateError::SymbolRedefinition);
+ }
+ }
+ (ast::LinkingDirective::Weak, ast::LinkingDirective::Extern) => true,
+ (ast::LinkingDirective::Weak, ast::LinkingDirective::Visible)
+ | (ast::LinkingDirective::Weak, ast::LinkingDirective::Common) => false,
+ (ast::LinkingDirective::Weak, ast::LinkingDirective::Weak) => match symbol {
+ GlobalSymbol::Method {
+ declaration: true, ..
+ } => {
+ if let GlobalSymbol::Method {
+ declaration: false, ..
+ } = new_symbol
+ {
+ true
+ } else {
+ false
+ }
+ }
+ _ => false,
},
- ),
- },
- ast::Type::Vector(typ, len) => {
- let result_type =
- self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len));
- let size_of_t = typ.size_of();
- let components = (0..*len)
- .map(|x| {
- self.get_or_add_constant(
- b,
- &ast::Type::Scalar(*typ),
- &init[((size_of_t as usize) * (x as usize))..],
- )
- })
- .collect::<Result<Vec<_>, _>>()?;
- b.constant_composite(result_type, None, &components)
- }
- ast::Type::Array(typ, dims) => match dims.as_slice() {
- [] => return Err(error_unreachable()),
- [dim] => {
- let result_type = self
- .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim]));
- let size_of_t = typ.size_of();
- let components = (0..*dim)
- .map(|x| {
- self.get_or_add_constant(
- b,
- &ast::Type::Scalar(*typ),
- &init[((size_of_t as usize) * (x as usize))..],
- )
- })
- .collect::<Result<Vec<_>, _>>()?;
- b.constant_composite(result_type, None, &components)
+ };
+ if !new_symbol.is_compatible(symbol) {
+ return Err(TranslateError::SymbolRedefinition);
}
- [first_dim, rest @ ..] => {
- let result_type = self.get_or_add(
- b,
- SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()),
- );
- let size_of_t = rest
- .iter()
- .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y));
- let components = (0..*first_dim)
- .map(|x| {
- self.get_or_add_constant(
- b,
- &ast::Type::Array(*typ, rest.to_vec()),
- &init[((size_of_t as usize) * (x as usize))..],
- )
- })
- .collect::<Result<Vec<_>, _>>()?;
- b.constant_composite(result_type, None, &components)
+ if override_global {
+ *symbol = new_symbol;
+ *module = self.module_index;
+ *location = new_location;
+ *linking = new_linking;
}
- },
- ast::Type::Pointer(typ, state_space) => {
- let base_t = typ.clone().into();
- let base = self.get_or_add_constant(b, &base_t, &[])?;
- let result_type = self.get_or_add(
- b,
- SpirvType::Pointer(
- Box::new(SpirvType::from(base_t)),
- (*state_space).to_spirv(),
- ),
- );
- b.variable(result_type, None, (*state_space).to_spirv(), Some(base))
}
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(SymbolState {
+ module: self.module_index,
+ location: new_location,
+ linking: new_linking,
+ symbol: new_symbol,
+ });
+ }
+ }
+ Ok(())
+ }
+
+ fn close(self) -> Result<ResolvedLinking<'input>, TranslateError> {
+ self.local_definitions.check()?;
+ for (_, state) in self.explicit_globals.iter() {
+ if state.linking == ast::LinkingDirective::Extern {
+ match state.symbol {
+ GlobalSymbol::Data {
+ space: ast::StateSpace::Shared,
+ ..
+ }
+ | GlobalSymbol::Method { .. } => {}
+ GlobalSymbol::Data { size, .. } if size != 0 && self.module_index == 1 => {}
+ _ => return Err(TranslateError::SymbolRedefinition),
+ }
+ } else if !self.is_raytracing {
+ if matches!(
+ state.symbol,
+ GlobalSymbol::Method {
+ declaration: true,
+ ..
+ }
+ ) {
+ return Err(TranslateError::SymbolRedefinition);
+ }
+ }
+ }
+ let explicit_globals = self
+ .explicit_globals
+ .into_iter()
+ .map(|(name, symbol)| {
+ let type_ = match symbol.symbol {
+ GlobalSymbol::Data { type_, .. } => Some(type_),
+ GlobalSymbol::Method { .. } => None,
+ };
+ (name, (symbol.module, symbol.location, type_))
+ })
+ .collect();
+ Ok(ResolvedLinking {
+ explicit_globals,
+ implicit_globals: self.implicit_globals,
})
}
+}
- fn get_or_add_constant_single<
- T: Copy,
- CastAsU64: FnOnce(T) -> u64,
- InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word,
- >(
+struct ResolvedLinking<'input> {
+ explicit_globals: FxHashMap<Cow<'input, str>, (usize, usize, Option<ast::Type>)>,
+ implicit_globals: FxHashMap<Cow<'input, str>, (usize, usize)>,
+}
+
+impl<'input> ResolvedLinking<'input> {
+ fn get_adjustment(
&mut self,
- b: &mut dr::Builder,
- key: ast::ScalarType,
- init: &[u8],
- cast: CastAsU64,
- f: InsertConstant,
- ) -> spirv::Word {
- let value = unsafe { *(init.as_ptr() as *const T) };
- let value_64 = cast(value);
- let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64);
- match self.constants.get(&ht_key) {
- Some(value) => *value,
- None => {
- let spirv_type = self.get_or_add_scalar(b, key);
- let result = f(b, spirv_type, value);
- self.constants.insert(ht_key, result);
- result
+ module: usize,
+ directive: usize,
+ name: Cow<'input, str>,
+ linking: ast::LinkingDirective,
+ explicit_initializer: bool,
+ ) -> Result<VisibilityAdjustment, TranslateError> {
+ if linking == ast::LinkingDirective::None {
+ if self.implicit_globals.get(&name).copied() == Some((module, directive)) {
+ Ok(VisibilityAdjustment::Global)
+ } else {
+ Ok(VisibilityAdjustment::Module)
+ }
+ } else {
+ if let Some((global_module, global_directive, type_)) = self.explicit_globals.get(&name)
+ {
+ if module == *global_module && directive == *global_directive {
+ Ok(VisibilityAdjustment::Global)
+ } else {
+ match linking {
+ ast::LinkingDirective::Extern
+ | ast::LinkingDirective::Weak
+ // Visible is possible and valid in case of function same-module declarations
+ | ast::LinkingDirective::Visible => {
+ Ok(VisibilityAdjustment::GlobalDeclaration(type_.clone()))
+ }
+ ast::LinkingDirective::Common => {
+ if explicit_initializer {
+ return Err(TranslateError::SymbolRedefinition);
+ }
+ Ok(VisibilityAdjustment::GlobalDeclaration(type_.clone()))
+ }
+ ast::LinkingDirective::None => {
+ Err(TranslateError::unreachable())
+ }
+ }
+ }
+ } else {
+ Err(TranslateError::unreachable())
}
}
}
}
-pub struct Module {
- pub spirv: dr::Module,
- pub kernel_info: HashMap<String, KernelInfo>,
- pub should_link_ptx_impl: Option<&'static [u8]>,
- pub build_options: CString,
+enum VisibilityAdjustment {
+ Global,
+ Module,
+ GlobalDeclaration(Option<ast::Type>),
}
-impl Module {
- pub fn assemble(&self) -> Vec<u32> {
- self.spirv.assemble()
+
+struct LocalDirectives<'a, 'input> {
+ directives: FxHashMap<Cow<'input, str>, LocalSymbol<'a, 'input>>,
+}
+
+impl<'a, 'input> LocalDirectives<'a, 'input> {
+ fn new() -> Self {
+ Self {
+ directives: FxHashMap::default(),
+ }
+ }
+
+ fn on_data(&mut self, name: Cow<'input, str>) -> Result<(), TranslateError> {
+ match self.directives.entry(name) {
+ hash_map::Entry::Occupied(_) => return Err(TranslateError::SymbolRedefinition),
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(LocalSymbol::Data);
+ }
+ }
+ Ok(())
+ }
+
+ fn on_function_definition(
+ &mut self,
+ decl_linking: ast::LinkingDirective,
+ decl: &'a ast::MethodDeclaration<'input, &'input str>,
+ ) -> Result<(), TranslateError> {
+ match self.directives.entry(Cow::Borrowed(decl.name())) {
+ hash_map::Entry::Occupied(mut entry) => match entry.get_mut() {
+ LocalSymbol::Data
+ | LocalSymbol::Function {
+ has_definition: true,
+ ..
+ } => return Err(TranslateError::SymbolRedefinition),
+ LocalSymbol::Function {
+ kernel,
+ ref mut has_definition,
+ return_arguments,
+ input_arguments,
+ linking,
+ } => {
+ if *kernel == decl.name.is_kernel() && decl_linking != *linking
+ || !is_variable_list_equivalent(&*decl.return_arguments, return_arguments)
+ || !is_variable_list_equivalent(&*decl.input_arguments, *input_arguments)
+ {
+ return Err(TranslateError::SymbolRedefinition);
+ }
+ *has_definition = true;
+ }
+ },
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(LocalSymbol::Function {
+ kernel: decl.name.is_kernel(),
+ has_definition: true,
+ linking: decl_linking,
+ return_arguments: &decl.return_arguments,
+ input_arguments: &decl.input_arguments,
+ });
+ }
+ }
+ Ok(())
+ }
+
+ fn on_function_declaration(
+ &mut self,
+ decl_linking: ast::LinkingDirective,
+ decl: &'a ast::MethodDeclaration<'input, &'input str>,
+ ) -> Result<(), TranslateError> {
+ match self.directives.entry(Cow::Borrowed(decl.name())) {
+ hash_map::Entry::Occupied(entry) => match entry.get() {
+ LocalSymbol::Data => return Err(TranslateError::SymbolRedefinition),
+ LocalSymbol::Function {
+ kernel,
+ has_definition: _,
+ linking,
+ return_arguments,
+ input_arguments,
+ } => {
+ if *kernel == decl.name.is_kernel() && *linking != decl_linking
+ || !is_variable_list_equivalent(&*decl.return_arguments, return_arguments)
+ || !is_variable_list_equivalent(&*decl.input_arguments, *input_arguments)
+ {
+ return Err(TranslateError::SymbolRedefinition);
+ }
+ }
+ },
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(LocalSymbol::Function {
+ kernel: decl.name.is_kernel(),
+ has_definition: false,
+ linking: decl_linking,
+ return_arguments: &decl.return_arguments,
+ input_arguments: &decl.input_arguments,
+ });
+ }
+ }
+ Ok(())
+ }
+
+ // At a first glance this looks incomplete, but:
+ // * Unresolved declarations at the global level are checked later,
+ // when we have symbols from all the modules
+ // * We don't check unresolved data with incomplete definitions, because
+ // they are invalid anyway, if data is incomplete it must be extern
+ // and hence checked at the global level
+ fn check(self) -> Result<(), TranslateError> {
+ for (_, symbol) in self.directives {
+ match symbol {
+ LocalSymbol::Data => {}
+ LocalSymbol::Function {
+ has_definition,
+ linking,
+ ..
+ } => {
+ if linking == ast::LinkingDirective::None && !has_definition {
+ return Err(TranslateError::SymbolRedefinition);
+ }
+ }
+ }
+ }
+ Ok(())
}
}
-pub struct KernelInfo {
- pub arguments_sizes: Vec<usize>,
- pub uses_shared_mem: bool,
+// Used to type-check declarations inside a module
+enum LocalSymbol<'a, 'input> {
+ Data,
+ Function {
+ kernel: bool,
+ has_definition: bool,
+ linking: ast::LinkingDirective,
+ return_arguments: &'a [ast::VariableDeclaration<&'input str>],
+ input_arguments: &'a [ast::VariableDeclaration<&'input str>],
+ },
}
-pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateError> {
- let mut id_defs = GlobalStringIdResolver::new(1);
- let mut ptx_impl_imports = HashMap::new();
- let directives = ast
- .directives
- .into_iter()
- .filter_map(|directive| {
- translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose()
- })
- .collect::<Result<Vec<_>, _>>()?;
- let must_link_ptx_impl = ptx_impl_imports.len() > 0;
- let directives = ptx_impl_imports
- .into_iter()
- .map(|(_, v)| v)
- .chain(directives.into_iter())
- .collect::<Vec<_>>();
- let mut builder = dr::Builder::new();
- builder.reserve_ids(id_defs.current_id());
- let call_map = get_call_map(&directives);
- let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
- normalize_variable_decls(&mut directives);
- let denorm_information = compute_denorm_information(&directives);
- // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
- builder.set_version(1, 3);
- emit_capabilities(&mut builder);
- emit_extensions(&mut builder);
- let opencl_id = emit_opencl_import(&mut builder);
- emit_memory_model(&mut builder);
- let mut map = TypeWordMap::new(&mut builder);
- emit_builtins(&mut builder, &mut map, &id_defs);
- let mut kernel_info = HashMap::new();
- let build_options = emit_denorm_build_string(&call_map, &denorm_information);
- emit_directives(
- &mut builder,
- &mut map,
- &id_defs,
- opencl_id,
- &denorm_information,
- &call_map,
- directives,
- &mut kernel_info,
- )?;
- let spirv = builder.module();
- Ok(Module {
- spirv,
- kernel_info,
- should_link_ptx_impl: if must_link_ptx_impl {
- Some(ZLUDA_PTX_IMPL)
- } else {
- None
- },
- build_options,
- })
+struct SymbolState<'a, 'input> {
+ module: usize,
+ location: usize,
+ linking: ast::LinkingDirective,
+ symbol: GlobalSymbol<'a, 'input>,
}
-// TODO: remove this once we have perf-function support for denorms
-fn emit_denorm_build_string(
- call_map: &HashMap<&str, HashSet<u32>>,
- denorm_information: &HashMap<MethodName, HashMap<u8, (spirv::FPDenormMode, isize)>>,
-) -> CString {
- let denorm_counts = denorm_information
- .iter()
- .map(|(method, meth_denorm)| {
- let f16_count = meth_denorm
- .get(&(mem::size_of::<f16>() as u8))
- .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
- .1;
- let f32_count = meth_denorm
- .get(&(mem::size_of::<f32>() as u8))
- .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0))
- .1;
- (method, (f16_count + f32_count))
- })
- .collect::<HashMap<_, _>>();
- let mut flush_over_preserve = 0;
- for (kernel, children) in call_map {
- flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0);
- for child_fn in children {
- flush_over_preserve += *denorm_counts
- .get(&MethodName::Func(*child_fn))
- .unwrap_or(&0);
- }
- }
- if flush_over_preserve > 0 {
- CString::new("-cl-denorms-are-zero").unwrap()
- } else {
- CString::default()
+enum GlobalSymbol<'a, 'input> {
+ Data {
+ size: usize,
+ space: ast::StateSpace,
+ type_: ast::Type,
+ },
+ Method {
+ kernel: bool,
+ declaration: bool,
+ return_arguments: &'a [ast::VariableDeclaration<&'input str>],
+ input_arguments: &'a [ast::VariableDeclaration<&'input str>],
+ },
+}
+
+impl<'a, 'input> GlobalSymbol<'a, 'input> {
+ fn is_compatible(&self, old_symbol: &GlobalSymbol<'a, 'input>) -> bool {
+ match (self, old_symbol) {
+ (
+ GlobalSymbol::Data {
+ size,
+ space,
+ type_: _,
+ },
+ GlobalSymbol::Data {
+ size: old_size,
+ space: old_space,
+ type_: _,
+ },
+ ) => (*size == *old_size || *old_size == 0 || *size == 0) && (space == old_space),
+ (
+ GlobalSymbol::Method {
+ kernel,
+ declaration: _,
+ return_arguments,
+ input_arguments,
+ },
+ GlobalSymbol::Method {
+ kernel: old_kernel,
+ declaration: _,
+ return_arguments: old_return_arguments,
+ input_arguments: old_input_arguments,
+ },
+ ) => {
+ *kernel == *old_kernel
+ && is_variable_list_equivalent(return_arguments, old_return_arguments)
+ && is_variable_list_equivalent(input_arguments, old_input_arguments)
+ }
+ _ => false,
+ }
}
}
-fn emit_directives<'input>(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- id_defs: &GlobalStringIdResolver<'input>,
- opencl_id: spirv::Word,
- denorm_information: &HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
- call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
- directives: Vec<Directive>,
- kernel_info: &mut HashMap<String, KernelInfo>,
-) -> Result<(), TranslateError> {
- let empty_body = Vec::new();
- for d in directives.iter() {
- match d {
- Directive::Variable(var) => {
- emit_variable(builder, map, &var)?;
- }
- Directive::Method(f) => {
- let f_body = match &f.body {
- Some(f) => f,
- None => {
- if f.import_as.is_some() {
- &empty_body
- } else {
- continue;
- }
- }
- };
- for var in f.globals.iter() {
- emit_variable(builder, map, var)?;
- }
- emit_function_header(
- builder,
- map,
- &id_defs,
- &f.globals,
- &f.spirv_decl,
- &denorm_information,
- call_map,
- &directives,
- kernel_info,
- )?;
- emit_function_body_ops(builder, map, opencl_id, &f_body)?;
- builder.end_function()?;
- if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) =
- (&f.func_decl, &f.import_as)
- {
- builder.decorate(
- *fn_id,
- spirv::Decoration::LinkageAttributes,
- &[
- dr::Operand::LiteralString(name.clone()),
- dr::Operand::LinkageType(spirv::LinkageType::Import),
- ],
- );
+fn is_variable_list_equivalent<'a>(
+ left: &[ast::VariableDeclaration<&'a str>],
+ right: &[ast::VariableDeclaration<&'a str>],
+) -> bool {
+ fn equivalent_arguments<'a>(
+ ast::VariableDeclaration {
+ type_: l_type_,
+ state_space: l_state_space,
+ align: _,
+ name: _,
+ }: &ast::VariableDeclaration<&'a str>,
+ ast::VariableDeclaration {
+ type_: r_type_,
+ state_space: r_state_space,
+ align: _,
+ name: _,
+ }: &ast::VariableDeclaration<&'a str>,
+ ) -> bool {
+ l_type_ == r_type_ && l_state_space == r_state_space
+ }
+ let mut left = left.iter();
+ let mut right = right.iter();
+ loop {
+ match (left.next(), right.next()) {
+ (None, None) => break,
+ (None, Some(_)) => return false,
+ (Some(_), None) => return false,
+ (Some(left), Some(right)) => {
+ if !equivalent_arguments(left, right) {
+ return false;
}
}
}
}
- Ok(())
+ true
}
-fn get_call_map<'input>(
- module: &[Directive<'input>],
-) -> HashMap<&'input str, HashSet<spirv::Word>> {
- let mut directly_called_by = HashMap::new();
- for directive in module {
- match directive {
- Directive::Method(Function {
- func_decl,
- body: Some(statements),
- ..
- }) => {
- let call_key = MethodName::new(&func_decl);
- if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
- entry.insert(Vec::new());
+// This is actually three transformations in one:
+// * Merge modules (linking-aware)
+// * Replace all string identifiers with numeric identifiers
+// * Convert predicates to branches
+// After those two conversions we can start inserting and removing additional
+// instructions freely
+fn link_and_normalize_modules<'input>(
+ asts: Vec<ast::Module<'input>>,
+ module: TranslationModule<'input, NormalizedArgParams>,
+ mut linking_resolver: ResolvedLinking<'input>,
+) -> Result<
+ (
+ TranslationModule<'input, NormalizedArgParams>,
+ FxHashMap<
+ Id,
+ (
+ Vec<ast::VariableDeclaration<Id>>,
+ Vec<ast::VariableDeclaration<Id>>,
+ ),
+ >,
+ ),
+ TranslateError,
+> {
+ let mut functions = get_existing_methods(&*module.directives);
+ let mut id_defs = module.id_defs;
+ let ptx_impl_imports = module.ptx_impl_imports;
+ let mut directives = module.directives;
+ let mut sm_version = 0;
+ let mut string_resolver = StringIdResolver::new(&mut id_defs, &directives)?;
+ for (mut module_index, ast) in asts.into_iter().enumerate() {
+ module_index += 1;
+ sm_version = sm_version.max(ast.sm_version);
+ let mut module_scope = string_resolver.start_module();
+ for (directive_index, directive) in ast.directives.into_iter().enumerate() {
+ match directive {
+ ast::Directive::Method(linking_directive, method) => {
+ directives.push(TranslationDirective::Method(normalize_method(
+ &mut linking_resolver,
+ &mut functions,
+ (module_index, directive_index),
+ &mut module_scope,
+ linking_directive,
+ method,
+ )?));
}
- for statement in statements {
- match statement {
- Statement::Call(call) => {
- multi_hash_map_append(&mut directly_called_by, call_key, call.func);
- }
- _ => {}
- }
+ ast::Directive::Variable(mut linking_directive, vars) => {
+ expand_multivariable2(
+ &mut module_scope,
+ iter::once(vars),
+ |scope, align, type_, space, name, mut initializer| {
+ let linking_adjustment = linking_resolver.get_adjustment(
+ module_index,
+ directive_index,
+ name.clone(),
+ linking_directive,
+ initializer.is_some(),
+ )?;
+ let (has_global_name, has_body, type_override) =
+ match linking_adjustment {
+ VisibilityAdjustment::Global => (true, true, None),
+ VisibilityAdjustment::Module => (false, true, None),
+ VisibilityAdjustment::GlobalDeclaration(type_override) => {
+ (true, false, type_override)
+ }
+ };
+ let type_ = type_override.unwrap_or_else(|| type_.clone());
+ let compiled_name = if has_global_name {
+ Some(name.clone())
+ } else {
+ None
+ };
+ if !has_body {
+ linking_directive = ast::LinkingDirective::Extern;
+ initializer = None;
+ }
+ directives.push(TranslationDirective::Variable(
+ linking_directive,
+ compiled_name,
+ scope.add_or_get_module_variable(
+ name,
+ has_global_name,
+ type_,
+ space,
+ align,
+ initializer,
+ )?,
+ ));
+ Ok(())
+ },
+ )?;
}
}
- _ => {}
}
}
- let mut result = HashMap::new();
- for (method_key, children) in directly_called_by.iter() {
- match method_key {
- MethodName::Kernel(name) => {
- let mut visited = HashSet::new();
- for child in children {
- add_call_map_single(&directly_called_by, &mut visited, *child);
- }
- result.insert(*name, visited);
+ Ok((
+ TranslationModule {
+ compilation_mode: module.compilation_mode,
+ sm_version,
+ id_defs,
+ ptx_impl_imports,
+ directives,
+ },
+ functions,
+ ))
+}
+
+fn get_existing_methods<'input, P: ast::ArgParams<Id = Id>>(
+ directives: &[TranslationDirective<'input, P>],
+) -> FxHashMap<
+ Id,
+ (
+ Vec<ast::VariableDeclaration<Id>>,
+ Vec<ast::VariableDeclaration<Id>>,
+ ),
+> {
+ let mut result = FxHashMap::default();
+ for directive in directives {
+ match directive {
+ TranslationDirective::Variable(..) => continue,
+ TranslationDirective::Method(method) => {
+ result.insert(
+ method.name,
+ (
+ method.return_arguments.clone(),
+ method.input_arguments.clone(),
+ ),
+ );
}
- MethodName::Func(_) => {}
}
}
result
}
-fn add_call_map_single<'input>(
- directly_called_by: &MultiHashMap<MethodName<'input>, spirv::Word>,
- visited: &mut HashSet<spirv::Word>,
- current: spirv::Word,
-) {
- if !visited.insert(current) {
- return;
- }
- if let Some(children) = directly_called_by.get(&MethodName::Func(current)) {
- for child in children {
- add_call_map_single(directly_called_by, visited, *child);
+fn normalize_method<'a, 'b, 'input>(
+ linking_resolver: &mut ResolvedLinking<'input>,
+ function_decls: &mut FxHashMap<
+ Id,
+ (
+ Vec<ast::VariableDeclaration<Id>>,
+ Vec<ast::VariableDeclaration<Id>>,
+ ),
+ >,
+ (module, directive): (usize, usize),
+ module_scope: &mut StringIdResolverScope<'a, 'b, 'input>,
+ linking_directive: ast::LinkingDirective,
+ method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>>,
+) -> Result<TranslationMethod<'input, NormalizedArgParams>, TranslateError> {
+ let is_kernel = method.func_directive.name.is_kernel();
+ let linking_adjustment = linking_resolver.get_adjustment(
+ module,
+ directive,
+ Cow::Borrowed(method.func_directive.name()),
+ linking_directive,
+ false,
+ )?;
+ let (has_global_name, has_body) = match linking_adjustment {
+ VisibilityAdjustment::Global => (true, true),
+ VisibilityAdjustment::Module => (false, true),
+ VisibilityAdjustment::GlobalDeclaration(_) => (true, false),
+ };
+ let name =
+ module_scope.add_or_get_at_module_level(method.func_directive.name(), has_global_name)?;
+ let mut fn_scope = module_scope.start_scope();
+ let return_arguments =
+ normalize_method_params(&mut fn_scope, &*method.func_directive.return_arguments)?;
+ let input_arguments =
+ normalize_method_params(&mut fn_scope, &*method.func_directive.input_arguments)?;
+ if !is_kernel {
+ if let hash_map::Entry::Vacant(entry) = function_decls.entry(name) {
+ entry.insert((return_arguments.clone(), input_arguments.clone()));
+ }
+ }
+ let source_name = if has_global_name {
+ Some(Cow::Borrowed(method.func_directive.name()))
+ } else {
+ None
+ };
+ let body = if has_body {
+ method
+ .body
+ .map(|body| {
+ let body = normalize_identifiers2(&mut fn_scope, body)?;
+ normalize_predicates2(&mut fn_scope, body)
+ })
+ .transpose()?
+ } else {
+ None
+ };
+ Ok(TranslationMethod {
+ return_arguments,
+ name,
+ input_arguments,
+ body,
+ tuning: method.tuning,
+ is_kernel,
+ source_name,
+ special_raytracing_linking: false,
+ })
+}
+
+fn normalize_method_params<'a, 'b, 'input>(
+ fn_scope: &mut StringIdResolverScope<'a, 'b, 'input>,
+ args: &[ast::VariableDeclaration<&'input str>],
+) -> Result<Vec<ast::VariableDeclaration<Id>>, TranslateError> {
+ args.iter()
+ .map(|a| {
+ Ok(ast::VariableDeclaration {
+ name: fn_scope.add_variable_checked(
+ a.name,
+ a.type_.clone(),
+ a.state_space,
+ a.align,
+ )?,
+ type_: a.type_.clone(),
+ state_space: a.state_space,
+ align: a.align,
+ })
+ })
+ .collect()
+}
+
+fn normalize_identifiers2<'a, 'b, 'input>(
+ scope: &mut StringIdResolverScope<'a, 'b, 'input>,
+ func: Vec<ast::Statement<ast::ParsedArgParams<'input>>>,
+) -> Result<Vec<NormalizedStatement>, TranslateError> {
+ gather_labels_in_scope(scope, &func)?;
+ let mut result = Vec::with_capacity(func.len());
+ for statement in func {
+ match statement {
+ ast::Statement::Block(block) => {
+ let mut scope = scope.start_scope();
+ result.extend(normalize_identifiers2(&mut scope, block)?);
+ }
+ ast::Statement::Label(name) => {
+ result.push(Statement::Label(scope.get_id_in_function_scopes(name)?))
+ }
+ ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
+ p.map(|p| p.map_variable(&mut |id| scope.get_id_in_module_scopes(id)))
+ .transpose()?,
+ i.map_variable(&mut |id| scope.get_id_in_module_scopes(id))?,
+ ))),
+ ast::Statement::Variable(vars) => {
+ expand_multivariable2(
+ scope,
+ vars.into_iter(),
+ |scope, align, type_, space, name, initializer| {
+ result.push(Statement::Variable(scope.register_variable(
+ name,
+ type_.clone(),
+ space,
+ align,
+ initializer,
+ )?));
+ Ok(())
+ },
+ )?;
+ }
+ ast::Statement::Callprototype(proto) => {
+ let name = scope.add_untyped_checked(proto.name)?;
+ scope.0.module.globals.function_prototypes.insert(
+ name,
+ Callprototype {
+ return_arguments: proto.return_arguments,
+ input_arguments: proto.input_arguments,
+ },
+ );
+ }
}
}
+ Ok(result)
}
-type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
+fn expand_multivariable2<'a, 'b, 'input>(
+ scope: &mut StringIdResolverScope<'a, 'b, 'input>,
+ vars: impl Iterator<Item = ast::MultiVariableDefinition<&'input str>>,
+ mut inserter: impl FnMut(
+ &mut StringIdResolverScope<'a, 'b, 'input>,
+ Option<u32>,
+ &ast::Type,
+ ast::StateSpace,
+ Cow<'input, str>,
+ Option<ast::Initializer<Id>>,
+ ) -> Result<(), TranslateError>,
+) -> Result<(), TranslateError> {
+ for var in vars {
+ let initializer = match var.suffix {
+ Some(ast::DeclarationSuffix::Count(count)) => {
+ for offset in 0..count {
+ let name = Cow::Owned(format!("{}{}", var.variable.name, offset));
+ inserter(
+ scope,
+ var.variable.align,
+ &var.variable.type_,
+ var.variable.state_space,
+ name,
+ None,
+ )?;
+ }
+ return Ok(());
+ }
+ Some(ast::DeclarationSuffix::Initializer(init)) => {
+ Some(expand_initializer2(scope, init)?)
+ }
+ None => None,
+ };
+ let name = Cow::Borrowed(var.variable.name);
+ inserter(
+ scope,
+ var.variable.align,
+ &var.variable.type_,
+ var.variable.state_space,
+ name,
+ initializer,
+ )?;
+ }
+ Ok(())
+}
-fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, key: K, value: V) {
- match m.entry(key) {
- hash_map::Entry::Occupied(mut entry) => {
- entry.get_mut().push(value);
+fn expand_initializer2<'a, 'b, 'input>(
+ scope: &mut StringIdResolverScope<'a, 'b, 'input>,
+ init: ast::Initializer<&'input str>,
+) -> Result<ast::Initializer<Id>, TranslateError> {
+ Ok(match init {
+ ast::Initializer::Constant(c) => ast::Initializer::Constant(c),
+ ast::Initializer::Global(g, type_) => {
+ ast::Initializer::Global(scope.get_id_in_module_scope(g)?, type_)
}
- hash_map::Entry::Vacant(entry) => {
- entry.insert(vec![value]);
+ ast::Initializer::GenericGlobal(g, type_) => {
+ ast::Initializer::GenericGlobal(scope.get_id_in_module_scope(g)?, type_)
}
- }
+ ast::Initializer::Add(add) => {
+ let (init1, init2) = *add;
+ ast::Initializer::Add(Box::new((
+ expand_initializer2(scope, init1)?,
+ expand_initializer2(scope, init2)?,
+ )))
+ }
+ ast::Initializer::Array(array) => ast::Initializer::Array(
+ array
+ .into_iter()
+ .map(|init| expand_initializer2(scope, init))
+ .collect::<Result<Vec<_>, _>>()?,
+ ),
+ })
}
-// PTX represents dynamically allocated shared local memory as
-// .extern .shared .align 4 .b8 shared_mem[];
-// In SPIRV/OpenCL world this is expressed as an additional argument
-// This pass looks for all uses of .extern .shared and converts them to
-// an additional method argument
-fn convert_dynamic_shared_memory_usage<'input>(
- module: Vec<Directive<'input>>,
- new_id: &mut impl FnMut() -> spirv::Word,
-) -> Vec<Directive<'input>> {
- let mut extern_shared_decls = HashMap::new();
- for dir in module.iter() {
- match dir {
- Directive::Variable(var) => {
- if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) =
- var.v_type
- {
- extern_shared_decls.insert(var.name, p_type);
+fn normalize_predicates2<'a, 'b, 'input>(
+ scope: &mut StringIdResolverScope<'a, 'b, 'input>,
+ func: Vec<NormalizedStatement>,
+) -> Result<Vec<UnconditionalStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(func.len());
+ for s in func {
+ match s {
+ Statement::Label(id) => result.push(Statement::Label(id)),
+ Statement::Instruction((pred, inst)) => {
+ if let Some(pred) = pred {
+ let if_true = scope.new_untyped();
+ let if_false = scope.new_untyped();
+ let folded_bra = match &inst {
+ ast::Instruction::Bra(_, arg) => Some(arg.src),
+ _ => None,
+ };
+ let mut branch = BrachCondition {
+ predicate: pred.label,
+ if_true: folded_bra.unwrap_or(if_true),
+ if_false,
+ };
+ if pred.not {
+ std::mem::swap(&mut branch.if_true, &mut branch.if_false);
+ }
+ result.push(Statement::Conditional(branch));
+ if folded_bra.is_none() {
+ result.push(Statement::Label(if_true));
+ result.push(Statement::Instruction(inst));
+ }
+ result.push(Statement::Label(if_false));
+ } else {
+ result.push(Statement::Instruction(inst));
}
}
- _ => {}
+ Statement::Variable(var) => result.push(Statement::Variable(var)),
+ // Blocks are flattened when resolving ids
+ _ => return Err(TranslateError::unreachable()),
}
}
- if extern_shared_decls.len() == 0 {
- return module;
+ Ok(result)
+}
+
+// Instructions can reference labels that are declared later on so
+// we gather ids of labels ahead of time
+fn gather_labels_in_scope<'a, 'b, 'input>(
+ scope: &mut StringIdResolverScope<'a, 'b, 'input>,
+ func: &[ast::Statement<ast::ParsedArgParams<'input>>],
+) -> Result<(), TranslateError> {
+ for s in func.iter() {
+ // Instructions can reference labels that are declared later so
+ // we gather ids of labels ahead of time
+ if let ast::Statement::Label(id) = s {
+ scope.add_untyped_checked(*id)?;
+ }
}
- let mut methods_using_extern_shared = HashSet::new();
- let mut directly_called_by = MultiHashMap::new();
- let module = module
+ Ok(())
+}
+
+fn resolve_instruction_types<'input>(
+ mut module: TranslationModule<'input, NormalizedArgParams>,
+ function_decls: FxHashMap<
+ Id,
+ (
+ Vec<ast::VariableDeclaration<Id>>,
+ Vec<ast::VariableDeclaration<Id>>,
+ ),
+ >,
+) -> Result<TranslationModule<'input, TypedArgParams>, TranslateError> {
+ let id_defs = &mut module.id_defs;
+ let directives = module
+ .directives
.into_iter()
- .map(|directive| match directive {
- Directive::Method(Function {
- func_decl,
- globals,
- body: Some(statements),
- import_as,
- spirv_decl,
- }) => {
- let call_key = MethodName::new(&func_decl);
- let statements = statements
- .into_iter()
- .map(|statement| match statement {
- Statement::Call(call) => {
- multi_hash_map_append(&mut directly_called_by, call.func, call_key);
- Statement::Call(call)
- }
- statement => statement.map_id(&mut |id, _| {
- if extern_shared_decls.contains_key(&id) {
- methods_using_extern_shared.insert(call_key);
- }
- id
- }),
+ .map(|directive| {
+ Ok(match directive {
+ TranslationDirective::Variable(linking, compiled_name, var) => {
+ TranslationDirective::Variable(
+ linking,
+ compiled_name,
+ resolve_initializers(id_defs, var)?,
+ )
+ }
+ TranslationDirective::Method(method) => {
+ let body = match method.body {
+ Some(body) => Some(resolve_instruction_types_method(
+ id_defs,
+ &function_decls,
+ body,
+ )?),
+ None => None,
+ };
+ TranslationDirective::Method(TranslationMethod {
+ return_arguments: method.return_arguments,
+ name: method.name,
+ input_arguments: method.input_arguments,
+ body,
+ tuning: method.tuning,
+ is_kernel: method.is_kernel,
+ source_name: method.source_name,
+ special_raytracing_linking: method.special_raytracing_linking,
})
- .collect();
- Directive::Method(Function {
- func_decl,
- globals,
- body: Some(statements),
- import_as,
- spirv_decl,
- })
- }
- directive => directive,
- })
- .collect::<Vec<_>>();
- // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
- // make sure it gets propagated to `fn1` and `kernel`
- get_callers_of_extern_shared(&mut methods_using_extern_shared, &directly_called_by);
- // now visit every method declaration and inject those additional arguments
- module
- .into_iter()
- .map(|directive| match directive {
- Directive::Method(Function {
- func_decl,
- globals,
- body: Some(statements),
- import_as,
- mut spirv_decl,
- }) => {
- if !methods_using_extern_shared.contains(&spirv_decl.name) {
- return Directive::Method(Function {
- func_decl,
- globals,
- body: Some(statements),
- import_as,
- spirv_decl,
- });
}
- let shared_id_param = new_id();
- spirv_decl.input.push({
- ast::Variable {
- align: None,
- v_type: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Shared,
- ),
- array_init: Vec::new(),
- name: shared_id_param,
- }
- });
- spirv_decl.uses_shared_mem = true;
- let shared_var_id = new_id();
- let shared_var = ExpandedStatement::Variable(ast::Variable {
- align: None,
- name: shared_var_id,
- array_init: Vec::new(),
- v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
- ast::SizedScalarType::B8,
- ast::PointerStateSpace::Shared,
- )),
- });
- let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
- arg: ast::Arg2St {
- src1: shared_var_id,
- src2: shared_id_param,
- },
- typ: ast::Type::Scalar(ast::ScalarType::B8),
- member_index: None,
- });
- let mut new_statements = vec![shared_var, shared_var_st];
- replace_uses_of_shared_memory(
- &mut new_statements,
- new_id,
- &extern_shared_decls,
- &mut methods_using_extern_shared,
- shared_id_param,
- shared_var_id,
- statements,
- );
- Directive::Method(Function {
- func_decl,
- globals,
- body: Some(new_statements),
- import_as,
- spirv_decl,
- })
- }
- directive => directive,
+ })
})
- .collect::<Vec<_>>()
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(TranslationModule {
+ compilation_mode: module.compilation_mode,
+ sm_version: module.sm_version,
+ directives,
+ id_defs: module.id_defs,
+ ptx_impl_imports: module.ptx_impl_imports,
+ })
}
-fn replace_uses_of_shared_memory<'a>(
- result: &mut Vec<ExpandedStatement>,
- new_id: &mut impl FnMut() -> spirv::Word,
- extern_shared_decls: &HashMap<spirv::Word, ast::SizedScalarType>,
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
- shared_id_param: spirv::Word,
- shared_var_id: spirv::Word,
- statements: Vec<ExpandedStatement>,
-) {
- for statement in statements {
+fn resolve_instruction_types_method<'input>(
+ id_defs: &mut IdNameMapBuilder<'input>,
+ function_decls: &FxHashMap<
+ Id,
+ (
+ Vec<ast::VariableDeclaration<Id>>,
+ Vec<ast::VariableDeclaration<Id>>,
+ ),
+ >,
+ fn_body: Vec<UnconditionalStatement>,
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let mut result = Vec::<TypedStatement>::with_capacity(fn_body.len());
+ let mut constants = KernelConstantsVisitor::new();
+ for statement in fn_body {
match statement {
- Statement::Call(mut call) => {
- // We can safely skip checking call arguments,
- // because there's simply no way to pass shared ptr
- // without converting it to .b64 first
- if methods_using_extern_shared.contains(&MethodName::Func(call.func)) {
- call.param_list
- .push((shared_id_param, ast::FnArgumentType::Shared));
+ Statement::Instruction(inst) => match inst {
+ // TODO: Replace this with proper constant propagation
+ ast::Instruction::PrmtSlow { control, arg } => {
+ let inst = if let Some(control) = constants.try_get_constant(control) {
+ ast::Instruction::Prmt { control, arg }
+ } else {
+ ast::Instruction::PrmtSlow { control, arg }
+ };
+ let mut visitor =
+ VectorRepackVisitor::new(&mut constants, &mut result, id_defs);
+ let reresolved_call = inst.visit(&mut visitor)?;
+ visitor.func.push(reresolved_call);
+ visitor.func.extend(visitor.post_stmts);
}
- result.push(Statement::Call(call))
- }
- statement => {
- let new_statement = statement.map_id(&mut |id, _| {
- if let Some(typ) = extern_shared_decls.get(&id) {
- if *typ == ast::SizedScalarType::B8 {
- return shared_var_id;
+ ast::Instruction::Sust(mut details, args) => {
+ if let ast::Operand::Reg(image) = args.image {
+ let (image_type, _, _, _) = id_defs.get_typed(image)?;
+ if matches!(image_type, ast::Type::Surfref) {
+ details.direct = true;
}
- let replacement_id = new_id();
- result.push(Statement::Conversion(ImplicitConversion {
- src: shared_var_id,
- dst: replacement_id,
- from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::B8),
- ast::LdStateSpace::Shared,
- ),
- to: ast::Type::Pointer(
- ast::PointerType::Scalar((*typ).into()),
- ast::LdStateSpace::Shared,
- ),
- kind: ConversionKind::PtrToPtr { spirv_ptr: true },
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
- }));
- replacement_id
- } else {
- id
}
- });
- result.push(new_statement);
+ let mut visitor =
+ VectorRepackVisitor::new(&mut constants, &mut result, id_defs);
+ let reresolved_call =
+ ast::Instruction::Sust(details, args).visit(&mut visitor)?;
+ visitor.func.push(reresolved_call);
+ visitor.func.extend(visitor.post_stmts);
+ }
+ ast::Instruction::Suld(mut details, args) => {
+ if let ast::Operand::Reg(image) = args.image {
+ let (image_type, _, _, _) = id_defs.get_typed(image)?;
+ if matches!(image_type, ast::Type::Surfref) {
+ details.direct = true;
+ }
+ }
+ let mut visitor =
+ VectorRepackVisitor::new(&mut constants, &mut result, id_defs);
+ let reresolved_call =
+ ast::Instruction::Suld(details, args).visit(&mut visitor)?;
+ visitor.func.push(reresolved_call);
+ visitor.func.extend(visitor.post_stmts);
+ }
+ ast::Instruction::Tex(mut details, args) => {
+ if let ast::Operand::Reg(image) = args.image {
+ let (image_type, _, _, _) = id_defs.get_typed(image)?;
+ if matches!(image_type, ast::Type::Texref) {
+ details.direct = true;
+ }
+ }
+ let mut visitor =
+ VectorRepackVisitor::new(&mut constants, &mut result, id_defs);
+ let reresolved_call =
+ ast::Instruction::Tex(details, args).visit(&mut visitor)?;
+ visitor.func.push(reresolved_call);
+ visitor.func.extend(visitor.post_stmts);
+ }
+ ast::Instruction::Mov(
+ mov,
+ ast::Arg2Mov {
+ dst: ast::Operand::Reg(dst_reg),
+ src: ast::Operand::Reg(src_reg),
+ },
+ ) if function_decls.contains_key(&src_reg) => {
+ if mov.typ != ast::Type::Scalar(ast::ScalarType::U64) {
+ return Err(TranslateError::mismatched_type());
+ }
+ result.push(TypedStatement::FunctionPointer(FunctionPointerDetails {
+ dst: dst_reg,
+ src: src_reg,
+ }));
+ }
+ ast::Instruction::Mov(
+ ast::MovDetails {
+ typ: ast::Type::Scalar(type_),
+ ..
+ },
+ ast::Arg2Mov {
+ dst: ast::Operand::Reg(dst_reg),
+ src: ast::Operand::Imm(src),
+ },
+ ) if type_.size_of() >= 2 && type_.is_integer() => {
+ constants.insert(
+ dst_reg,
+ Some(src.as_u16().ok_or_else(TranslateError::unreachable)?),
+ )?;
+ let mut noop_visitor = PassthroughVisitor;
+ let mut visitor =
+ VectorRepackVisitor::new(&mut noop_visitor, &mut result, id_defs);
+ let instruction = Statement::Instruction(inst.map(&mut visitor)?);
+ visitor.func.push(instruction);
+ visitor.func.extend(visitor.post_stmts);
+ }
+ ast::Instruction::Call(call) => {
+ let resolved_call = match function_decls.get(&call.func) {
+ Some((return_args, input_args)) => {
+ ResolvedCall::from_declaration(call, return_args, input_args)?
+ }
+ None => {
+ let callproto_name =
+ call.prototype.ok_or_else(TranslateError::unreachable)?;
+ let callproto = id_defs
+ .globals
+ .function_prototypes
+ .get(&callproto_name)
+ .ok_or_else(TranslateError::unreachable)?;
+ ResolvedCall::from_callprototype(call, callproto)?
+ }
+ };
+ let mut visitor =
+ VectorRepackVisitor::new(&mut constants, &mut result, id_defs);
+ let reresolved_call = resolved_call.visit(&mut visitor)?;
+ visitor.func.push(reresolved_call);
+ visitor.func.extend(visitor.post_stmts);
+ }
+ inst => {
+ let mut visitor =
+ VectorRepackVisitor::new(&mut constants, &mut result, id_defs);
+ let instruction = Statement::Instruction(inst.map(&mut visitor)?);
+ visitor.func.push(instruction);
+ visitor.func.extend(visitor.post_stmts);
+ }
+ },
+ Statement::Label(i) => result.push(Statement::Label(i)),
+ Statement::Variable(v) => {
+ result.push(Statement::Variable(resolve_initializers(id_defs, v)?))
}
+ Statement::Conditional(c) => result.push(Statement::Conditional(c)),
+ _ => return Err(TranslateError::unreachable()),
}
}
+ Ok(result)
}
-fn get_callers_of_extern_shared<'a>(
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
- directly_called_by: &MultiHashMap<spirv::Word, MethodName<'a>>,
-) {
- let direct_uses_of_extern_shared = methods_using_extern_shared
- .iter()
- .filter_map(|method| {
- if let MethodName::Func(f_id) = method {
- Some(*f_id)
- } else {
- None
- }
- })
- .collect::<Vec<_>>();
- for fn_id in direct_uses_of_extern_shared {
- get_callers_of_extern_shared_single(methods_using_extern_shared, directly_called_by, fn_id);
- }
-}
-
-fn get_callers_of_extern_shared_single<'a>(
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
- directly_called_by: &MultiHashMap<spirv::Word, MethodName<'a>>,
- fn_id: spirv::Word,
-) {
- if let Some(callers) = directly_called_by.get(&fn_id) {
- for caller in callers {
- if methods_using_extern_shared.insert(*caller) {
- if let MethodName::Func(caller_fn) = caller {
- get_callers_of_extern_shared_single(
- methods_using_extern_shared,
- directly_called_by,
- *caller_fn,
- );
+fn resolve_initializers<'input>(
+ id_defs: &mut IdNameMapBuilder<'input>,
+ mut v: Variable,
+) -> Result<Variable, TranslateError> {
+ fn resolve_initializer_impl<'input>(
+ id_defs: &mut IdNameMapBuilder<'input>,
+ init: &mut ast::Initializer<Id>,
+ ) -> Result<(), TranslateError> {
+ match init {
+ ast::Initializer::Constant(_) => {}
+ ast::Initializer::Global(name, type_)
+ | ast::Initializer::GenericGlobal(name, type_) => {
+ let (src_type, _, _, _) = id_defs.get_typed(*name)?;
+ *type_ = src_type;
+ }
+ ast::Initializer::Add(subinit) => {
+ resolve_initializer_impl(id_defs, &mut (*subinit).0)?;
+ resolve_initializer_impl(id_defs, &mut (*subinit).1)?;
+ }
+ ast::Initializer::Array(inits) => {
+ for init in inits.iter_mut() {
+ resolve_initializer_impl(id_defs, init)?;
}
}
}
+ Ok(())
+ }
+ if let Some(ref mut init) = v.initializer {
+ resolve_initializer_impl(id_defs, init)?;
}
+ Ok(v)
}
-type DenormCountMap<T> = HashMap<T, isize>;
+// TODO: All this garbage should be replaced with proper constant propagation or
+// at least ability to visit statements without moving them
+struct KernelConstantsVisitor {
+ constant_candidates: FxHashMap<Id, ConstantCandidate>,
+}
-fn denorm_count_map_update<T: Eq + Hash>(map: &mut DenormCountMap<T>, key: T, value: bool) {
- let num_value = if value { 1 } else { -1 };
- denorm_count_map_update_impl(map, key, num_value);
+struct ConstantCandidate {
+ value: u16,
+ used: bool,
+ not_const: bool,
}
-fn denorm_count_map_update_impl<T: Eq + Hash>(
- map: &mut DenormCountMap<T>,
- key: T,
- num_value: isize,
-) {
- match map.entry(key) {
- hash_map::Entry::Occupied(mut counter) => {
- *(counter.get_mut()) += num_value;
+impl KernelConstantsVisitor {
+ fn new() -> Self {
+ Self {
+ constant_candidates: FxHashMap::default(),
}
- hash_map::Entry::Vacant(entry) => {
- entry.insert(num_value);
+ }
+
+ fn insert(&mut self, id: Id, value: Option<u16>) -> Result<(), TranslateError> {
+ match self.constant_candidates.entry(id) {
+ hash_map::Entry::Occupied(mut entry) => {
+ let candidate = entry.get_mut();
+ if candidate.used {
+ return Err(TranslateError::unexpected_pattern());
+ }
+ candidate.not_const = true;
+ }
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(ConstantCandidate {
+ value: value.unwrap_or(u16::MAX),
+ used: false,
+ not_const: value.is_none(),
+ });
+ }
}
+ Ok(())
+ }
+
+ fn try_get_constant(&mut self, id: Id) -> Option<u16> {
+ self.constant_candidates.get_mut(&id).and_then(|candidate| {
+ if candidate.not_const {
+ return None;
+ }
+ candidate.used = true;
+ Some(candidate.value)
+ })
}
}
-// HACK ALERT!
-// This function is a "good enough" heuristic of whetever to mark f16/f32 operations
-// in the kernel as flushing denorms to zero or preserving them
-// PTX support per-instruction ftz information. Unfortunately SPIR-V has no
-// such capability, so instead we guesstimate which use is more common in the kernel
-// and emit suitable execution mode
-fn compute_denorm_information<'input>(
- module: &[Directive<'input>],
-) -> HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
- let mut denorm_methods = HashMap::new();
- for directive in module {
- match directive {
- Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {}
- Directive::Method(Function {
- func_decl,
- body: Some(statements),
- ..
- }) => {
- let mut flush_counter = DenormCountMap::new();
- let method_key = MethodName::new(func_decl);
- for statement in statements {
- match statement {
- Statement::Instruction(inst) => {
- if let Some((flush, width)) = inst.flush_to_zero() {
- denorm_count_map_update(&mut flush_counter, width, flush);
- }
- }
- Statement::LoadVar(..) => {}
- Statement::StoreVar(..) => {}
- Statement::Call(_) => {}
- Statement::Conditional(_) => {}
- Statement::Conversion(_) => {}
- Statement::Constant(_) => {}
- Statement::RetValue(_, _) => {}
- Statement::Label(_) => {}
- Statement::Variable(_) => {}
- Statement::PtrAccess { .. } => {}
- Statement::RepackVector(_) => {}
- }
- }
- denorm_methods.insert(method_key, flush_counter);
+impl ArgumentMapVisitor<TypedArgParams, TypedArgParams> for KernelConstantsVisitor {
+ fn id(
+ &mut self,
+ desc: ArgumentDescriptor<Id>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError> {
+ if desc.is_dst {
+ self.insert(desc.op, None)?;
+ }
+ Ok(desc.op)
+ }
+
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<TypedOperand>,
+ _: &ast::Type,
+ _: ast::StateSpace,
+ ) -> Result<TypedOperand, TranslateError> {
+ if desc.is_dst {
+ if let TypedOperand::Reg(op) = desc.op {
+ self.insert(op, None)?;
}
}
+ Ok(desc.op)
}
- denorm_methods
+}
+
+struct PassthroughVisitor;
+
+impl<P: ArgParamsEx> ArgumentMapVisitor<P, P> for PassthroughVisitor {
+ fn id(
+ &mut self,
+ desc: ArgumentDescriptor<P::Id>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<P::Id, TranslateError> {
+ Ok(desc.op)
+ }
+
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<P::Operand>,
+ _: &ast::Type,
+ _: ast::StateSpace,
+ ) -> Result<P::Operand, TranslateError> {
+ Ok(desc.op)
+ }
+}
+
+fn convert_methods<'input, From: ast::ArgParams<Id = Id>, To: ast::ArgParams<Id = Id>>(
+ mut module: TranslationModule<'input, From>,
+ mut mapper: impl FnMut(
+ CompilationMode,
+ &mut IdNameMapBuilder<'input>,
+ &mut AdditionalFunctionDeclarations,
+ &mut [ast::VariableDeclaration<From::Id>],
+ &mut [ast::VariableDeclaration<From::Id>],
+ bool,
+ Vec<Statement<ast::Instruction<From>, From>>,
+ ) -> Result<Vec<Statement<ast::Instruction<To>, To>>, TranslateError>,
+) -> Result<TranslationModule<'input, To>, TranslateError> {
+ let compilation_mode = module.compilation_mode;
+ let id_defs = &mut module.id_defs;
+ let mut additional_declarations = AdditionalFunctionDeclarations::new();
+ let post_declarations_directives = module
+ .directives
.into_iter()
- .map(|(name, v)| {
- let width_to_denorm = v
- .into_iter()
- .map(|(k, flush_over_preserve)| {
- let mode = if flush_over_preserve > 0 {
- spirv::FPDenormMode::FlushToZero
- } else {
- spirv::FPDenormMode::Preserve
+ .map(|directive| {
+ Ok(match directive {
+ TranslationDirective::Method(mut method) => {
+ let body = match method.body {
+ Some(body) => Some(mapper(
+ compilation_mode,
+ id_defs,
+ &mut additional_declarations,
+ &mut method.return_arguments,
+ &mut method.input_arguments,
+ method.is_kernel,
+ body,
+ )?),
+ None => None,
};
- (k, (mode, flush_over_preserve))
- })
- .collect();
- (name, width_to_denorm)
+ TranslationDirective::Method(TranslationMethod {
+ return_arguments: method.return_arguments,
+ name: method.name,
+ input_arguments: method.input_arguments,
+ body,
+ tuning: method.tuning,
+ is_kernel: method.is_kernel,
+ source_name: method.source_name,
+ special_raytracing_linking: method.special_raytracing_linking,
+ })
+ }
+ TranslationDirective::Variable(linking, compiled_name, var) => {
+ TranslationDirective::Variable(linking, compiled_name, var)
+ }
+ })
})
- .collect()
+ .collect::<Result<Vec<_>, TranslateError>>()?;
+ let mut directives = Vec::with_capacity(post_declarations_directives.len());
+ additional_declarations.flush(&mut directives);
+ directives.extend(post_declarations_directives);
+ Ok(TranslationModule {
+ compilation_mode: module.compilation_mode,
+ sm_version: module.sm_version,
+ directives: directives,
+ id_defs: module.id_defs,
+ ptx_impl_imports: module.ptx_impl_imports,
+ })
}
-#[derive(Hash, PartialEq, Eq, Copy, Clone)]
-enum MethodName<'input> {
- Kernel(&'input str),
- Func(spirv::Word),
+fn convert_methods_simple<'input, From: ast::ArgParams<Id = Id>, To: ast::ArgParams<Id = Id>>(
+ module: TranslationModule<'input, From>,
+ mut mapper: impl FnMut(
+ &mut IdNameMapBuilder<'input>,
+ Vec<Statement<ast::Instruction<From>, From>>,
+ ) -> Result<Vec<Statement<ast::Instruction<To>, To>>, TranslateError>,
+) -> Result<TranslationModule<'input, To>, TranslateError> {
+ convert_methods(module, |_, id_defs, _, _, _, _, body| mapper(id_defs, body))
+}
+
+// NVIDIA PTX compiler emits methods that are declared like this:
+// .visible .func (.param .b8 retval[12]) foobar(.param .b64 arg1, .param .b64 arg2);
+// This pass converts them to a regular form:
+// .visible .func (.reg .b8 retval[12]) foobar(.reg .b64 arg1, .reg .b64 arg2);
+// and does appropriate adjustments to function calls
+fn deparamize_function_declarations<'input>(
+ mut module: TranslationModule<'input, TypedArgParams>,
+) -> Result<TranslationModule<'input, TypedArgParams>, TranslateError> {
+ let id_defs = &mut module.id_defs;
+ let mut delayed_deparamization = FxHashMap::default();
+ let directives = module
+ .directives
+ .into_iter()
+ .map(|directive| deparamize_directive(id_defs, directive, &mut delayed_deparamization))
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(TranslationModule {
+ directives,
+ ..module
+ })
}
-impl<'input> MethodName<'input> {
- fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
- match decl {
- ast::MethodDecl::Kernel { name, .. } => MethodName::Kernel(name),
- ast::MethodDecl::Func(_, id, _) => MethodName::Func(*id),
+fn deparamize_directive<'input>(
+ id_defs: &mut IdNameMapBuilder,
+ directive: TranslationDirective<'input, TypedArgParams>,
+ delayed_deparamization: &mut FxHashMap<Id, (BitVec, BitVec)>,
+) -> Result<TranslationDirective<'input, TypedArgParams>, TranslateError> {
+ Ok(match directive {
+ var @ TranslationDirective::Variable(..) => var,
+ TranslationDirective::Method(mut method) => {
+ let deparamized_args = get_deparamized_arguments(&method, delayed_deparamization);
+ deparamize_function_body(id_defs, &mut method, deparamized_args)?;
+ TranslationDirective::Method(method)
}
+ })
+}
+
+fn get_deparamized_arguments<'a, 'input, T: ast::ArgParams<Id = Id>>(
+ method: &TranslationMethod<'input, T>,
+ delayed_deparamization: &'a mut FxHashMap<Id, (BitVec, BitVec)>,
+) -> Option<&'a (BitVec, BitVec)> {
+ if method.is_kernel {
+ return None;
}
+ Some(
+ delayed_deparamization
+ .entry(method.name)
+ .or_insert_with(|| {
+ let return_deparams = get_deparamize_arg_list(&method.return_arguments[..]);
+ let input_deparams = get_deparamize_arg_list(&method.input_arguments[..]);
+ (return_deparams, input_deparams)
+ }),
+ )
}
-fn emit_builtins(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- id_defs: &GlobalStringIdResolver,
-) {
- for (reg, id) in id_defs.special_registers.builtins() {
- let result_type = map.get_or_add(
- builder,
- SpirvType::Pointer(
- Box::new(SpirvType::from(reg.get_type())),
- spirv::StorageClass::Input,
- ),
- );
- builder.variable(result_type, Some(id), spirv::StorageClass::Input, None);
- builder.decorate(
- id,
- spirv::Decoration::BuiltIn,
- &[dr::Operand::BuiltIn(reg.get_builtin())],
- );
+fn get_deparamize_arg_list(args: &[ast::VariableDeclaration<Id>]) -> BitVec {
+ let mut deparams = BitVec::from_elem(args.len(), false);
+ for (index, arg) in args.iter().enumerate() {
+ if arg.state_space == ast::StateSpace::Param {
+ deparams.set(index, true);
+ }
}
+ deparams
}
-fn emit_function_header<'a>(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- defined_globals: &GlobalStringIdResolver<'a>,
- synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
- func_decl: &SpirvMethodDecl<'a>,
- _denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
- call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
- direcitves: &[Directive],
- kernel_info: &mut HashMap<String, KernelInfo>,
+fn deparamize_function_body<'input>(
+ id_defs: &mut IdNameMapBuilder,
+ method: &mut TranslationMethod<'input, TypedArgParams>,
+ deparams: Option<&(BitVec, BitVec)>,
) -> Result<(), TranslateError> {
- if let MethodName::Kernel(name) = func_decl.name {
- let input_args = if !func_decl.uses_shared_mem {
- func_decl.input.as_slice()
- } else {
- &func_decl.input[0..func_decl.input.len() - 1]
- };
- let args_lens = input_args
- .iter()
- .map(|param| param.v_type.size_of())
- .collect();
- kernel_info.insert(
- name.to_string(),
- KernelInfo {
- arguments_sizes: args_lens,
- uses_shared_mem: func_decl.uses_shared_mem,
- },
- );
- }
- let (ret_type, func_type) =
- get_function_type(builder, map, &func_decl.input, &func_decl.output);
- let fn_id = match func_decl.name {
- MethodName::Kernel(name) => {
- let fn_id = defined_globals.get_id(name)?;
- let mut global_variables = defined_globals
- .variables_type_check
- .iter()
- .filter_map(|(k, t)| t.as_ref().map(|_| *k))
- .collect::<Vec<_>>();
- let mut interface = defined_globals.special_registers.interface();
- for ast::Variable { name, .. } in synthetic_globals {
- interface.push(*name);
- }
- let empty_hash_set = HashSet::new();
- let child_fns = call_map.get(name).unwrap_or(&empty_hash_set);
- for directive in direcitves {
- match directive {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- globals,
- ..
- }) => {
- if child_fns.contains(name) {
- for var in globals {
- interface.push(var.name);
- }
- }
- }
- _ => {}
- }
+ let mut result = Vec::with_capacity(method.body.as_ref().map(|body| body.len()).unwrap_or(0));
+ let has_body = method.body.is_some();
+ let mut return_args_flush: hash_map::HashMap<
+ Id,
+ Id,
+ std::hash::BuildHasherDefault<rustc_hash::FxHasher>,
+ > = FxHashMap::default();
+ if let Some((return_deparams, input_deparams)) = deparams {
+ for (index, return_arg) in method.return_arguments.iter_mut().enumerate() {
+ if !return_deparams
+ .get(index)
+ .ok_or_else(TranslateError::unreachable)?
+ {
+ continue;
+ }
+ if has_body {
+ let original_name = return_arg.name;
+ *return_arg = id_defs.register_variable_decl(
+ return_arg.align,
+ return_arg.type_.clone(),
+ ast::StateSpace::Reg,
+ );
+ return_args_flush.insert(return_arg.name, original_name);
+ result.push(Statement::Variable(Variable {
+ align: return_arg.align,
+ type_: return_arg.type_.clone(),
+ state_space: ast::StateSpace::Param,
+ name: original_name,
+ initializer: None,
+ }));
+ } else {
+ return_arg.state_space = ast::StateSpace::Reg;
+ }
+ }
+ for (index, input_arg) in method.input_arguments.iter_mut().enumerate() {
+ if !input_deparams
+ .get(index)
+ .ok_or_else(TranslateError::unreachable)?
+ {
+ continue;
+ }
+ if has_body {
+ let original_name = input_arg.name;
+ *input_arg = id_defs.register_variable_decl(
+ input_arg.align,
+ input_arg.type_.clone(),
+ ast::StateSpace::Reg,
+ );
+ result.push(Statement::Variable(Variable {
+ align: input_arg.align,
+ type_: input_arg.type_.clone(),
+ state_space: ast::StateSpace::Param,
+ name: original_name,
+ initializer: None,
+ }));
+ result.push(Statement::Instruction(ast::Instruction::St(
+ ast::StData {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Param,
+ caching: ast::StCacheOperator::Writeback,
+ typ: input_arg.type_.clone(),
+ },
+ ast::Arg2St {
+ src1: TypedOperand::Reg(original_name),
+ src2: TypedOperand::Reg(input_arg.name),
+ },
+ )));
+ } else {
+ input_arg.state_space = ast::StateSpace::Reg;
}
- global_variables.append(&mut interface);
- builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
- fn_id
}
- MethodName::Func(name) => name,
+ }
+ let body = if let Some(ref mut body) = method.body {
+ std::mem::replace(body, Vec::new())
+ } else {
+ return Ok(());
};
- builder.begin_function(
- ret_type,
- Some(fn_id),
- spirv::FunctionControl::NONE,
- func_type,
- )?;
- // TODO: re-enable when Intel float control extension works
- /*
- if let Some(denorm_modes) = denorm_information.get(&func_decl.name) {
- for (size_of, denorm_mode) in denorm_modes {
- builder.decorate(
- fn_id,
- spirv::Decoration::FunctionDenormModeINTEL,
- [
- dr::Operand::LiteralInt32((*size_of as u32) * 8),
- dr::Operand::FPDenormMode(*denorm_mode),
- ],
- )
+ for statement in body {
+ match statement {
+ Statement::Instruction(ast::Instruction::Exit) => {
+ deparamize_instruction_ret(
+ deparams,
+ &method.return_arguments,
+ &return_args_flush,
+ &mut result,
+ ast::Instruction::Exit,
+ )?;
+ }
+ Statement::Instruction(ast::Instruction::Ret(ret)) => {
+ deparamize_instruction_ret(
+ deparams,
+ &method.return_arguments,
+ &return_args_flush,
+ &mut result,
+ ast::Instruction::Ret(ret),
+ )?;
+ }
+ Statement::Call(call) => {
+ deparamize_single_function_call(id_defs, &mut result, call)?;
+ }
+ statement => result.push(statement),
}
}
- */
- for input in &func_decl.input {
- let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone()));
- let inst = dr::Instruction::new(
- spirv::Op::FunctionParameter,
- Some(result_type),
- Some(input.name),
- Vec::new(),
- );
- builder.function.as_mut().unwrap().parameters.push(inst);
- }
+ method.body = Some(result);
Ok(())
}
-fn emit_capabilities(builder: &mut dr::Builder) {
- builder.capability(spirv::Capability::GenericPointer);
- builder.capability(spirv::Capability::Linkage);
- builder.capability(spirv::Capability::Addresses);
- builder.capability(spirv::Capability::Kernel);
- builder.capability(spirv::Capability::Int8);
- builder.capability(spirv::Capability::Int16);
- builder.capability(spirv::Capability::Int64);
- builder.capability(spirv::Capability::Float16);
- builder.capability(spirv::Capability::Float64);
- // TODO: re-enable when Intel float control extension works
- //builder.capability(spirv::Capability::FunctionFloatControlINTEL);
+fn deparamize_instruction_ret(
+ deparams: Option<&(BitVec, BitVec)>,
+ return_arguments: &[ast::VariableDeclaration<Id>],
+ return_args_flush: &std::collections::HashMap<
+ Id,
+ Id,
+ std::hash::BuildHasherDefault<rustc_hash::FxHasher>,
+ >,
+ result: &mut Vec<Statement<ast::Instruction<TypedArgParams>, TypedArgParams>>,
+ ret: ast::Instruction<TypedArgParams>,
+) -> Result<(), TranslateError> {
+ if let Some((return_deparams, _)) = deparams {
+ for (index, return_arg) in return_arguments.iter().enumerate() {
+ if !return_deparams
+ .get(index)
+ .ok_or_else(TranslateError::unreachable)?
+ {
+ continue;
+ }
+ let src = return_args_flush[&return_arg.name];
+ result.push(Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Param,
+ caching: ast::LdCacheOperator::Cached,
+ typ: return_arg.type_.clone(),
+ non_coherent: false,
+ },
+ ast::Arg2Ld {
+ dst: TypedOperand::Reg(return_arg.name),
+ src: TypedOperand::Reg(src),
+ },
+ )));
+ }
+ }
+ result.push(Statement::Instruction(ret));
+ Ok(())
}
-// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html
-fn emit_extensions(_builder: &mut dr::Builder) {
- // TODO: re-enable when Intel float control extension works
- //builder.extension("SPV_INTEL_float_controls2");
+fn deparamize_single_function_call(
+ id_defs: &mut IdNameMapBuilder,
+ result: &mut Vec<TypedStatement>,
+ call: ResolvedCall<TypedArgParams>,
+) -> Result<(), TranslateError> {
+ let input_arguments = call
+ .input_arguments
+ .into_iter()
+ .map(|(operand, type_, space)| match space {
+ ast::StateSpace::Param => {
+ let arg_id =
+ id_defs.register_intermediate(Some((type_.clone(), ast::StateSpace::Reg)));
+ result.push(Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Param,
+ caching: ast::LdCacheOperator::Cached,
+ typ: type_.clone(),
+ non_coherent: false,
+ },
+ ast::Arg2Ld {
+ dst: TypedOperand::Reg(arg_id),
+ src: operand,
+ },
+ )));
+ (TypedOperand::Reg(arg_id), type_, ast::StateSpace::Reg)
+ }
+ space => (operand, type_, space),
+ })
+ .collect::<Vec<_>>();
+ let mut post_statements = Vec::new();
+ let return_arguments = call
+ .return_arguments
+ .into_iter()
+ .map(|(operand, type_, state_space)| {
+ Ok(match state_space {
+ ast::StateSpace::Reg => (operand, type_, state_space),
+ ast::StateSpace::Param => {
+ let arg_id =
+ id_defs.register_intermediate(Some((type_.clone(), ast::StateSpace::Reg)));
+ post_statements.push(Statement::Instruction(ast::Instruction::St(
+ ast::StData {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Param,
+ caching: ast::StCacheOperator::Writeback,
+ typ: type_.clone(),
+ },
+ ast::Arg2St {
+ src1: TypedOperand::Reg(operand),
+ src2: TypedOperand::Reg(arg_id),
+ },
+ )));
+ (arg_id, type_, ast::StateSpace::Reg)
+ }
+ _ => return Err(TranslateError::unreachable()),
+ })
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ result.push(Statement::Call(ResolvedCall {
+ input_arguments,
+ return_arguments,
+ ..call
+ }));
+ result.extend(post_statements);
+ Ok(())
}
-fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word {
- builder.ext_inst_import("OpenCL.std")
+fn insert_hardware_registers<'input>(
+ module: TranslationModule<'input, TypedArgParams>,
+) -> Result<TranslationModule<'input, TypedArgParams>, TranslateError> {
+ convert_methods_simple(module, insert_hardware_registers_impl)
}
-fn emit_memory_model(builder: &mut dr::Builder) {
- builder.memory_model(
- spirv::AddressingModel::Physical64,
- spirv::MemoryModel::OpenCL,
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#extended-precision-integer-arithmetic-instructions
+// NVIDIA documentation is misleading. In fact there is no single CC.CF,
+// but separate registers for overflow (`add` and `mad`) and underflow (`sub`)
+// For reference check the .ptx tests
+fn insert_hardware_registers_impl<'input>(
+ id_defs: &mut IdNameMapBuilder<'input>,
+ typed_statements: Vec<TypedStatement>,
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(typed_statements.len());
+ let overflow_flag_var = id_defs.register_variable_def(
+ None,
+ ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ Some(ast::Initializer::Constant(ast::ImmediateValue::U64(0))),
+ );
+ let underflow_flag_var = id_defs.register_variable_def(
+ None,
+ ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ Some(ast::Initializer::Constant(ast::ImmediateValue::U64(0))),
);
+ let overflow_flag = overflow_flag_var.name;
+ let underflow_flag = underflow_flag_var.name;
+ result.push(Statement::Variable(overflow_flag_var));
+ result.push(Statement::Variable(underflow_flag_var));
+ for statement in typed_statements {
+ match statement {
+ Statement::Instruction(ast::Instruction::MadC {
+ type_,
+ is_hi,
+ arg,
+ carry_out,
+ }) => result.push(Statement::MadC(MadCDetails {
+ type_,
+ is_hi,
+ arg: Arg4CarryIn::new(arg, carry_out, TypedOperand::Reg(overflow_flag)),
+ })),
+ Statement::Instruction(ast::Instruction::MadCC { type_, arg }) => {
+ result.push(Statement::MadCC(MadCCDetails {
+ type_,
+ arg: Arg4CarryOut::new(arg, TypedOperand::Reg(overflow_flag)),
+ }))
+ }
+ Statement::Instruction(ast::Instruction::AddC(details, args)) => {
+ result.push(Statement::AddC(
+ details.type_,
+ Arg3CarryIn::new(args, details.carry_out, TypedOperand::Reg(overflow_flag)),
+ ))
+ }
+ Statement::Instruction(ast::Instruction::AddCC(details, args)) => {
+ result.push(Statement::AddCC(
+ details,
+ Arg3CarryOut::new(args, TypedOperand::Reg(overflow_flag)),
+ ))
+ }
+ Statement::Instruction(ast::Instruction::SubC(details, args)) => {
+ result.push(Statement::SubC(
+ details.type_,
+ Arg3CarryIn::new(args, details.carry_out, TypedOperand::Reg(underflow_flag)),
+ ))
+ }
+ Statement::Instruction(ast::Instruction::SubCC(details, args)) => {
+ result.push(Statement::SubCC(
+ details,
+ Arg3CarryOut::new(args, TypedOperand::Reg(underflow_flag)),
+ ))
+ }
+ s => result.push(s),
+ }
+ }
+ Ok(result)
}
-fn translate_directive<'input>(
- id_defs: &mut GlobalStringIdResolver<'input>,
- ptx_impl_imports: &mut HashMap<String, Directive<'input>>,
- d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
-) -> Result<Option<Directive<'input>>, TranslateError> {
- Ok(match d {
- ast::Directive::Variable(v) => Some(Directive::Variable(translate_variable(id_defs, v)?)),
- ast::Directive::Method(f) => {
- translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method)
- }
- })
+fn fix_special_registers<'input>(
+ module: TranslationModule<'input, TypedArgParams>,
+) -> Result<TranslationModule<'input, TypedArgParams>, TranslateError> {
+ convert_methods(module, fix_special_registers_impl)
}
-fn translate_variable<'a>(
- id_defs: &mut GlobalStringIdResolver<'a>,
- var: ast::Variable<ast::VariableType, &'a str>,
-) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
- let (space, var_type) = var.v_type.to_type();
- let mut is_variable = false;
- let var_type = match space {
- ast::StateSpace::Reg => {
- is_variable = true;
- var_type
- }
- ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?,
- ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
- ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
- ast::StateSpace::Shared => {
- // If it's a pointer it will be translated to a method parameter later
- if let ast::Type::Pointer(..) = var_type {
- is_variable = true;
- var_type
- } else {
- var_type.param_pointer_to(ast::LdStateSpace::Shared)?
+fn fix_special_registers_impl<'input>(
+ _: CompilationMode,
+ id_defs: &mut IdNameMapBuilder<'input>,
+ ptx_imports: &mut AdditionalFunctionDeclarations,
+ _: &mut [ast::VariableDeclaration<Id>],
+ _: &mut [ast::VariableDeclaration<Id>],
+ _: bool,
+ typed_statements: Vec<TypedStatement>,
+) -> Result<Vec<TypedStatement>, TranslateError> {
+ let result = Vec::with_capacity(typed_statements.len());
+ let mut sreg_sresolver = SpecialRegisterResolver {
+ ptx_imports,
+ id_defs,
+ result,
+ };
+ for s in typed_statements {
+ match s {
+ Statement::Call(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::Instruction(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::Conditional(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::Conversion(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
}
+ Statement::PtrAccess(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::RepackVector(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::MadC(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::MadCC(details) => {
+ let new_statement = details.visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::AddC(details, arg) => {
+ let new_statement = VisitAddC(details, arg).visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::AddCC(type_, arg) => {
+ let new_statement = VisitAddCC(type_, arg).visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::SubC(details, arg) => {
+ let new_statement = VisitSubC(details, arg).visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ Statement::SubCC(type_, arg) => {
+ let new_statement = VisitSubCC(type_, arg).visit(&mut sreg_sresolver)?;
+ sreg_sresolver.result.push(new_statement);
+ }
+ s @ Statement::Variable(_)
+ | s @ Statement::Constant(_)
+ | s @ Statement::Label(_)
+ | s @ Statement::FunctionPointer(_) => sreg_sresolver.result.push(s),
+ _ => return Err(TranslateError::unreachable()),
}
- ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
- };
- Ok(ast::Variable {
- align: var.align,
- v_type: var.v_type,
- name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
- array_init: var.array_init,
- })
+ }
+ Ok(sreg_sresolver.result)
}
-fn translate_function<'a>(
- id_defs: &mut GlobalStringIdResolver<'a>,
- ptx_impl_imports: &mut HashMap<String, Directive<'a>>,
- f: ast::ParsedFunction<'a>,
-) -> Result<Option<Function<'a>>, TranslateError> {
- let import_as = match &f.func_directive {
- ast::MethodDecl::Func(_, "__assertfail", _) => {
- Some("__zluda_ptx_impl____assertfail".to_owned())
- }
- _ => None,
- };
- let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
- let mut func = to_ssa(ptx_impl_imports, str_resolver, fn_resolver, fn_decl, f.body)?;
- func.import_as = import_as;
- if func.import_as.is_some() {
- ptx_impl_imports.insert(
- func.import_as.as_ref().unwrap().clone(),
- Directive::Method(func),
- );
- Ok(None)
- } else {
- Ok(Some(func))
+struct AdditionalFunctionDeclarations(
+ BTreeMap<
+ String,
+ (
+ Vec<ast::VariableDeclaration<Id>>,
+ Id,
+ Vec<ast::VariableDeclaration<Id>>,
+ ),
+ >,
+);
+
+impl AdditionalFunctionDeclarations {
+ fn new() -> Self {
+ Self(BTreeMap::new())
}
-}
-fn expand_kernel_params<'a, 'b>(
- fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::KernelArgument<&'a str>>,
-) -> Result<Vec<ast::KernelArgument<spirv::Word>>, TranslateError> {
- args.map(|a| {
- Ok(ast::KernelArgument {
- name: fn_resolver.add_def(
- a.name,
- Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?),
- false,
- ),
- v_type: a.v_type.clone(),
- align: a.align,
- array_init: Vec::new(),
+ fn add_or_get_declaration<'a>(
+ &mut self,
+ id_defs: &mut IdNameMapBuilder,
+ name: String,
+ return_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
+ input_arguments: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError> {
+ Ok(match self.0.entry(name) {
+ btree_map::Entry::Vacant(entry) => {
+ let fn_id = id_defs.register_intermediate(None);
+ let return_arguments = fn_arguments_to_variables(id_defs, return_arguments);
+ let input_arguments = fn_arguments_to_variables(id_defs, input_arguments);
+ entry.insert((return_arguments, fn_id, input_arguments));
+ fn_id
+ }
+ btree_map::Entry::Occupied(entry) => entry.get().1,
})
- })
- .collect::<Result<_, _>>()
-}
+ }
-fn expand_fn_params<'a, 'b>(
- fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
-) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
- args.map(|a| {
- let is_variable = match a.v_type {
- ast::FnArgumentType::Reg(_) => true,
- _ => false,
- };
- let var_type = a.v_type.to_func_type();
- Ok(ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some(var_type), is_variable),
- v_type: a.v_type.clone(),
- align: a.align,
- array_init: Vec::new(),
- })
- })
- .collect()
-}
-
-fn to_ssa<'input, 'b>(
- ptx_impl_imports: &mut HashMap<String, Directive>,
- mut id_defs: FnStringIdResolver<'input, 'b>,
- fn_defs: GlobalFnDeclResolver<'input, 'b>,
- f_args: ast::MethodDecl<'input, spirv::Word>,
- f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
-) -> Result<Function<'input>, TranslateError> {
- let mut spirv_decl = SpirvMethodDecl::new(&f_args);
- let f_body = match f_body {
- Some(vec) => vec,
- None => {
- return Ok(Function {
- func_decl: f_args,
+ fn flush<'input, P: ast::ArgParams<Id = Id>>(
+ self,
+ directives: &mut Vec<TranslationDirective<'input, P>>,
+ ) {
+ for (name, (return_arguments, id, input_arguments)) in self.0 {
+ directives.push(TranslationDirective::Method(TranslationMethod {
+ return_arguments,
+ name: id,
+ input_arguments,
body: None,
- globals: Vec::new(),
- import_as: None,
- spirv_decl,
- })
+ tuning: Vec::new(),
+ is_kernel: false,
+ source_name: Some(Cow::Owned(name)),
+ special_raytracing_linking: false,
+ }));
}
- };
- let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
- let mut numeric_id_defs = id_defs.finish();
- let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
- let typed_statements =
- convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
- let typed_statements =
- convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
- let ssa_statements = insert_mem_ssa_statements(
- typed_statements,
- &mut numeric_id_defs,
- &f_args,
- &mut spirv_decl,
- )?;
- let ssa_statements = fix_builtins(ssa_statements, &mut numeric_id_defs)?;
- let mut numeric_id_defs = numeric_id_defs.finish();
- let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
- let expanded_statements =
- insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?;
- let mut numeric_id_defs = numeric_id_defs.unmut();
- let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs);
- let (f_body, globals) =
- extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs);
- Ok(Function {
- func_decl: f_args,
- globals: globals,
- body: Some(f_body),
- import_as: None,
- spirv_decl,
- })
+ }
+}
+
+fn insert_mem_ssa_statements<'input>(
+ module: TranslationModule<'input, TypedArgParams>,
+) -> Result<TranslationModule<'input, TypedArgParams>, TranslateError> {
+ convert_methods(module, insert_mem_ssa_statements_impl)
}
-fn fix_builtins(
+fn insert_mem_ssa_statements_impl<'input>(
+ _: CompilationMode,
+ id_def: &mut IdNameMapBuilder<'input>,
+ _: &mut AdditionalFunctionDeclarations,
+ return_arguments: &mut [ast::VariableDeclaration<Id>],
+ input_arguments: &mut [ast::VariableDeclaration<Id>],
+ is_kernel: bool,
typed_statements: Vec<TypedStatement>,
- numeric_id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
let mut result = Vec::with_capacity(typed_statements.len());
- for s in typed_statements {
- match s {
- Statement::LoadVar(
- mut
- details
- @
- LoadVarDetails {
- member_index: Some((_, Some(_))),
- ..
- },
- ) => {
- let index = details.member_index.unwrap().0;
- if index == 3 {
- result.push(Statement::Constant(ConstantDefinition {
- dst: details.arg.dst,
- typ: ast::ScalarType::U32,
- value: ast::ImmediateValue::U64(0),
- }));
- } else {
- let sreg_and_type = match numeric_id_defs.special_registers.get(details.arg.src)
- {
- Some(reg) => get_sreg_id_scalar_type(numeric_id_defs, reg),
- None => None,
- };
- let (sreg_src, scalar_typ, vector_width) = match sreg_and_type {
- Some(sreg_and_type) => sreg_and_type,
- None => {
- result.push(Statement::LoadVar(details));
- continue;
- }
- };
- let temp_id = numeric_id_defs.new_non_variable(Some(details.typ.clone()));
- let real_dst = details.arg.dst;
- details.arg.dst = temp_id;
- result.push(Statement::LoadVar(LoadVarDetails {
- arg: Arg2 {
- src: sreg_src,
- dst: temp_id,
- },
- typ: ast::Type::Scalar(scalar_typ),
- member_index: Some((index, Some(vector_width))),
- }));
- result.push(Statement::Conversion(ImplicitConversion {
- src: temp_id,
- dst: real_dst,
- from: ast::Type::Scalar(scalar_typ),
- to: ast::Type::Scalar(ast::ScalarType::U32),
- kind: ConversionKind::Default,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
- }));
+ if !is_kernel {
+ for arg in input_arguments.iter_mut() {
+ insert_mem_ssa_argument(id_def, &mut result, arg);
+ }
+ }
+ for arg in return_arguments.iter() {
+ insert_mem_ssa_argument_reg_return(&mut result, arg);
+ }
+ for statement in typed_statements {
+ match statement {
+ Statement::Call(call) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, call.cast())?
+ }
+ Statement::Instruction(inst) => match inst {
+ ast::Instruction::Exit => {
+ insert_mma_ssa_statement_ret(
+ return_arguments,
+ &mut result,
+ ast::Instruction::Exit,
+ ast::RetData { uniform: false },
+ id_def,
+ );
}
+ ast::Instruction::Ret(d) => {
+ insert_mma_ssa_statement_ret(
+ return_arguments,
+ &mut result,
+ ast::Instruction::Ret(d),
+ d,
+ id_def,
+ );
+ }
+ inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
+ },
+ Statement::Conditional(bra) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, bra)?
}
- s => result.push(s),
+ Statement::Conversion(conv) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, conv)?
+ }
+ Statement::PtrAccess(ptr_access) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)?
+ }
+ Statement::RepackVector(repack) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, repack)?
+ }
+ Statement::FunctionPointer(func_ptr) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, func_ptr)?
+ }
+ Statement::MadC(madc) => insert_mem_ssa_statement_default(id_def, &mut result, madc)?,
+ Statement::MadCC(madcc) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, madcc)?
+ }
+ Statement::AddC(details, arg) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, VisitAddC(details, arg))?
+ }
+ Statement::AddCC(type_, arg) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, VisitAddCC(type_, arg))?
+ }
+ Statement::SubC(details, arg) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, VisitSubC(details, arg))?
+ }
+ Statement::SubCC(type_, arg) => {
+ insert_mem_ssa_statement_default(id_def, &mut result, VisitSubCC(type_, arg))?
+ }
+ s @ Statement::Variable(_) | s @ Statement::Label(_) | s @ Statement::Constant(..) => {
+ result.push(s)
+ }
+ _ => return Err(TranslateError::unreachable()),
}
}
Ok(result)
}
-fn get_sreg_id_scalar_type(
- numeric_id_defs: &mut NumericIdResolver,
- sreg: PtxSpecialRegister,
-) -> Option<(spirv::Word, ast::ScalarType, u8)> {
- match sreg.normalized_sreg_and_type() {
- Some((normalized_sreg, typ, vec_width)) => Some((
- numeric_id_defs
- .special_registers
- .get_or_add(numeric_id_defs.current_id, normalized_sreg),
- typ,
- vec_width,
- )),
- None => None,
+fn insert_mma_ssa_statement_ret(
+ return_arguments: &mut [ast::VariableDeclaration<Id>],
+ result: &mut Vec<Statement<ast::Instruction<TypedArgParams>, TypedArgParams>>,
+ zero_inst: ast::Instruction<TypedArgParams>,
+ d: ast::RetData,
+ id_def: &mut IdNameMapBuilder<'_>,
+) {
+ if return_arguments.len() == 0 {
+ result.push(Statement::Instruction(zero_inst));
+ } else {
+ let return_ids = return_arguments
+ .iter()
+ .map(|return_reg| {
+ let new_id = id_def
+ .register_intermediate(Some((return_reg.type_.clone(), ast::StateSpace::Reg)));
+ result.push(Statement::LoadVar(LoadVarDetails {
+ arg: ast::Arg2 {
+ dst: new_id,
+ src: return_reg.name,
+ },
+ // TODO: ret with stateful conversion
+ _state_space: ast::StateSpace::Reg,
+ typ: return_reg.type_.clone(),
+ member_index: None,
+ }));
+ (new_id, return_reg.type_.clone())
+ })
+ .collect::<Vec<_>>();
+ result.push(Statement::RetValue(d, return_ids));
}
}
-fn extract_globals<'input, 'b>(
- sorted_statements: Vec<ExpandedStatement>,
- ptx_impl_imports: &mut HashMap<String, Directive>,
- id_def: &mut NumericIdResolver,
-) -> (
- Vec<ExpandedStatement>,
- Vec<ast::Variable<ast::VariableType, spirv::Word>>,
-) {
- let mut local = Vec::with_capacity(sorted_statements.len());
- let mut global = Vec::new();
- for statement in sorted_statements {
+fn expand_arguments<'input>(
+ module: TranslationModule<'input, TypedArgParams>,
+) -> Result<TranslationModule<'input, ExpandedArgParams>, TranslateError> {
+ convert_methods_simple(module, expand_arguments2_impl)
+}
+
+fn expand_arguments2_impl<'input>(
+ id_defs: &mut IdNameMapBuilder<'input>,
+ fn_body: Vec<TypedStatement>,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(fn_body.len());
+ for statment in fn_body {
+ match statment {
+ Statement::Call(call) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_defs);
+ let (new_call, post_stmts) = (call.map(&mut visitor)?, visitor.post_stmts);
+ result.push(Statement::Call(new_call));
+ result.extend(post_stmts);
+ }
+ Statement::Instruction(inst) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_defs);
+ let (new_inst, post_stmts) = (inst.map(&mut visitor)?, visitor.post_stmts);
+ result.push(Statement::Instruction(new_inst));
+ result.extend(post_stmts);
+ }
+ Statement::Variable(Variable {
+ align,
+ type_,
+ state_space,
+ name,
+ initializer,
+ }) => result.push(Statement::Variable(Variable {
+ align,
+ type_,
+ state_space,
+ name,
+ initializer,
+ })),
+ Statement::PtrAccess(ptr_access) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_defs);
+ let (new_inst, post_stmts) = (ptr_access.map(&mut visitor)?, visitor.post_stmts);
+ result.push(Statement::PtrAccess(new_inst));
+ result.extend(post_stmts);
+ }
+ Statement::RepackVector(repack) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_defs);
+ let (new_inst, post_stmts) = (repack.map(&mut visitor)?, visitor.post_stmts);
+ result.push(Statement::RepackVector(new_inst));
+ result.extend(post_stmts);
+ }
+ Statement::MadC(madc) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_defs);
+ let (new_inst, post_stmts) = (madc.visit(&mut visitor)?, visitor.post_stmts);
+ result.push(new_inst);
+ result.extend(post_stmts);
+ }
+ Statement::MadCC(madcc) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_defs);
+ let (new_inst, post_stmts) = (madcc.visit(&mut visitor)?, visitor.post_stmts);
+ result.push(new_inst);
+ result.extend(post_stmts);
+ }
+ Statement::AddC(details, arg) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_defs);
+ let (new_inst, post_stmts) = (
+ VisitAddC(details, arg).visit(&mut visitor)?,
+ visitor.post_stmts,
+ );
+ result.push(new_inst);
+ result.extend(post_stmts);
+ }
+ Statement::AddCC(type_, arg) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_defs);
+ let (new_inst, post_stmts) = (
+ VisitAddCC(type_, arg).visit(&mut visitor)?,
+ visitor.post_stmts,
+ );
+ result.push(new_inst);
+ result.extend(post_stmts);
+ }
+ Statement::SubC(details, arg) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_defs);
+ let (new_inst, post_stmts) = (
+ VisitSubC(details, arg).visit(&mut visitor)?,
+ visitor.post_stmts,
+ );
+ result.push(new_inst);
+ result.extend(post_stmts);
+ }
+ Statement::SubCC(type_, arg) => {
+ let mut visitor = FlattenArguments::new(&mut result, id_defs);
+ let (new_inst, post_stmts) = (
+ VisitSubCC(type_, arg).visit(&mut visitor)?,
+ visitor.post_stmts,
+ );
+ result.push(new_inst);
+ result.extend(post_stmts);
+ }
+ Statement::Label(id) => result.push(Statement::Label(id)),
+ Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
+ Statement::LoadVar(details) => result.push(Statement::LoadVar(details)),
+ Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
+ Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
+ Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
+ Statement::Constant(c) => result.push(Statement::Constant(c)),
+ Statement::FunctionPointer(d) => result.push(Statement::FunctionPointer(d)),
+ Statement::AsmVolatile { asm, constraints } => {
+ result.push(Statement::AsmVolatile { asm, constraints })
+ }
+ }
+ }
+ Ok(result)
+}
+
+/*
+ There are several kinds of implicit conversions in PTX:
+ * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
+ * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
+ - ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
+ semantics are to first zext/chop/bitcast `y` as needed and then do
+ documented special ld/st/cvt conversion rules for destination operands
+ - st.param [x] y (used as function return arguments) same rule as above applies
+ - generic/global ld: for instruction `ld x, [y]`, y must be of type
+ b64/u64/s64, which is bitcast to a pointer, dereferenced and then
+ documented special ld/st/cvt conversion rules are applied to dst
+ - generic/global st: for instruction `st [x], y`, x must be of type
+ b64/u64/s64, which is bitcast to a pointer
+ * and many more
+*/
+fn insert_implicit_conversions<'input>(
+ module: TranslationModule<'input, ExpandedArgParams>,
+) -> Result<TranslationModule<'input, ExpandedArgParams>, TranslateError> {
+ convert_methods_simple(module, insert_implicit_conversions2_impl)
+}
+
+fn insert_implicit_conversions2_impl<'input>(
+ id_def: &mut IdNameMapBuilder<'input>,
+ fn_body: Vec<ExpandedStatement>,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(fn_body.len());
+ for statement in fn_body {
match statement {
- Statement::Variable(
- var
- @
- ast::Variable {
- v_type: ast::VariableType::Shared(_),
- ..
+ Statement::Call(call) => {
+ insert_implicit_conversions_impl(&mut result, id_def, call)?;
+ }
+ Statement::Instruction(inst) => {
+ insert_implicit_conversions_impl(&mut result, id_def, inst)?;
+ }
+ Statement::PtrAccess(access) => {
+ insert_implicit_conversions_impl(&mut result, id_def, access)?;
+ }
+ Statement::RepackVector(repack) => {
+ insert_implicit_conversions_impl(&mut result, id_def, repack)?;
+ }
+ Statement::MadC(madc) => {
+ insert_implicit_conversions_impl(&mut result, id_def, madc)?;
+ }
+ Statement::MadCC(madcc) => {
+ insert_implicit_conversions_impl(&mut result, id_def, madcc)?;
+ }
+ Statement::AddC(details, arg) => {
+ insert_implicit_conversions_impl(&mut result, id_def, VisitAddC(details, arg))?;
+ }
+ Statement::AddCC(type_, arg) => {
+ insert_implicit_conversions_impl(&mut result, id_def, VisitAddCC(type_, arg))?;
+ }
+ Statement::SubC(details, arg) => {
+ insert_implicit_conversions_impl(&mut result, id_def, VisitSubC(details, arg))?;
+ }
+ Statement::SubCC(type_, arg) => {
+ insert_implicit_conversions_impl(&mut result, id_def, VisitSubCC(type_, arg))?;
+ }
+ s @ Statement::Conditional(_)
+ | s @ Statement::Conversion(_)
+ | s @ Statement::Label(_)
+ | s @ Statement::Constant(_)
+ | s @ Statement::Variable(_)
+ | s @ Statement::LoadVar(..)
+ | s @ Statement::StoreVar(..)
+ | s @ Statement::RetValue(..)
+ | s @ Statement::AsmVolatile { .. }
+ | s @ Statement::FunctionPointer(..) => result.push(s),
+ }
+ }
+ Ok(result)
+}
+
+fn normalize_labels<'input>(
+ module: TranslationModule<'input, ExpandedArgParams>,
+) -> Result<TranslationModule<'input, ExpandedArgParams>, TranslateError> {
+ convert_methods_simple(module, normalize_labels2_impl)
+}
+
+fn normalize_labels2_impl<'input>(
+ id_defs: &mut IdNameMapBuilder<'input>,
+ fn_body: Vec<ExpandedStatement>,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut labels_in_use = FxHashSet::default();
+ for statement in fn_body.iter() {
+ match statement {
+ Statement::Instruction(i) => {
+ if let Some(target) = i.jump_target() {
+ labels_in_use.insert(target);
+ }
+ }
+ Statement::Conditional(cond) => {
+ labels_in_use.insert(cond.if_true);
+ labels_in_use.insert(cond.if_false);
+ }
+ Statement::Call(..)
+ | Statement::Variable(..)
+ | Statement::LoadVar(..)
+ | Statement::StoreVar(..)
+ | Statement::RetValue(..)
+ | Statement::Conversion(..)
+ | Statement::Constant(..)
+ | Statement::Label(..)
+ | Statement::PtrAccess { .. }
+ | Statement::RepackVector(..)
+ | Statement::MadC(..)
+ | Statement::MadCC(..)
+ | Statement::AddC(..)
+ | Statement::AddCC(..)
+ | Statement::SubC(..)
+ | Statement::SubCC(..)
+ | Statement::AsmVolatile { .. }
+ | Statement::FunctionPointer(..) => {}
+ }
+ }
+ Ok(
+ iter::once(Statement::Label(id_defs.register_intermediate(None)))
+ .chain(fn_body.into_iter().filter(|s| match s {
+ Statement::Label(i) => labels_in_use.contains(i),
+ _ => true,
+ }))
+ .collect::<Vec<_>>(),
+ )
+}
+
+fn hoist_globals<'input, P: ast::ArgParams<Id = Id>>(
+ module: TranslationModule<'input, P>,
+) -> TranslationModule<'input, P> {
+ let mut directives = Vec::with_capacity(module.directives.len());
+ for directive in module.directives {
+ match directive {
+ var @ TranslationDirective::Variable(..) => directives.push(var),
+ TranslationDirective::Method(method) => {
+ let body = method.body.map(|body| {
+ body.into_iter()
+ .filter_map(|statement| match statement {
+ Statement::Variable(
+ var @ Variable {
+ state_space: ast::StateSpace::Shared,
+ ..
+ },
+ )
+ | Statement::Variable(
+ var @ Variable {
+ state_space: ast::StateSpace::Global,
+ ..
+ },
+ ) => {
+ directives.push(TranslationDirective::Variable(
+ ast::LinkingDirective::None,
+ None,
+ var,
+ ));
+ None
+ }
+ statement => Some(statement),
+ })
+ .collect::<Vec<_>>()
+ });
+ directives.push(TranslationDirective::Method(TranslationMethod {
+ body,
+ ..method
+ }))
+ }
+ }
+ }
+ {
+ TranslationModule {
+ directives,
+ ..module
+ }
+ }
+}
+
+fn replace_instructions_with_builtins<'input>(
+ module: TranslationModule<'input, ExpandedArgParams>,
+) -> Result<TranslationModule<'input, ExpandedArgParams>, TranslateError> {
+ convert_methods(module, replace_instructions_with_builtins_impl)
+}
+
+fn replace_instructions_with_builtins_impl<'input>(
+ compilation_mode: CompilationMode,
+ id_def: &mut IdNameMapBuilder<'input>,
+ ptx_impl_imports: &mut AdditionalFunctionDeclarations,
+ _: &mut [ast::VariableDeclaration<Id>],
+ _: &mut [ast::VariableDeclaration<Id>],
+ _: bool,
+ fn_body: Vec<ExpandedStatement>,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut statements = Vec::with_capacity(fn_body.len());
+ for statement in fn_body {
+ match statement {
+ Statement::Instruction(ast::Instruction::Nanosleep(arg)) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "nanosleep_u32"].concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Nanosleep(arg),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::MatchAny(arg)) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "match_any_sync_b32"].concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::MatchAny(arg),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Dp4a(type_, arg)) => {
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "dp4a_",
+ type_.to_ptx_name(),
+ "_",
+ type_.to_ptx_name(),
+ ]
+ .concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Dp4a(type_, arg),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::BarRed(op, arg3)) => {
+ let op_name = match op {
+ ast::ReductionOp::And => "and",
+ ast::ReductionOp::Or => "or",
+ ast::ReductionOp::Popc => "popc",
+ };
+ let dst_type = op.dst_type().to_ptx_name();
+ let fn_name = [ZLUDA_PTX_PREFIX, "bar_red_", op_name, "_", dst_type].concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::BarRed(op, arg3),
+ fn_name,
+ )?);
+ }
+ // We dispatch vote_sync_... by compilation mode suffix
+ // because LLVM crashes if there are both calls to `llvm.amdgcn.ballot.i32` and
+ // `llvm.amdgcn.ballot.i64` inside same function (even if only one of them can be called)
+ Statement::Instruction(ast::Instruction::Vote(
+ ast::VoteDetails {
+ mode: ast::VoteMode::Any,
+ negate_pred,
},
- )
- | Statement::Variable(
- var
- @
- ast::Variable {
- v_type: ast::VariableType::Global(_),
- ..
+ arg,
+ )) => {
+ let instr_suffix = if negate_pred { "_negate" } else { "" };
+ let mode_suffix = compilation_mode_suffix(compilation_mode);
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "vote_sync_any_pred",
+ instr_suffix,
+ mode_suffix,
+ ]
+ .concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Vote(
+ ast::VoteDetails {
+ mode: ast::VoteMode::Any,
+ negate_pred,
+ },
+ arg,
+ ),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Vote(
+ ast::VoteDetails {
+ mode: ast::VoteMode::All,
+ negate_pred,
},
- ) => global.push(var),
+ arg,
+ )) => {
+ let instr_suffix = if negate_pred { "_negate" } else { "" };
+ let mode_suffix = compilation_mode_suffix(compilation_mode);
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "vote_sync_all_pred",
+ instr_suffix,
+ mode_suffix,
+ ]
+ .concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Vote(
+ ast::VoteDetails {
+ mode: ast::VoteMode::All,
+ negate_pred,
+ },
+ arg,
+ ),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Vote(
+ ast::VoteDetails {
+ mode: ast::VoteMode::Ballot,
+ negate_pred,
+ },
+ arg,
+ )) => {
+ let instr_suffix = if negate_pred { "_negate" } else { "" };
+ let mode_suffix = compilation_mode_suffix(compilation_mode);
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "vote_sync_ballot_b32",
+ instr_suffix,
+ mode_suffix,
+ ]
+ .concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Vote(
+ ast::VoteDetails {
+ mode: ast::VoteMode::Ballot,
+ negate_pred,
+ },
+ arg,
+ ),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Bar(details, arg)) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "barrier_sync"].concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Bar(details, arg),
+ fn_name,
+ )?);
+ }
Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => {
- local.push(to_ptx_impl_bfe_call(id_def, ptx_impl_imports, typ, arg));
+ let fn_name = [ZLUDA_PTX_PREFIX, "bfe_", typ.to_ptx_name()].concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Bfe { typ, arg },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Bfi { typ, arg }) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "bfi_", typ.to_ptx_name()].concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Bfi { typ, arg },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Activemask { arg }) => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "activemask"].concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Activemask { arg },
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Tex(tex, arg)) => {
+ let geometry = tex.geometry.as_ptx();
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "tex",
+ tex.suffix(),
+ "_",
+ geometry,
+ "_v4",
+ "_",
+ tex.channel_type.to_ptx_name(),
+ "_",
+ tex.coordinate_type.to_ptx_name(),
+ ]
+ .concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Tex(tex, arg),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Shfl(shfl_mode, arg))
+ if arg.dst2.is_none() =>
+ {
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "shfl_",
+ shfl_mode.to_ptx_name(),
+ "_b32_slow",
+ ]
+ .concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Shfl(shfl_mode, arg),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Shfl(shfl_mode, arg))
+ if arg.dst2.is_some() =>
+ {
+ replace_shfl_with_pred(id_def, ptx_impl_imports, &mut statements, shfl_mode, arg)?;
+ }
+ Statement::Instruction(ast::Instruction::Cvt(
+ ast::CvtDetails::FloatFromFloat(ast::CvtDesc {
+ saturate: true,
+ src,
+ dst,
+ rounding,
+ flush_to_zero,
+ }),
+ arg,
+ )) if src == dst => {
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "cvt_sat_",
+ dst.to_ptx_name(),
+ "_",
+ dst.to_ptx_name(),
+ ]
+ .concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Cvt(
+ ast::CvtDetails::FloatFromFloat(ast::CvtDesc {
+ saturate: true,
+ src,
+ dst,
+ rounding,
+ flush_to_zero,
+ }),
+ arg,
+ ),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Shf(
+ ast::FunnelShift {
+ direction,
+ mode: ast::ShiftNormalization::Clamp,
+ },
+ arg,
+ )) => {
+ let direction_str = match direction {
+ ast::FunnelDirection::Left => "l",
+ ast::FunnelDirection::Right => "r",
+ };
+ let fn_name = [ZLUDA_PTX_PREFIX, "shf_", direction_str, "_clamp_b32"].concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Shf(
+ ast::FunnelShift {
+ direction,
+ mode: ast::ShiftNormalization::Clamp,
+ },
+ arg,
+ ),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Suld(suld, arg)) => {
+ let geometry = suld.geometry.as_ptx();
+ let vector = suld.vector_ptx()?;
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "suld_b_",
+ suld.suffix(),
+ geometry,
+ vector,
+ "_",
+ suld.type_.to_ptx_name(),
+ "_trap",
+ ]
+ .concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Suld(suld, arg),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Sust(sust, arg)) => {
+ let geometry = sust.geometry.as_ptx();
+ let vector = sust.vector_ptx()?;
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "sust_b_",
+ sust.suffix(),
+ geometry,
+ vector,
+ "_",
+ sust.type_.to_ptx_name(),
+ "_trap",
+ ]
+ .concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Sust(sust, arg),
+ fn_name,
+ )?);
}
Statement::Instruction(ast::Instruction::Atom(
- d
- @
- ast::AtomDetails {
+ details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Inc,
@@ -1462,20 +2927,28 @@ fn extract_globals<'input, 'b>(
},
..
},
- a,
+ args,
)) => {
- local.push(to_ptx_impl_atomic_call(
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "atom_",
+ details.semantics.to_ptx_name(),
+ "_",
+ details.scope.to_ptx_name(),
+ "_",
+ details.space.to_ptx_name(),
+ "_inc",
+ ]
+ .concat();
+ statements.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
- d,
- a,
- "inc",
- ));
+ ast::Instruction::Atom(details, args),
+ fn_name,
+ )?);
}
Statement::Instruction(ast::Instruction::Atom(
- d
- @
- ast::AtomDetails {
+ details @ ast::AtomDetails {
inner:
ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Dec,
@@ -1483,114 +2956,1408 @@ fn extract_globals<'input, 'b>(
},
..
},
- a,
+ args,
)) => {
- local.push(to_ptx_impl_atomic_call(
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "atom_",
+ details.semantics.to_ptx_name(),
+ "_",
+ details.scope.to_ptx_name(),
+ "_",
+ details.space.to_ptx_name(),
+ "_dec",
+ ]
+ .concat();
+ statements.push(instruction_to_fn_call(
id_def,
ptx_impl_imports,
- d,
- a,
- "dec",
- ));
+ ast::Instruction::Atom(details, args),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Cvt(
+ ast::CvtDetails::FloatFromInt(desc),
+ args,
+ )) => extract_global_cvt(
+ &mut statements,
+ ptx_impl_imports,
+ id_def,
+ desc.clone(),
+ ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(desc), args),
+ )?,
+ Statement::Instruction(ast::Instruction::Cvt(
+ ast::CvtDetails::IntFromFloat(desc),
+ args,
+ )) => extract_global_cvt(
+ &mut statements,
+ ptx_impl_imports,
+ id_def,
+ desc.clone(),
+ ast::Instruction::Cvt(ast::CvtDetails::IntFromFloat(desc), args),
+ )?,
+ Statement::Instruction(ast::Instruction::Cvt(
+ ast::CvtDetails::FloatFromFloat(desc),
+ args,
+ )) if desc.dst.size_of() < desc.src.size_of() => extract_global_cvt(
+ &mut statements,
+ ptx_impl_imports,
+ id_def,
+ desc.clone(),
+ ast::Instruction::Cvt(ast::CvtDetails::FloatFromFloat(desc), args),
+ )?,
+ Statement::Instruction(ast::Instruction::Mul(
+ ast::MulDetails::Signed(ast::MulInt {
+ control: ast::MulIntControl::High,
+ typ,
+ }),
+ args,
+ ))
+ | Statement::Instruction(ast::Instruction::Mul(
+ ast::MulDetails::Unsigned(ast::MulInt {
+ control: ast::MulIntControl::High,
+ typ,
+ }),
+ args,
+ )) if typ == ast::ScalarType::U64 || typ == ast::ScalarType::S64 => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "mul_hi_", typ.to_ptx_name()].concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Mul(
+ ast::MulDetails::Signed(ast::MulInt {
+ control: ast::MulIntControl::High,
+ typ,
+ }),
+ args,
+ ),
+ fn_name,
+ )?);
+ }
+ Statement::Instruction(ast::Instruction::Mad(
+ ast::MulDetails::Signed(ast::MulInt {
+ control: ast::MulIntControl::High,
+ typ,
+ }),
+ args,
+ ))
+ | Statement::Instruction(ast::Instruction::Mad(
+ ast::MulDetails::Unsigned(ast::MulInt {
+ control: ast::MulIntControl::High,
+ typ,
+ }),
+ args,
+ )) if typ == ast::ScalarType::U64 || typ == ast::ScalarType::S64 => {
+ let fn_name = [ZLUDA_PTX_PREFIX, "mad_hi_", typ.to_ptx_name()].concat();
+ statements.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ ast::Instruction::Mad(
+ ast::MulDetails::Signed(ast::MulInt {
+ control: ast::MulIntControl::High,
+ typ,
+ }),
+ args,
+ ),
+ fn_name,
+ )?);
}
- s => local.push(s),
+ s => statements.push(s),
}
}
- (local, global)
+ Ok(statements)
}
-fn normalize_variable_decls(directives: &mut Vec<Directive>) {
- for directive in directives {
+fn compilation_mode_suffix(compilation_mode: CompilationMode) -> &'static str {
+ match compilation_mode {
+ CompilationMode::Wave32 => "_32",
+ CompilationMode::Wave32OnWave64 => "_32on64",
+ CompilationMode::DoubleWave32OnWave64 => "_double32on64",
+ }
+}
+
+fn replace_shfl_with_pred<'input>(
+ id_defs: &mut IdNameMapBuilder<'input>,
+ ptx_impl_imports: &mut AdditionalFunctionDeclarations,
+ statements: &mut Vec<ExpandedStatement>,
+ shfl_mode: ast::ShflMode,
+ arg: ast::Arg5Shfl<ExpandedArgParams>,
+) -> Result<(), TranslateError> {
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "shfl_",
+ shfl_mode.to_ptx_name(),
+ "_b32_pred_slow",
+ ]
+ .concat();
+ let inst = ast::Instruction::Shfl(shfl_mode, arg);
+ let mut arguments = Vec::new();
+ inst.visit(
+ &mut |desc: ArgumentDescriptor<Id>, typ: Option<(&ast::Type, ast::StateSpace)>| {
+ let (typ, space) = match typ {
+ Some((typ, space)) => (typ.clone(), space),
+ None => return Err(TranslateError::unreachable()),
+ };
+ let result = desc.op;
+ arguments.push((desc, typ, space));
+ Ok(result)
+ },
+ )?;
+ let return_arguments_count = arguments
+ .iter()
+ .position(|(desc, _, _)| !desc.is_dst)
+ .unwrap_or(arguments.len());
+ let (original_return_arguments, input_arguments) =
+ arguments.split_at_mut(return_arguments_count);
+ // Builtin call returns <2 x i32>, we have to unpack the vector and insert
+ // conversion for predicate
+ let call_return_arguments = [(
+ id_defs.register_intermediate(Some((
+ ast::Type::Vector(ast::ScalarType::U32, 2),
+ ast::StateSpace::Reg,
+ ))),
+ ast::Type::Vector(ast::ScalarType::U32, 2),
+ ast::StateSpace::Reg,
+ )];
+ let fn_id = ptx_impl_imports.add_or_get_declaration(
+ id_defs,
+ fn_name,
+ call_return_arguments
+ .iter()
+ .map(|(_, typ, state)| (typ, *state)),
+ input_arguments.iter().map(|(_, typ, state)| (typ, *state)),
+ )?;
+ statements.push(Statement::Call(ResolvedCall {
+ uniform: false,
+ name: fn_id,
+ return_arguments: call_return_arguments.to_vec(),
+ input_arguments: arguments_to_resolved_arguments(input_arguments),
+ is_indirect: false,
+ }));
+ let unpacked_elements = [
+ original_return_arguments[0].0.op,
+ id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ ))),
+ ];
+ statements.push(Statement::RepackVector(RepackVectorDetails {
+ is_extract: true,
+ typ: ast::ScalarType::U32,
+ packed: call_return_arguments[0].0,
+ unpacked: unpacked_elements.to_vec(),
+ non_default_implicit_conversion: None,
+ }));
+ let constant_1 = id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ )));
+ statements.push(Statement::Constant(ConstantDefinition {
+ dst: constant_1,
+ typ: ast::ScalarType::U32,
+ value: ast::ImmediateValue::U64(1),
+ }));
+ statements.push(Statement::Instruction(ast::Instruction::Setp(
+ ast::SetpData {
+ typ: ast::ScalarType::U32,
+ flush_to_zero: None,
+ cmp_op: ast::SetpCompareOp::Eq,
+ },
+ ast::Arg4Setp {
+ dst1: original_return_arguments[1].0.op,
+ dst2: None,
+ src1: unpacked_elements[1],
+ src2: constant_1,
+ },
+ )));
+ Ok(())
+}
+
+// HACK ALERT!
+// This function is a "good enough" heuristic of whetever to mark f16/f32 operations
+// in the kernel as flushing denorms to zero or preserving them
+// PTX support per-instruction ftz information. Unfortunately LLVM has no
+// such capability, so instead we guesstimate which use is more common in a
+// method and emit suitable attributes
+fn compute_denorm_statistics<'input, P: ast::ArgParams<Id = Id>>(
+ module: &TranslationModule<'input, P>,
+) -> FxHashMap<Id, DenormSummary> {
+ let mut denorm_methods = FxHashMap::default();
+ for directive in module.directives.iter() {
match directive {
- Directive::Method(Function {
- body: Some(func), ..
+ TranslationDirective::Variable(..)
+ | TranslationDirective::Method(TranslationMethod { body: None, .. }) => {}
+ TranslationDirective::Method(TranslationMethod {
+ name,
+ body: Some(statements),
+ ..
}) => {
- func[1..].sort_by_key(|s| match s {
- Statement::Variable(_) => 0,
- _ => 1,
- });
+ let mut fp32_flush_count = 0isize;
+ let mut nonfp32_flush_count = 0isize;
+ for statement in statements {
+ match statement {
+ Statement::Instruction(inst) => match inst.flush_to_zero() {
+ Some((flush, 4)) => {
+ fp32_flush_count += if flush { 1 } else { -1 };
+ }
+ Some((flush, _)) => {
+ nonfp32_flush_count += if flush { 1 } else { -1 };
+ }
+ None => {}
+ },
+ Statement::LoadVar(..) => {}
+ Statement::StoreVar(..) => {}
+ Statement::Call(_) => {}
+ Statement::Conditional(_) => {}
+ Statement::Conversion(_) => {}
+ Statement::Constant(_) => {}
+ Statement::RetValue(_, _) => {}
+ Statement::Label(_) => {}
+ Statement::Variable(_) => {}
+ Statement::PtrAccess { .. } => {}
+ Statement::RepackVector(_) => {}
+ Statement::FunctionPointer(_) => {}
+ Statement::MadC(_) => {}
+ Statement::MadCC(_) => {}
+ Statement::AddC(..) => {}
+ Statement::AddCC(..) => {}
+ Statement::SubC(..) => {}
+ Statement::SubCC(..) => {}
+ Statement::AsmVolatile { .. } => {}
+ }
+ }
+ let summary = DenormSummary {
+ f32: if fp32_flush_count > 0 {
+ FPDenormMode::FlushToZero
+ } else {
+ FPDenormMode::Preserve
+ },
+ f16f64: if nonfp32_flush_count > 0 {
+ FPDenormMode::FlushToZero
+ } else {
+ FPDenormMode::Preserve
+ },
+ };
+ denorm_methods.insert(*name, summary);
}
- _ => (),
}
}
+ denorm_methods
}
-fn convert_to_typed_statements(
- func: Vec<UnconditionalStatement>,
- fn_defs: &GlobalFnDeclResolver,
- id_defs: &mut NumericIdResolver,
-) -> Result<Vec<TypedStatement>, TranslateError> {
- let mut result = Vec::<TypedStatement>::with_capacity(func.len());
- for s in func {
- match s {
- Statement::Instruction(inst) => match inst {
- ast::Instruction::Call(call) => {
- // TODO: error out if lengths don't match
- let fn_def = fn_defs.get_fn_decl(call.func)?;
- let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
- let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params);
- let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args
- .into_iter()
- .partition(|(_, arg_type)| arg_type.is_param());
- let normalized_input_args = out_params
+#[derive(Copy, Clone)]
+pub(crate) struct DenormSummary {
+ pub(crate) f32: FPDenormMode,
+ pub(crate) f16f64: FPDenormMode,
+}
+
+pub fn to_llvm_module<'input>(
+ compilation_mode: CompilationMode,
+ ast: Vec<ast::Module<'input>>,
+) -> Result<Module, TranslateError> {
+ to_llvm_module_impl2(compilation_mode, ast, None)
+}
+
+pub fn to_llvm_module_for_raytracing<'input>(
+ ast: ast::Module<'input>,
+ raytracing_fn: &str,
+ cumulative_attribute_variables: &VariablesBlock,
+) -> Result<RaytracingModule<'input>, TranslateError> {
+ let mut raytracing_state =
+ RaytracingTranslationState::new(raytracing_fn, cumulative_attribute_variables);
+ let compilation_module = to_llvm_module_impl2(
+ CompilationMode::Wave32,
+ vec![ast],
+ Some(&mut raytracing_state),
+ )?;
+ let entry_point_kind: RaytracingEntryPointKind = raytracing_state.entry_point_kind.unwrap();
+ let rt_section = hip_common::kernel_metadata::zluda_rt6::write(
+ &raytracing_state.new_attribute_variables,
+ &raytracing_state.variables,
+ entry_point_kind == RaytracingEntryPointKind::Callable,
+ );
+ let mut linker_module = Vec::new();
+ emit::emit_section(
+ hip_common::kernel_metadata::zluda_rt6::SECTION_STR,
+ &rt_section,
+ &mut linker_module,
+ );
+ Ok(RaytracingModule::new(
+ raytracing_state.kernel_name,
+ compilation_module,
+ raytracing_state.variables,
+ entry_point_kind,
+ raytracing_state.new_attribute_variables,
+ linker_module,
+ ))
+}
+
+pub(crate) struct RaytracingTranslationState<'a, 'input> {
+ pub(crate) entry_point_str: &'a str,
+ pub(crate) entry_point_id: Option<Id>,
+ pub(crate) entry_point_kind: Option<RaytracingEntryPointKind>,
+ pub(crate) kernel_name: String,
+ pub(crate) buffers: FxHashMap<Id, Cow<'input, str>>,
+ pub(crate) variables: VariablesBlock,
+ pub(crate) old_attribute_variables: &'a VariablesBlock,
+ pub(crate) new_attribute_variables: VariablesBlock,
+ pub(crate) reachable_user_functions: FxHashSet<Id>,
+}
+
+impl<'a, 'input> RaytracingTranslationState<'a, 'input> {
+ fn new(entry_point_str: &'a str, cumulative_attribute_variables: &'a VariablesBlock) -> Self {
+ Self {
+ entry_point_str,
+ old_attribute_variables: cumulative_attribute_variables,
+ entry_point_id: None,
+ entry_point_kind: None,
+ kernel_name: String::new(),
+ buffers: FxHashMap::default(),
+ variables: VariablesBlock::empty(),
+ new_attribute_variables: VariablesBlock::empty(),
+ reachable_user_functions: FxHashSet::default(),
+ }
+ }
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+pub(crate) enum RaytracingEntryPointKind {
+ BoundingBox,
+ Intersection,
+ Callable,
+ // Closest hit, any hit, attribute, ray generation, exception, miss
+ Unknown,
+}
+
+fn to_llvm_module_impl2<'a, 'input>(
+ compilation_mode: CompilationMode,
+ asts: Vec<ast::Module<'input>>,
+ mut raytracing: Option<&mut RaytracingTranslationState<'a, 'input>>,
+) -> Result<Module<'input>, TranslateError> {
+ let empty_module = if raytracing.is_some() {
+ raytracing::create_module_with_builtins()
+ } else {
+ TranslationModule::new(compilation_mode)
+ };
+ let linking = resolve_linking(&*asts, raytracing.is_some())?;
+ let (mut translation_module, functions) =
+ link_and_normalize_modules(asts, empty_module, linking)?;
+ if let Some(ref mut raytracing_state) = raytracing {
+ translation_module = raytracing::run_on_normalized(translation_module, raytracing_state)?;
+ }
+ let translation_module = extract_builtin_functions(translation_module);
+ let translation_module = resolve_instruction_types(translation_module, functions)?;
+ let mut translation_module = restructure_function_return_types(translation_module)?;
+ if let Some(ref mut raytracing_state) = raytracing {
+ translation_module = raytracing::run_on_typed(translation_module, raytracing_state)?;
+ }
+ let translation_module = deparamize_function_declarations(translation_module)?;
+ let translation_module = insert_hardware_registers(translation_module)?;
+ let translation_module = fix_special_registers(translation_module)?;
+ let translation_module = insert_mem_ssa_statements(translation_module)?;
+ let translation_module = expand_arguments(translation_module)?;
+ let mut translation_module = deparamize_variable_declarations(translation_module)?;
+ if let Some(ref mut raytracing_state) = raytracing {
+ // raytracing passes rely heavily on particular PTX patterns, they must run before implicit conversions
+ translation_module = raytracing::postprocess(translation_module, raytracing_state)?;
+ }
+ let translation_module = insert_implicit_conversions(translation_module)?;
+ let translation_module = insert_compilation_mode_prologue(translation_module);
+ let translation_module = normalize_labels(translation_module)?;
+ let translation_module = hoist_globals(translation_module);
+ let translation_module = move_variables_to_start(translation_module)?;
+ let mut translation_module = replace_instructions_with_builtins(translation_module)?;
+ if raytracing.is_some() {
+ translation_module = raytracing::replace_tex_builtins_hack(translation_module)?;
+ }
+ let call_graph = CallGraph::new(&translation_module.directives);
+ let translation_module = convert_dynamic_shared_memory_usage(translation_module, &call_graph)?;
+ let denorm_statistics = compute_denorm_statistics(&translation_module);
+ let kernel_arguments = get_kernel_arguments(&translation_module.directives)?;
+ let mut bitcode_modules = vec![ZLUDA_PTX_IMPL_AMD];
+ if raytracing.is_some() {
+ bitcode_modules.push(raytracing::bitcode());
+ }
+ let metadata = create_metadata(&translation_module);
+ let (llvm_context, llvm_module) = unsafe {
+ emit::emit_llvm_bitcode_and_linker_module(translation_module, denorm_statistics)?
+ };
+ Ok(Module {
+ metadata,
+ compilation_mode,
+ llvm_module,
+ kernel_arguments,
+ _llvm_context: llvm_context,
+ bitcode_modules,
+ })
+}
+
+// From "Performance Tips for Frontend Authors" (https://llvm.org/docs/Frontend/PerformanceTips.html):
+// "The SROA (Scalar Replacement Of Aggregates) and Mem2Reg passes only attempt to eliminate alloca
+// instructions that are in the entry basic block. Given SSA is the canonical form expected by much
+// of the optimizer; if allocas can not be eliminated by Mem2Reg or SROA, the optimizer is likely to
+// be less effective than it could be."
+// Empirically, this is true. Moving allocas to the start gives us less spill-happy assembly
+fn move_variables_to_start<'input, P: ast::ArgParams<Id = Id>>(
+ module: TranslationModule<'input, P>,
+) -> Result<TranslationModule<'input, P>, TranslateError> {
+ convert_methods_simple(module, move_variables_to_start_impl)
+}
+
+fn move_variables_to_start_impl<'input, P: ast::ArgParams>(
+ _: &mut IdNameMapBuilder<'input>,
+ fn_body: Vec<Statement<ast::Instruction<P>, P>>,
+) -> Result<Vec<Statement<ast::Instruction<P>, P>>, TranslateError> {
+ if fn_body.is_empty() {
+ return Ok(fn_body);
+ }
+ let mut result = (0..fn_body.len())
+ .into_iter()
+ .map(|_| mem::MaybeUninit::<_>::uninit())
+ .collect::<Vec<_>>();
+ let variables_count = fn_body.iter().fold(0, |acc, statement| {
+ acc + matches!(statement, Statement::Variable(..)) as usize
+ });
+ let mut variable = 1usize;
+ let mut non_variable = variables_count + 1;
+ // methods always start with an entry label
+ let mut statements = fn_body.into_iter();
+ let start_label = statements.next().ok_or_else(TranslateError::unreachable)?;
+ unsafe { result.get_unchecked_mut(0).write(start_label) };
+ for statement in statements {
+ let index = match statement {
+ Statement::Variable(_) => &mut variable,
+ _ => &mut non_variable,
+ };
+ unsafe { result.get_unchecked_mut(*index).write(statement) };
+ *index += 1;
+ }
+ Ok(unsafe { mem::transmute(result) })
+}
+
+// PTX definition of param state space does not translate cleanly into AMDGPU notion of an address space:
+//  .param in kernel arguments matches AMDGPU constant address space
+// .param in function arguments and variables matches AMDGPU private address space
+// This pass converts all instances of declarations in .param state space into constant or local state space appropriately
+// Previously we used AMDPGU generic address space for params and left it for LLVM to infer the right non-generic space,
+// but this made LLVM crash on some inputs (e.g. test alloca_call.ptx)
+fn deparamize_variable_declarations<'input>(
+ mut module: TranslationModule<'input, ExpandedArgParams>,
+) -> Result<TranslationModule<'input, ExpandedArgParams>, TranslateError> {
+ let id_def = &mut module.id_defs;
+ for directive in module.directives.iter_mut() {
+ match directive {
+ TranslationDirective::Variable(..) => {}
+ TranslationDirective::Method(method) => {
+ let mut new_space = FxHashMap::default();
+ if method.is_kernel {
+ let input_arguments: Vec<ast::VariableDeclaration<Id>> =
+ mem::replace(&mut method.input_arguments, Vec::new());
+ let input_arguments = input_arguments
.into_iter()
- .map(|(id, typ)| (ast::Operand::Reg(id), typ))
- .chain(in_args.into_iter())
- .collect();
- let resolved_call = ResolvedCall {
- uniform: call.uniform,
- ret_params: out_non_params,
- func: call.func,
- param_list: normalized_input_args,
- };
- let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
- let reresolved_call = resolved_call.visit(&mut visitor)?;
- visitor.func.push(reresolved_call);
- visitor.func.extend(visitor.post_stmts);
+ .map(|arg| {
+ if arg.state_space == ast::StateSpace::Param {
+ let new_arg = id_def.register_variable_decl(
+ arg.align,
+ arg.type_.clone(),
+ ast::StateSpace::Const,
+ );
+ new_space.insert(arg.name, (new_arg.name, ast::StateSpace::Const));
+ new_arg
+ } else {
+ arg
+ }
+ })
+ .collect::<Vec<_>>();
+ method.input_arguments = input_arguments;
+ }
+ method.body = method
+ .body
+ .take()
+ .map(|old_body| {
+ deparamize_variable_declarations_convert_body(id_def, new_space, old_body)
+ })
+ .transpose()?;
+ }
+ }
+ }
+ Ok(module)
+}
+
+fn deparamize_variable_declarations_convert_body<'input>(
+ id_def: &mut IdNameMapBuilder<'input>,
+ mut new_space: FxHashMap<Id, (Id, ast::StateSpace)>,
+ fn_body: Vec<ExpandedStatement>,
+) -> Result<Vec<ExpandedStatement>, TranslateError> {
+ let mut result = Vec::with_capacity(fn_body.len());
+ for statement in fn_body {
+ match statement {
+ Statement::Instruction(ast::Instruction::Mov(details, mut args)) => {
+ if let Some((new_name, _)) = new_space.get(&args.src) {
+ args.src = *new_name;
+ }
+ result.push(Statement::Instruction(ast::Instruction::Mov(details, args)));
+ }
+ Statement::Variable(
+ var @ Variable {
+ state_space: ast::StateSpace::Param,
+ ..
+ },
+ ) => {
+ let old_name = var.name;
+ let new_var = id_def.register_variable_def(
+ var.align,
+ var.type_,
+ ast::StateSpace::Local,
+ var.initializer,
+ );
+ new_space.insert(old_name, (new_var.name, ast::StateSpace::Local));
+ result.push(Statement::Variable(new_var));
+ }
+ Statement::PtrAccess(
+ mut ptr @ PtrAccess {
+ state_space: ast::StateSpace::Param,
+ ..
+ },
+ ) => {
+ if let Some((new_name, new_space)) = new_space.get(&ptr.ptr_src) {
+ ptr.state_space = *new_space;
+ ptr.ptr_src = *new_name;
+ } else {
+ ptr.state_space = ast::StateSpace::Const;
+ }
+ let old_name = ptr.dst;
+ ptr.dst = id_def
+ .register_intermediate(Some((ptr.underlying_type.clone(), ptr.state_space)));
+ new_space.insert(old_name, (ptr.dst, ptr.state_space));
+ result.push(Statement::PtrAccess(ptr));
+ }
+ Statement::Instruction(ast::Instruction::St(
+ mut details @ ast::StData {
+ state_space: ast::StateSpace::Param,
+ ..
+ },
+ mut args,
+ )) => {
+ if let Some((new_name, new_space)) = new_space.get(&args.src1) {
+ details.state_space = *new_space;
+ args.src1 = *new_name;
+ } else {
+ details.state_space = ast::StateSpace::Const;
+ }
+ result.push(Statement::Instruction(ast::Instruction::St(details, args)));
+ }
+ Statement::Instruction(ast::Instruction::Ld(
+ mut details @ ast::LdDetails {
+ state_space: ast::StateSpace::Param,
+ ..
+ },
+ mut args,
+ )) => {
+ if let Some((new_name, new_space)) = new_space.get(&args.src) {
+ details.state_space = *new_space;
+ args.src = *new_name;
+ } else {
+ details.state_space = ast::StateSpace::Const;
}
- ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => {
- if let Some(src_id) = src.underlying() {
- let (typ, _) = id_defs.get_typed(*src_id)?;
- let take_address = match typ {
- ast::Type::Scalar(_) => false,
- ast::Type::Vector(_, _) => false,
- ast::Type::Array(_, _) => true,
- ast::Type::Pointer(_, _) => true,
- };
- d.src_is_address = take_address;
+ result.push(Statement::Instruction(ast::Instruction::Ld(details, args)));
+ }
+ s => result.push(s),
+ }
+ }
+ Ok(result)
+}
+
+fn create_metadata<'input>(
+ translation_module: &TranslationModule<'input, ExpandedArgParams>,
+) -> Metadata<'input> {
+ let mut kernel_metadata = Vec::new();
+ for directive in translation_module.directives.iter() {
+ match directive {
+ TranslationDirective::Method(method) => {
+ if method.tuning.is_empty() {
+ continue;
+ }
+ let name = match method.source_name {
+ Some(ref name) => name.clone(),
+ None => continue,
+ };
+ for tuning in method.tuning.iter().copied() {
+ match tuning {
+ // TODO: measure
+ ast::TuningDirective::MaxNReg(_)
+ | ast::TuningDirective::MinNCtaPerSm(_) => {}
+ ast::TuningDirective::MaxNtid(x, y, z) => {
+ let size = x as u64 * y as u64 * z as u64;
+ kernel_metadata.push((
+ name.clone(),
+ None,
+ NonZeroU32::new(size as u32),
+ ));
+ }
+ ast::TuningDirective::ReqNtid(x, y, z) => {
+ let size = x as u64 * y as u64 * z as u64;
+ kernel_metadata.push((
+ name.clone(),
+ NonZeroU32::new(size as u32),
+ NonZeroU32::new(size as u32),
+ ));
+ }
}
- let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
- let instruction = Statement::Instruction(
- ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?,
- );
- visitor.func.push(instruction);
- visitor.func.extend(visitor.post_stmts);
}
- inst => {
- let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
- let instruction = Statement::Instruction(inst.map(&mut visitor)?);
- visitor.func.push(instruction);
- visitor.func.extend(visitor.post_stmts);
+ }
+ TranslationDirective::Variable(..) => {}
+ }
+ }
+ Metadata {
+ sm_version: translation_module.sm_version,
+ kernel_metadata,
+ }
+}
+
+fn insert_compilation_mode_prologue<'input>(
+ mut translation_module: TranslationModule<'input, ExpandedArgParams>,
+) -> TranslationModule<'input, ExpandedArgParams> {
+ if translation_module.compilation_mode != CompilationMode::Wave32OnWave64 {
+ return translation_module;
+ }
+ for directive in translation_module.directives.iter_mut() {
+ match directive {
+ TranslationDirective::Method(TranslationMethod {
+ is_kernel,
+ body: Some(body),
+ tuning,
+ ..
+ }) => {
+ for t in tuning.iter_mut() {
+ match t {
+ ast::TuningDirective::MaxNReg(_)
+ | ast::TuningDirective::MinNCtaPerSm(_) => {}
+ ast::TuningDirective::MaxNtid(_, _, z) => {
+ *z *= 2;
+ }
+ ast::TuningDirective::ReqNtid(_, _, z) => {
+ *z *= 2;
+ }
+ }
}
- },
- Statement::Label(i) => result.push(Statement::Label(i)),
- Statement::Variable(v) => result.push(Statement::Variable(v)),
- Statement::Conditional(c) => result.push(Statement::Conditional(c)),
- _ => return Err(error_unreachable()),
+ if !*is_kernel {
+ continue;
+ }
+ let old_body = mem::replace(body, Vec::new());
+ let mut new_body = Vec::with_capacity(old_body.len() + 1);
+ // I'd rather use early exit on laneid like a normal person,
+ // but that leads to miscompilations, so here's the next best thing
+ let asm = "s_bcnt1_i32_b64 exec_lo, exec\ns_lshr_b32 exec_lo, exec_lo, 1\ns_bfm_b64 exec, exec_lo, 0";
+ let constraints = "~{scc}";
+ new_body.push(Statement::AsmVolatile { asm, constraints });
+ new_body.extend(old_body.into_iter());
+ *body = new_body;
+ }
+ TranslationDirective::Method(..) | TranslationDirective::Variable(..) => {}
+ }
+ }
+ translation_module
+}
+
+// THIS PASS IS AN AWFUL HACK TO WORK AROUND LLVM BUG
+// In certain situations LLVM will miscompile AMDGPU
+// binaries if the return type of a function is .b8 array.
+// For example, if the return of the function is float3 NVIDIA
+// frontend compiler will compile it to .b8[12].
+// Turns out if the return type is a .b8 array, then LLVM will
+// sometimes be unable to remove the alloca.
+// Which is fine, but for some reason AMDGPU has a bug
+// where it does not deallocate alloca
+// Our """solution""" is to convert all b8[] into b32[]
+fn restructure_function_return_types(
+ mut module: TranslationModule<TypedArgParams>,
+) -> Result<TranslationModule<TypedArgParams>, TranslateError> {
+ let id_defs = &mut module.id_defs;
+ module.directives = module
+ .directives
+ .into_iter()
+ .map(|directive| avoid_byte_array_returns(id_defs, directive))
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(module)
+}
+
+fn avoid_byte_array_returns<'input>(
+ id_defs: &mut IdNameMapBuilder<'input>,
+ mut directive: TranslationDirective<'input, TypedArgParams>,
+) -> Result<TranslationDirective<'input, TypedArgParams>, TranslateError> {
+ match directive {
+ TranslationDirective::Method(ref mut method) => {
+ if !method.is_kernel {
+ let return_arguments = &mut method.return_arguments;
+ let input_arguments = &mut method.input_arguments;
+ for arg in return_arguments
+ .iter_mut()
+ .chain(input_arguments.iter_mut())
+ {
+ if let (
+ ast::Type::Array(
+ ref mut scalar_type @ ast::ScalarType::B8,
+ ref mut dimensions,
+ ),
+ _,
+ ) = (&mut arg.type_, arg.state_space)
+ {
+ if dimensions.len() > 1 {
+ return Err(TranslateError::unexpected_pattern());
+ }
+ *scalar_type = ast::ScalarType::B32;
+ dimensions[0] = div_positive_round_up(dimensions[0], 4);
+ id_defs.change_type(
+ arg.name,
+ ast::Type::Array(ast::ScalarType::B32, dimensions.clone()),
+ )?;
+ }
+ }
+ }
+ for statement in method.body.iter_mut().flatten() {
+ if let Statement::Call(call) = statement {
+ let return_arguments = &mut call.return_arguments;
+ let input_arguments = &mut call.input_arguments;
+ for (type_, space) in return_arguments
+ .iter_mut()
+ .map(|(_, t, s)| (t, s))
+ .chain(input_arguments.iter_mut().map(|(_, t, s)| (t, s)))
+ {
+ if let (
+ ast::Type::Array(
+ ref mut scalar_type @ ast::ScalarType::B8,
+ ref mut dimensions,
+ ),
+ _,
+ ) = (type_, space)
+ {
+ if dimensions.len() > 1 {
+ return Err(TranslateError::unexpected_pattern());
+ }
+ *scalar_type = ast::ScalarType::B32;
+ dimensions[0] = div_positive_round_up(dimensions[0], 4);
+ }
+ }
+ }
+ if let Statement::Variable(Variable {
+ state_space: ast::StateSpace::Param,
+ type_:
+ ast::Type::Array(ref mut scalar_type @ ast::ScalarType::B8, ref mut dimensions),
+ ..
+ }) = statement
+ {
+ if dimensions.len() > 1 {
+ return Err(TranslateError::unexpected_pattern());
+ }
+ *scalar_type = ast::ScalarType::B32;
+ dimensions[0] = div_positive_round_up(dimensions[0], 4);
+ }
+ }
+ Ok(directive)
+ }
+ TranslationDirective::Variable(..) => Ok(directive),
+ }
+}
+
+fn div_positive_round_up(dividend: u32, divisor: u32) -> u32 {
+ let mut result = dividend / divisor;
+ if (dividend % divisor) != 0 {
+ result += 1;
+ }
+ result
+}
+
+fn get_kernel_arguments(
+ directives: &[Directive],
+) -> Result<FxHashMap<String, Vec<Layout>>, TranslateError> {
+ let mut result = FxHashMap::default();
+ for directive in directives.iter() {
+ match directive {
+ Directive::Method(TranslationMethod {
+ is_kernel: true,
+ source_name,
+ input_arguments,
+ ..
+ }) => {
+ let name = match source_name {
+ Some(name) => name,
+ None => continue,
+ };
+ let layout = input_arguments
+ .iter()
+ .map(|var| var.layout())
+ .collect::<Vec<_>>();
+ result.insert(name.to_string(), layout);
+ }
+ _ => continue,
}
}
Ok(result)
}
-struct VectorRepackVisitor<'a, 'b> {
- func: &'b mut Vec<TypedStatement>,
- id_def: &'b mut NumericIdResolver<'a>,
+pub(crate) struct CallGraph {
+ pub(crate) all_callees: FxHashMap<Id, FxHashSet<Id>>,
+}
+
+// TODO: resolve declarations
+impl CallGraph {
+ pub(crate) fn new<'input, P: ast::ArgParams<Id = Id>>(
+ module: &[TranslationDirective<'input, P>],
+ ) -> Self {
+ let mut has_body = FxHashSet::default();
+ let mut direct_callees = FxHashMap::default();
+ for directive in module {
+ match directive {
+ TranslationDirective::Method(TranslationMethod {
+ name,
+ body: Some(statements),
+ ..
+ }) => {
+ let call_key = *name;
+ has_body.insert(call_key);
+ if let hash_map::Entry::Vacant(entry) = direct_callees.entry(call_key) {
+ entry.insert(Vec::new());
+ }
+ for statement in statements {
+ match statement {
+ Statement::Call(ResolvedCall {
+ name,
+ is_indirect: false,
+ ..
+ }) => {
+ multi_hash_map_append(&mut direct_callees, call_key, *name);
+ }
+ _ => {}
+ }
+ }
+ }
+ _ => {}
+ }
+ }
+ let mut result = FxHashMap::default();
+ for (&method_key, children) in direct_callees.iter() {
+ let mut visited = FxHashSet::default();
+ for child in children {
+ if !has_body.contains(child) {
+ continue;
+ }
+ Self::add_call_map_single(&has_body, &direct_callees, &mut visited, *child);
+ }
+ result.insert(method_key, visited);
+ }
+ CallGraph {
+ all_callees: result,
+ }
+ }
+
+ fn add_call_map_single(
+ has_body: &FxHashSet<Id>,
+ directly_called_by: &FxHashMap<Id, Vec<Id>>,
+ visited: &mut FxHashSet<Id>,
+ current: Id,
+ ) {
+ if !visited.insert(current) {
+ return;
+ }
+ if let Some(children) = directly_called_by.get(&current) {
+ for child in children {
+ if !has_body.contains(child) {
+ continue;
+ }
+ Self::add_call_map_single(has_body, directly_called_by, visited, *child);
+ }
+ }
+ }
+
+ fn methods(&self) -> impl Iterator<Item = (Id, &FxHashSet<Id>)> {
+ self.all_callees
+ .iter()
+ .map(|(method, children)| (*method, children))
+ }
+}
+
+fn multi_hash_map_append<
+ K: Eq + std::hash::Hash,
+ V,
+ Collection: std::iter::Extend<V> + std::default::Default,
+>(
+ m: &mut FxHashMap<K, Collection>,
+ key: K,
+ value: V,
+) {
+ match m.entry(key) {
+ hash_map::Entry::Occupied(mut entry) => {
+ entry.get_mut().extend(iter::once(value));
+ }
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(Default::default()).extend(iter::once(value));
+ }
+ }
+}
+
+/*
+ PTX represents dynamically allocated shared local memory as
+ .extern .shared .b32 shared_mem[];
+ In SPIRV/OpenCL world this is expressed as an additional argument to the kernel
+ And in AMD compilation
+ This pass looks for all uses of .extern .shared and converts them to
+ an additional method argument
+ The question is how this artificial argument should be expressed. There are
+ several options:
+ * Straight conversion:
+ .shared .b32 shared_mem[]
+ * Introduce .param_shared statespace:
+ .param_shared .b32 shared_mem
+ or
+ .param_shared .b32 shared_mem[]
+ * Introduce .shared_ptr <SCALAR> type:
+ .param .shared_ptr .b32 shared_mem
+ * Reuse .ptr hint:
+ .param .u64 .ptr shared_mem
+ This is the most tempting, but also the most nonsensical, .ptr is just a
+ hint, which has no semantical meaning (and the output of our
+ transformation has a semantical meaning - we emit additional
+ "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
+*/
+fn convert_dynamic_shared_memory_usage<'input>(
+ mut module: TranslationModule<'input, ExpandedArgParams>,
+ kernels_methods_call_map: &CallGraph,
+) -> Result<TranslationModule<'input, ExpandedArgParams>, TranslateError> {
+ let mut globals_shared = FxHashMap::default();
+ for dir in module.directives.iter() {
+ match dir {
+ Directive::Variable(
+ _,
+ _,
+ Variable {
+ state_space: ast::StateSpace::Shared,
+ name,
+ type_,
+ ..
+ },
+ ) => {
+ globals_shared.insert(*name, type_.clone());
+ }
+ _ => {}
+ }
+ }
+ if globals_shared.len() == 0 {
+ return Ok(module);
+ }
+ let mut methods_to_directly_used_shared_globals = FxHashMap::<_, FxHashSet<Id>>::default();
+ let remapped_directives = module
+ .directives
+ .into_iter()
+ .map(|directive| match directive {
+ Directive::Method(TranslationMethod {
+ return_arguments,
+ name,
+ input_arguments,
+ body: Some(statements),
+ tuning,
+ is_kernel,
+ source_name,
+ special_raytracing_linking: raytracing_linking,
+ }) => {
+ let call_key = name;
+ let statements = statements
+ .into_iter()
+ .map(|statement| {
+ statement.map_id(&mut |id, _| {
+ if globals_shared.get(&id).is_some() {
+ methods_to_directly_used_shared_globals
+ .entry(call_key)
+ .or_insert_with(FxHashSet::default)
+ .insert(id);
+ }
+ id
+ })
+ })
+ .collect();
+ Directive::Method(TranslationMethod {
+ return_arguments,
+ name,
+ input_arguments,
+ body: Some(statements),
+ tuning,
+ is_kernel,
+ source_name,
+ special_raytracing_linking: raytracing_linking,
+ })
+ }
+ directive => directive,
+ })
+ .collect::<Vec<_>>();
+ // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
+ // make sure it gets propagated to `fn1` and `kernel`
+ let methods_to_indirectly_used_shared_globals = resolve_indirect_uses_of_globals_shared(
+ methods_to_directly_used_shared_globals,
+ kernels_methods_call_map,
+ );
+ // now visit every method declaration and inject those additional arguments
+ let mut directives = Vec::with_capacity(remapped_directives.len());
+ for directive in remapped_directives.into_iter() {
+ match directive {
+ Directive::Method(TranslationMethod {
+ return_arguments,
+ name,
+ mut input_arguments,
+ body,
+ tuning,
+ is_kernel,
+ source_name,
+ special_raytracing_linking: raytracing_linking,
+ }) => {
+ let statements: Option<
+ Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>>,
+ > = {
+ insert_arguments_remap_statements(
+ &mut module.id_defs.id_gen,
+ &globals_shared,
+ &methods_to_indirectly_used_shared_globals,
+ name,
+ is_kernel,
+ &mut input_arguments,
+ body,
+ )?
+ };
+ directives.push(Directive::Method(TranslationMethod {
+ return_arguments,
+ name,
+ input_arguments,
+ body: statements,
+ tuning,
+ is_kernel,
+ source_name,
+ special_raytracing_linking: raytracing_linking,
+ }));
+ }
+ directive => directives.push(directive),
+ }
+ }
+ Ok(TranslationModule {
+ directives,
+ ..module
+ })
+}
+
+fn insert_arguments_remap_statements<'input>(
+ new_id: &mut IdGenerator,
+ globals_shared: &FxHashMap<Id, ast::Type>,
+ methods_to_indirectly_used_shared_globals: &FxHashMap<Id, BTreeSet<Id>>,
+ method_name: Id,
+ is_kernel: bool,
+ input_arguments: &mut Vec<ast::VariableDeclaration<Id>>,
+ statements: Option<Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>>>,
+) -> Result<
+ Option<Vec<Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>>>,
+ TranslateError,
+> {
+ let method_globals = match methods_to_indirectly_used_shared_globals.get(&method_name) {
+ Some(method_globals) => method_globals,
+ None => return Ok(statements),
+ };
+ let remapped_globals_in_method = method_globals
+ .iter()
+ .map(|global| {
+ Ok((
+ *global,
+ (
+ if is_kernel { *global } else { new_id.next() },
+ globals_shared
+ .get(&global)
+ .ok_or_else(TranslateError::todo)?
+ .clone(),
+ ),
+ ))
+ })
+ .collect::<Result<BTreeMap<_, _>, _>>()?;
+ if !is_kernel {
+ for (_, (new_shared_global_id, shared_global_type)) in remapped_globals_in_method.iter() {
+ input_arguments.push(ast::VariableDeclaration {
+ align: None,
+ type_: shared_global_type.clone(),
+ state_space: ast::StateSpace::Shared,
+ name: *new_shared_global_id,
+ });
+ }
+ }
+ Ok(statements.map(|statements| {
+ replace_uses_of_shared_memory(
+ methods_to_indirectly_used_shared_globals,
+ statements,
+ remapped_globals_in_method,
+ )
+ }))
+}
+
+fn replace_uses_of_shared_memory<'input>(
+ methods_to_indirectly_used_shared_globals: &FxHashMap<Id, BTreeSet<Id>>,
+ statements: Vec<ExpandedStatement>,
+ remapped_globals_in_method: BTreeMap<Id, (Id, ast::Type)>,
+) -> Vec<ExpandedStatement> {
+ let mut result = Vec::with_capacity(statements.len());
+ for statement in statements {
+ match statement {
+ Statement::Call(mut call) => {
+ // We can safely skip checking call arguments,
+ // because there's simply no way to pass shared ptr
+ // without converting it to .b64 first
+ if let Some(shared_globals_used_by_callee) =
+ methods_to_indirectly_used_shared_globals.get(&call.name)
+ {
+ for &shared_global_used_by_callee in shared_globals_used_by_callee {
+ let (remapped_shared_id, type_) = remapped_globals_in_method
+ .get(&shared_global_used_by_callee)
+ .unwrap_or_else(|| todo!());
+ call.input_arguments.push((
+ *remapped_shared_id,
+ type_.clone(),
+ ast::StateSpace::Shared,
+ ));
+ }
+ }
+ result.push(Statement::Call(call))
+ }
+ statement => {
+ let new_statement = statement.map_id(&mut |id, _| {
+ if let Some((remapped_shared_id, _)) = remapped_globals_in_method.get(&id) {
+ *remapped_shared_id
+ } else {
+ id
+ }
+ });
+ result.push(new_statement);
+ }
+ }
+ }
+ result
+}
+
+// We need to compute two kinds of information:
+// * If it's a kernel -> size of .shared globals in use (direct or indirect)
+// * If it's a function -> does it use .shared global (directly or indirectly)
+fn resolve_indirect_uses_of_globals_shared<'input>(
+ methods_use_of_globals_shared: FxHashMap<Id, FxHashSet<Id>>,
+ kernels_methods_call_map: &CallGraph,
+) -> FxHashMap<Id, BTreeSet<Id>> {
+ let mut result = FxHashMap::default();
+ for (method, callees) in kernels_methods_call_map.methods() {
+ let mut indirect_globals = methods_use_of_globals_shared
+ .get(&method)
+ .into_iter()
+ .flatten()
+ .copied()
+ .collect::<BTreeSet<_>>();
+ for &callee in callees {
+ indirect_globals.extend(
+ methods_use_of_globals_shared
+ .get(&callee)
+ .into_iter()
+ .flatten()
+ .copied(),
+ );
+ }
+ result.insert(method, indirect_globals);
+ }
+ result
+}
+
+struct SpecialRegisterResolver<'a, 'input> {
+ id_defs: &'a mut IdNameMapBuilder<'input>,
+ ptx_imports: &'a mut AdditionalFunctionDeclarations,
+ result: Vec<TypedStatement>,
+}
+
+impl<'a, 'input> SpecialRegisterResolver<'a, 'input> {
+ fn replace_sreg(
+ &mut self,
+ desc: ArgumentDescriptor<Id>,
+ vector_index: Option<u8>,
+ ) -> Result<Id, TranslateError> {
+ if let Some(sreg) = self.id_defs.globals.special_registers.get(desc.op) {
+ if desc.is_dst {
+ return Err(TranslateError::mismatched_type());
+ }
+ let input_arguments = match (vector_index, sreg.get_function_input_type()) {
+ (Some(idx), Some(inp_type)) => {
+ if inp_type != ast::ScalarType::U8 {
+ return Err(TranslateError::unreachable());
+ }
+ let constant = self.id_defs.register_intermediate(Some((
+ ast::Type::Scalar(inp_type),
+ ast::StateSpace::Reg,
+ )));
+ self.result.push(Statement::Constant(ConstantDefinition {
+ dst: constant,
+ typ: inp_type,
+ value: ast::ImmediateValue::U64(idx as u64),
+ }));
+ vec![(
+ TypedOperand::Reg(constant),
+ ast::Type::Scalar(inp_type),
+ ast::StateSpace::Reg,
+ )]
+ }
+ (None, None) => Vec::new(),
+ _ => return Err(TranslateError::mismatched_type()),
+ };
+ let ocl_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
+ let return_type = sreg.get_function_return_type();
+ let fn_result = self.id_defs.register_intermediate(Some((
+ ast::Type::Scalar(return_type),
+ ast::StateSpace::Reg,
+ )));
+ let return_arguments = vec![(
+ fn_result,
+ ast::Type::Scalar(return_type),
+ ast::StateSpace::Reg,
+ )];
+ let fn_call = self.ptx_imports.add_or_get_declaration(
+ self.id_defs,
+ ocl_fn_name.to_string(),
+ return_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ input_arguments.iter().map(|(_, typ, space)| (typ, *space)),
+ )?;
+ self.result.push(Statement::Call(ResolvedCall {
+ uniform: false,
+ return_arguments,
+ name: fn_call,
+ input_arguments,
+ is_indirect: false,
+ }));
+ Ok(fn_result)
+ } else {
+ Ok(desc.op)
+ }
+ }
+}
+
+impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
+ for SpecialRegisterResolver<'a, 'input>
+{
+ fn id(
+ &mut self,
+ desc: ArgumentDescriptor<Id>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError> {
+ self.replace_sreg(desc, None)
+ }
+
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<TypedOperand>,
+ _typ: &ast::Type,
+ _state_space: ast::StateSpace,
+ ) -> Result<TypedOperand, TranslateError> {
+ Ok(match desc.op {
+ TypedOperand::Reg(reg) => TypedOperand::Reg(self.replace_sreg(desc.new_op(reg), None)?),
+ op @ TypedOperand::RegOffset(_, _) => op,
+ op @ TypedOperand::Imm(_) => op,
+ TypedOperand::VecMember(reg, idx) => {
+ TypedOperand::VecMember(self.replace_sreg(desc.new_op(reg), Some(idx))?, idx)
+ }
+ })
+ }
+}
+
+fn extract_global_cvt<'input>(
+ local: &mut Vec<ExpandedStatement>,
+ ptx_impl_imports: &mut AdditionalFunctionDeclarations,
+ id_def: &mut IdNameMapBuilder<'input>,
+ desc: ast::CvtDesc,
+ inst: ast::Instruction<ExpandedArgParams>,
+) -> Result<(), TranslateError> {
+ let fn_name = [
+ ZLUDA_PTX_PREFIX,
+ "cvt_",
+ rounding_to_ptx_name(desc.rounding),
+ "_",
+ desc.dst.to_ptx_name(),
+ "_",
+ desc.src.to_ptx_name(),
+ ]
+ .concat();
+ local.push(instruction_to_fn_call(
+ id_def,
+ ptx_impl_imports,
+ inst,
+ fn_name,
+ )?);
+ Ok(())
+}
+
+fn rounding_to_ptx_name(this: Option<ast::RoundingMode>) -> &'static str {
+ match this {
+ None | Some(ast::RoundingMode::NearestEven) => "rn",
+ Some(ast::RoundingMode::Zero) => "rz",
+ Some(ast::RoundingMode::NegativeInf) => "rm",
+ Some(ast::RoundingMode::PositiveInf) => "rp",
+ }
+}
+
+impl ast::AtomSemantics {
+ fn to_ptx_name(self) -> &'static str {
+ match self {
+ ast::AtomSemantics::Relaxed => "relaxed",
+ ast::AtomSemantics::Acquire => "acquire",
+ ast::AtomSemantics::Release => "release",
+ ast::AtomSemantics::AcquireRelease => "acq_rel",
+ }
+ }
+}
+
+impl ast::MemScope {
+ fn to_ptx_name(self) -> &'static str {
+ match self {
+ ast::MemScope::Cta => "cta",
+ ast::MemScope::Gpu => "gpu",
+ ast::MemScope::Sys => "sys",
+ }
+ }
+}
+
+impl ast::StateSpace {
+ fn to_ptx_name(self) -> &'static str {
+ match self {
+ ast::StateSpace::Generic => "generic",
+ ast::StateSpace::Global => "global",
+ ast::StateSpace::Shared => "shared",
+ ast::StateSpace::Reg => "reg",
+ ast::StateSpace::Const => "const",
+ ast::StateSpace::Local => "local",
+ ast::StateSpace::Param => "param",
+ ast::StateSpace::Sreg => "sreg",
+ }
+ }
+}
+
+impl ast::ShflMode {
+ fn to_ptx_name(self) -> &'static str {
+ match self {
+ ast::ShflMode::Up => "up",
+ ast::ShflMode::Down => "down",
+ ast::ShflMode::Bfly => "bfly",
+ ast::ShflMode::Idx => "idx",
+ }
+ }
+}
+struct VectorRepackVisitor<'a, 'input, V> {
+ extra_vistor: &'a mut V,
+ func: &'a mut Vec<TypedStatement>,
+ id_def: &'a mut IdNameMapBuilder<'input>,
post_stmts: Option<TypedStatement>,
}
-impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
- fn new(func: &'b mut Vec<TypedStatement>, id_def: &'b mut NumericIdResolver<'a>) -> Self {
+impl<'a, 'input, V> VectorRepackVisitor<'a, 'input, V> {
+ fn new(
+ extra_vistor: &'a mut V,
+ func: &'a mut Vec<TypedStatement>,
+ id_def: &'a mut IdNameMapBuilder<'input>,
+ ) -> Self {
VectorRepackVisitor {
+ extra_vistor,
func,
id_def,
post_stmts: None,
@@ -1600,22 +4367,60 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
fn convert_vector(
&mut self,
is_dst: bool,
- vector_sema: ArgumentSemantics,
+ non_default_implicit_conversion: Option<
+ fn(
+ (ast::StateSpace, &ast::Type),
+ (ast::StateSpace, &ast::Type),
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ >,
typ: &ast::Type,
- idx: Vec<spirv::Word>,
- ) -> Result<spirv::Word, TranslateError> {
- // mov.u32 foobar, {a,b};
+ state_space: ast::StateSpace,
+ idx: Vec<ast::RegOrImmediate<Id>>,
+ ) -> Result<Id, TranslateError> {
let scalar_t = match typ {
+ // mov.v2.u32 foobar, {a,b};
ast::Type::Vector(scalar_t, _) => *scalar_t,
- _ => return Err(TranslateError::MismatchedType),
+ // mov.b64 foobar, {a,b};
+ ast::Type::Scalar(scalar_t) => {
+ if scalar_t.kind() == ast::ScalarKind::Bit {
+ let total_size_of = scalar_t.size_of() as usize;
+ let scalar_size_of = total_size_of / idx.len();
+ if idx.len() * scalar_size_of == total_size_of {
+ ast::ScalarType::from_parts(scalar_size_of as u8, ast::ScalarKind::Bit)
+ } else {
+ return Err(TranslateError::mismatched_type());
+ }
+ } else {
+ return Err(TranslateError::mismatched_type());
+ }
+ }
+ _ => return Err(TranslateError::mismatched_type()),
};
- let temp_vec = self.id_def.new_non_variable(Some(typ.clone()));
+ let temp_vec = self
+ .id_def
+ .register_intermediate(Some((typ.clone(), state_space)));
+ let vector_members = idx
+ .into_iter()
+ .map(|vector_member| match vector_member {
+ ast::RegOrImmediate::Reg(reg) => reg,
+ ast::RegOrImmediate::Imm(immediate) => {
+ let (id, statement) = FlattenArguments::immediate_impl(
+ self.id_def,
+ immediate,
+ &ast::Type::Scalar(scalar_t),
+ ast::StateSpace::Reg,
+ );
+ self.func.push(statement);
+ id
+ }
+ })
+ .collect::<Vec<_>>();
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
packed: temp_vec,
- unpacked: idx,
- vector_sema,
+ unpacked: vector_members,
+ non_default_implicit_conversion,
});
if is_dst {
self.post_stmts = Some(statement);
@@ -1626,453 +4431,139 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
}
}
-impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
- for VectorRepackVisitor<'a, 'b>
+impl<'a, 'b, V: ArgumentMapVisitor<TypedArgParams, TypedArgParams>>
+ ArgumentMapVisitor<NormalizedArgParams, TypedArgParams> for VectorRepackVisitor<'a, 'b, V>
{
fn id(
&mut self,
- desc: ArgumentDescriptor<spirv::Word>,
- _: Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError> {
- Ok(desc.op)
+ desc: ArgumentDescriptor<Id>,
+ type_: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError> {
+ self.extra_vistor.id(desc, type_)
}
fn operand(
&mut self,
- desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ desc: ArgumentDescriptor<ast::Operand<Id>>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
- ast::Operand::Reg(reg) => TypedOperand::Reg(reg),
+ ast::Operand::Reg(reg) => {
+ self.extra_vistor
+ .operand(desc.new_op(TypedOperand::Reg(reg)), typ, state_space)?
+ }
ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
ast::Operand::Imm(x) => TypedOperand::Imm(x),
ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
- ast::Operand::VecPack(vec) => {
- TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?)
- }
+ ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector(
+ desc.is_dst,
+ desc.non_default_implicit_conversion,
+ typ,
+ state_space,
+ vec,
+ )?),
})
}
}
-//TODO: share common code between this and to_ptx_impl_bfe_call
-fn to_ptx_impl_atomic_call(
- id_defs: &mut NumericIdResolver,
- ptx_impl_imports: &mut HashMap<String, Directive>,
- details: ast::AtomDetails,
- arg: ast::Arg3<ExpandedArgParams>,
- op: &'static str,
-) -> ExpandedStatement {
- let semantics = ptx_semantics_name(details.semantics);
- let scope = ptx_scope_name(details.scope);
- let space = ptx_space_name(details.space);
- let fn_name = format!(
- "__zluda_ptx_impl__atom_{}_{}_{}_{}",
- semantics, scope, space, op
- );
- // TODO: extract to a function
- let ptr_space = match details.space {
- ast::AtomSpace::Generic => ast::PointerStateSpace::Generic,
- ast::AtomSpace::Global => ast::PointerStateSpace::Global,
- ast::AtomSpace::Shared => ast::PointerStateSpace::Shared,
- };
- let fn_id = match ptx_impl_imports.entry(fn_name) {
- hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
- align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
- array_init: Vec::new(),
- }],
- fn_id,
- vec![
- ast::FnArgument {
- align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
- ast::SizedScalarType::U32,
- ptr_space,
- )),
- name: id_defs.new_non_variable(None),
- array_init: Vec::new(),
- },
- ast::FnArgument {
- align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
- array_init: Vec::new(),
- },
- ],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
- let func = Function {
- func_decl,
- globals: Vec::new(),
- body: None,
- import_as: Some(entry.key().clone()),
- spirv_decl,
- };
- entry.insert(Directive::Method(func));
- fn_id
- }
- hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
- _ => unreachable!(),
- },
- };
- Statement::Call(ResolvedCall {
- uniform: false,
- func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
- )],
- param_list: vec![
- (
- arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
- ast::SizedScalarType::U32,
- ptr_space,
- )),
- ),
- (
- arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
- ),
- ],
- })
-}
-
-fn to_ptx_impl_bfe_call(
- id_defs: &mut NumericIdResolver,
- ptx_impl_imports: &mut HashMap<String, Directive>,
- typ: ast::IntType,
- arg: ast::Arg4<ExpandedArgParams>,
-) -> ExpandedStatement {
- let prefix = "__zluda_ptx_impl__";
- let suffix = match typ {
- ast::IntType::U32 => "bfe_u32",
- ast::IntType::U64 => "bfe_u64",
- ast::IntType::S32 => "bfe_s32",
- ast::IntType::S64 => "bfe_s64",
- _ => unreachable!(),
- };
- let fn_name = format!("{}{}", prefix, suffix);
- let fn_id = match ptx_impl_imports.entry(fn_name) {
- hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
- align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
- array_init: Vec::new(),
- }],
- fn_id,
- vec![
- ast::FnArgument {
- align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
- array_init: Vec::new(),
- },
- ast::FnArgument {
- align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
- array_init: Vec::new(),
- },
- ast::FnArgument {
- align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
- array_init: Vec::new(),
- },
- ],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
- let func = Function {
- func_decl,
- globals: Vec::new(),
- body: None,
- import_as: Some(entry.key().clone()),
- spirv_decl,
+fn instruction_to_fn_call<'input>(
+ id_defs: &mut IdNameMapBuilder,
+ ptx_impl_imports: &mut AdditionalFunctionDeclarations,
+ inst: ast::Instruction<ExpandedArgParams>,
+ fn_name: String,
+) -> Result<ExpandedStatement, TranslateError> {
+ let mut arguments = Vec::new();
+ inst.visit(
+ &mut |desc: ArgumentDescriptor<Id>, typ: Option<(&ast::Type, ast::StateSpace)>| {
+ let (typ, space) = match typ {
+ Some((typ, space)) => (typ.clone(), space),
+ None => return Err(TranslateError::unreachable()),
};
- entry.insert(Directive::Method(func));
- fn_id
- }
- hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
- _ => unreachable!(),
+ let result = desc.op;
+ arguments.push((desc, typ, space));
+ Ok(result)
},
- };
- Statement::Call(ResolvedCall {
+ )?;
+ let return_arguments_count = arguments
+ .iter()
+ .position(|(desc, _, _)| !desc.is_dst)
+ .unwrap_or(arguments.len());
+ let (return_arguments, input_arguments) = arguments.split_at(return_arguments_count);
+ let fn_id = ptx_impl_imports.add_or_get_declaration(
+ id_defs,
+ fn_name,
+ return_arguments.iter().map(|(_, typ, state)| (typ, *state)),
+ input_arguments.iter().map(|(_, typ, state)| (typ, *state)),
+ )?;
+ Ok(Statement::Call(ResolvedCall {
uniform: false,
- func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- )],
- param_list: vec![
- (
- arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- ),
- (
- arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
- ),
- (
- arg.src3,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
- ),
- ],
+ name: fn_id,
+ return_arguments: arguments_to_resolved_arguments(return_arguments),
+ input_arguments: arguments_to_resolved_arguments(input_arguments),
+ is_indirect: false,
+ }))
+}
+
+fn fn_arguments_to_variables<'a>(
+ id_defs: &mut IdNameMapBuilder,
+ args: impl Iterator<Item = (&'a ast::Type, ast::StateSpace)>,
+) -> Vec<ast::VariableDeclaration<Id>> {
+ args.map(|(typ, space)| ast::VariableDeclaration {
+ align: None,
+ type_: typ.clone(),
+ state_space: space,
+ name: id_defs.register_intermediate(None),
})
+ .collect::<Vec<_>>()
}
-fn to_resolved_fn_args<T>(
- params: Vec<T>,
- params_decl: &[ast::FnArgumentType],
-) -> Vec<(T, ast::FnArgumentType)> {
- params
- .into_iter()
- .zip(params_decl.iter())
- .map(|(id, typ)| (id, typ.clone()))
+fn arguments_to_resolved_arguments(
+ args: &[(ArgumentDescriptor<Id>, ast::Type, ast::StateSpace)],
+) -> Vec<(Id, ast::Type, ast::StateSpace)> {
+ args.iter()
+ .map(|(desc, typ, space)| (desc.op, typ.clone(), *space))
.collect::<Vec<_>>()
}
-fn normalize_labels(
- func: Vec<ExpandedStatement>,
- id_def: &mut NumericIdResolver,
-) -> Vec<ExpandedStatement> {
- let mut labels_in_use = HashSet::new();
- for s in func.iter() {
- match s {
- Statement::Instruction(i) => {
- if let Some(target) = i.jump_target() {
- labels_in_use.insert(target);
- }
- }
- Statement::Conditional(cond) => {
- labels_in_use.insert(cond.if_true);
- labels_in_use.insert(cond.if_false);
- }
- Statement::Call(..)
- | Statement::Variable(..)
- | Statement::LoadVar(..)
- | Statement::StoreVar(..)
- | Statement::RetValue(..)
- | Statement::Conversion(..)
- | Statement::Constant(..)
- | Statement::Label(..)
- | Statement::PtrAccess { .. }
- | Statement::RepackVector(..) => {}
- }
- }
- iter::once(Statement::Label(id_def.new_non_variable(None)))
- .chain(func.into_iter().filter(|s| match s {
- Statement::Label(i) => labels_in_use.contains(i),
- _ => true,
- }))
- .collect::<Vec<_>>()
-}
-
-fn normalize_predicates(
- func: Vec<NormalizedStatement>,
- id_def: &mut NumericIdResolver,
-) -> Result<Vec<UnconditionalStatement>, TranslateError> {
- let mut result = Vec::with_capacity(func.len());
- for s in func {
- match s {
- Statement::Label(id) => result.push(Statement::Label(id)),
- Statement::Instruction((pred, inst)) => {
- if let Some(pred) = pred {
- let if_true = id_def.new_non_variable(None);
- let if_false = id_def.new_non_variable(None);
- let folded_bra = match &inst {
- ast::Instruction::Bra(_, arg) => Some(arg.src),
- _ => None,
- };
- let mut branch = BrachCondition {
- predicate: pred.label,
- if_true: folded_bra.unwrap_or(if_true),
- if_false,
- };
- if pred.not {
- std::mem::swap(&mut branch.if_true, &mut branch.if_false);
- }
- result.push(Statement::Conditional(branch));
- if folded_bra.is_none() {
- result.push(Statement::Label(if_true));
- result.push(Statement::Instruction(inst));
- }
- result.push(Statement::Label(if_false));
- } else {
- result.push(Statement::Instruction(inst));
- }
- }
- Statement::Variable(var) => result.push(Statement::Variable(var)),
- // Blocks are flattened when resolving ids
- _ => return Err(error_unreachable()),
- }
- }
- Ok(result)
-}
-
-fn insert_mem_ssa_statements<'a, 'b>(
- func: Vec<TypedStatement>,
- id_def: &mut NumericIdResolver,
- ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>,
- fn_decl: &mut SpirvMethodDecl,
-) -> Result<Vec<TypedStatement>, TranslateError> {
- let is_func = match ast_fn_decl {
- ast::MethodDecl::Func(..) => true,
- ast::MethodDecl::Kernel { .. } => false,
- };
- let mut result = Vec::with_capacity(func.len());
- for arg in fn_decl.output.iter() {
- match type_to_variable_type(&arg.v_type, is_func)? {
- Some(var_type) => {
- result.push(Statement::Variable(ast::Variable {
- align: arg.align,
- v_type: var_type,
- name: arg.name,
- array_init: arg.array_init.clone(),
- }));
- }
- None => return Err(error_unreachable()),
- }
- }
- for spirv_arg in fn_decl.input.iter_mut() {
- match type_to_variable_type(&spirv_arg.v_type, is_func)? {
- Some(var_type) => {
- let typ = spirv_arg.v_type.clone();
- let new_id = id_def.new_non_variable(Some(typ.clone()));
- result.push(Statement::Variable(ast::Variable {
- align: spirv_arg.align,
- v_type: var_type,
- name: spirv_arg.name,
- array_init: spirv_arg.array_init.clone(),
- }));
- result.push(Statement::StoreVar(StoreVarDetails {
- arg: ast::Arg2St {
- src1: spirv_arg.name,
- src2: new_id,
- },
- typ,
- member_index: None,
- }));
- spirv_arg.name = new_id;
- }
- None => {}
- }
- }
- for s in func {
- match s {
- Statement::Call(call) => {
- insert_mem_ssa_statement_default(id_def, &mut result, call.cast())?
- }
- Statement::Instruction(inst) => match inst {
- ast::Instruction::Ret(d) => {
- // TODO: handle multiple output args
- if let &[out_param] = &fn_decl.output.as_slice() {
- let (typ, _) = id_def.get_typed(out_param.name)?;
- let new_id = id_def.new_non_variable(Some(typ.clone()));
- result.push(Statement::LoadVar(LoadVarDetails {
- arg: ast::Arg2 {
- dst: new_id,
- src: out_param.name,
- },
- typ: typ.clone(),
- member_index: None,
- }));
- result.push(Statement::RetValue(d, new_id));
- } else {
- result.push(Statement::Instruction(ast::Instruction::Ret(d)))
- }
- }
- inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
- },
- Statement::Conditional(mut bra) => {
- let generated_id =
- id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
- result.push(Statement::LoadVar(LoadVarDetails {
- arg: Arg2 {
- dst: generated_id,
- src: bra.predicate,
- },
- typ: ast::Type::Scalar(ast::ScalarType::Pred),
- member_index: None,
- }));
- bra.predicate = generated_id;
- result.push(Statement::Conditional(bra));
- }
- Statement::Conversion(conv) => {
- insert_mem_ssa_statement_default(id_def, &mut result, conv)?
- }
- Statement::PtrAccess(ptr_access) => {
- insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)?
- }
- Statement::RepackVector(repack) => {
- insert_mem_ssa_statement_default(id_def, &mut result, repack)?
- }
- s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s),
- _ => return Err(error_unreachable()),
- }
- }
- Ok(result)
+fn insert_mem_ssa_argument<'input>(
+ id_def: &mut IdNameMapBuilder<'input>,
+ func: &mut Vec<TypedStatement>,
+ arg: &mut ast::VariableDeclaration<Id>,
+) {
+ let new_id = id_def.register_intermediate(Some((arg.type_.clone(), arg.state_space)));
+ func.push(Statement::Variable(Variable {
+ align: arg.align,
+ type_: arg.type_.clone(),
+ state_space: ast::StateSpace::Reg,
+ name: arg.name,
+ initializer: None,
+ }));
+ func.push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::Arg2St {
+ src1: arg.name,
+ src2: new_id,
+ },
+ type_: arg.type_.clone(),
+ member_index: None,
+ }));
+ arg.name = new_id;
}
-fn type_to_variable_type(
- t: &ast::Type,
- is_func: bool,
-) -> Result<Option<ast::VariableType>, TranslateError> {
- Ok(match t {
- ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
- ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
- (*typ)
- .try_into()
- .map_err(|_| TranslateError::MismatchedType)?,
- *len,
- ))),
- ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array(
- (*typ)
- .try_into()
- .map_err(|_| TranslateError::MismatchedType)?,
- len.clone(),
- ))),
- ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
- if is_func {
- return Ok(None);
- }
- Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
- scalar_type
- .clone()
- .try_into()
- .map_err(|_| error_unreachable())?,
- (*space).try_into().map_err(|_| error_unreachable())?,
- )))
- }
- ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
- _ => return Err(error_unreachable()),
- })
+fn insert_mem_ssa_argument_reg_return(
+ func: &mut Vec<TypedStatement>,
+ arg: &ast::VariableDeclaration<Id>,
+) {
+ func.push(Statement::Variable(Variable {
+ align: arg.align,
+ type_: arg.type_.clone(),
+ state_space: arg.state_space,
+ name: arg.name,
+ initializer: None,
+ }));
}
-trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
+pub(crate) trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
fn visit(
self,
visitor: &mut impl ArgumentMapVisitor<From, To>,
@@ -2081,31 +4572,34 @@ trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
struct VisitArgumentDescriptor<
'a,
- Ctor: FnOnce(spirv::Word) -> Statement<ast::Instruction<U>, U>,
+ Ctor: FnOnce(Id) -> Statement<ast::Instruction<U>, U>,
U: ArgParamsEx,
> {
- desc: ArgumentDescriptor<spirv::Word>,
+ desc: ArgumentDescriptor<Id>,
typ: &'a ast::Type,
+ state_space: ast::StateSpace,
stmt_ctor: Ctor,
}
impl<
'a,
- Ctor: FnOnce(spirv::Word) -> Statement<ast::Instruction<U>, U>,
- T: ArgParamsEx<Id = spirv::Word>,
- U: ArgParamsEx<Id = spirv::Word>,
+ Ctor: FnOnce(Id) -> Statement<ast::Instruction<U>, U>,
+ T: ArgParamsEx<Id = Id>,
+ U: ArgParamsEx<Id = Id>,
> Visitable<T, U> for VisitArgumentDescriptor<'a, Ctor, U>
{
fn visit(
self,
visitor: &mut impl ArgumentMapVisitor<T, U>,
) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
- Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?))
+ Ok((self.stmt_ctor)(
+ visitor.id(self.desc, Some((self.typ, self.state_space)))?,
+ ))
}
}
struct InsertMemSSAVisitor<'a, 'input> {
- id_def: &'a mut NumericIdResolver<'input>,
+ id_def: &'a mut IdNameMapBuilder<'input>,
func: &'a mut Vec<TypedStatement>,
post_statements: Vec<TypedStatement>,
}
@@ -2113,15 +4607,15 @@ struct InsertMemSSAVisitor<'a, 'input> {
impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn symbol(
&mut self,
- desc: ArgumentDescriptor<(spirv::Word, Option<u8>)>,
- expected_type: Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError> {
+ desc: ArgumentDescriptor<(Id, Option<u8>)>,
+ expected: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError> {
let symbol = desc.op.0;
- if expected_type.is_none() {
+ if expected.is_none() {
return Ok(symbol);
};
- let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?;
- if !is_variable {
+ let (mut var_type, var_space, _, is_variable) = self.id_def.get_typed(symbol)?;
+ if !var_space.is_compatible(ast::StateSpace::Reg) || !is_variable {
return Ok(symbol);
};
let member_index = match desc.op.1 {
@@ -2131,38 +4625,44 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
var_type = ast::Type::Scalar(scalar_t);
width
}
- _ => return Err(TranslateError::MismatchedType),
+ _ => return Err(TranslateError::mismatched_type()),
};
- Some((
- idx,
- if self.id_def.special_registers.get(symbol).is_some() {
- Some(vector_width)
- } else {
- None
- },
- ))
+ Some((idx, vector_width))
}
None => None,
};
- let generated_id = self.id_def.new_non_variable(Some(var_type.clone()));
+ let generated_id = self
+ .id_def
+ .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
if !desc.is_dst {
self.func.push(Statement::LoadVar(LoadVarDetails {
arg: Arg2 {
dst: generated_id,
src: symbol,
},
+ _state_space: ast::StateSpace::Reg,
typ: var_type,
member_index,
}));
} else {
+ let (type_, member_index) = match member_index {
+ None => (var_type, None),
+ Some((idx, width)) => {
+ if let ast::Type::Scalar(scalar) = var_type {
+ (ast::Type::Vector(scalar, width), Some(idx))
+ } else {
+ return Err(TranslateError::unreachable());
+ }
+ }
+ };
self.post_statements
.push(Statement::StoreVar(StoreVarDetails {
arg: Arg2St {
src1: symbol,
src2: generated_id,
},
- typ: var_type,
- member_index: member_index.map(|(idx, _)| idx),
+ type_,
+ member_index,
}));
}
Ok(generated_id)
@@ -2174,9 +4674,9 @@ impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
{
fn id(
&mut self,
- desc: ArgumentDescriptor<spirv::Word>,
- typ: Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError> {
+ desc: ArgumentDescriptor<Id>,
+ typ: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError> {
self.symbol(desc.new_op((desc.op, None)), typ)
}
@@ -2184,24 +4684,26 @@ impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
&mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
TypedOperand::Reg(reg) => {
- TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
- }
- TypedOperand::RegOffset(reg, offset) => {
- TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset)
+ TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?)
}
+ TypedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(
+ self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?,
+ offset,
+ ),
op @ TypedOperand::Imm(..) => op,
- TypedOperand::VecMember(symbol, index) => {
- TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
- }
+ TypedOperand::VecMember(symbol, index) => TypedOperand::Reg(
+ self.symbol(desc.new_op((symbol, Some(index))), Some((typ, state_space)))?,
+ ),
})
}
}
fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable<TypedArgParams, TypedArgParams>>(
- id_def: &'a mut NumericIdResolver<'input>,
+ id_def: &'a mut IdNameMapBuilder<'input>,
func: &'a mut Vec<TypedStatement>,
stmt: S,
) -> Result<(), TranslateError> {
@@ -2216,71 +4718,14 @@ fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable<TypedArgParams, Typ
Ok(())
}
-fn expand_arguments<'a, 'b>(
- func: Vec<TypedStatement>,
- id_def: &'b mut MutableNumericIdResolver<'a>,
-) -> Result<Vec<ExpandedStatement>, TranslateError> {
- let mut result = Vec::with_capacity(func.len());
- for s in func {
- match s {
- Statement::Call(call) => {
- let mut visitor = FlattenArguments::new(&mut result, id_def);
- let (new_call, post_stmts) = (call.map(&mut visitor)?, visitor.post_stmts);
- result.push(Statement::Call(new_call));
- result.extend(post_stmts);
- }
- Statement::Instruction(inst) => {
- let mut visitor = FlattenArguments::new(&mut result, id_def);
- let (new_inst, post_stmts) = (inst.map(&mut visitor)?, visitor.post_stmts);
- result.push(Statement::Instruction(new_inst));
- result.extend(post_stmts);
- }
- Statement::Variable(ast::Variable {
- align,
- v_type,
- name,
- array_init,
- }) => result.push(Statement::Variable(ast::Variable {
- align,
- v_type,
- name,
- array_init,
- })),
- Statement::PtrAccess(ptr_access) => {
- let mut visitor = FlattenArguments::new(&mut result, id_def);
- let (new_inst, post_stmts) = (ptr_access.map(&mut visitor)?, visitor.post_stmts);
- result.push(Statement::PtrAccess(new_inst));
- result.extend(post_stmts);
- }
- Statement::RepackVector(repack) => {
- let mut visitor = FlattenArguments::new(&mut result, id_def);
- let (new_inst, post_stmts) = (repack.map(&mut visitor)?, visitor.post_stmts);
- result.push(Statement::RepackVector(new_inst));
- result.extend(post_stmts);
- }
- Statement::Label(id) => result.push(Statement::Label(id)),
- Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
- Statement::LoadVar(details) => result.push(Statement::LoadVar(details)),
- Statement::StoreVar(details) => result.push(Statement::StoreVar(details)),
- Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)),
- Statement::Conversion(conv) => result.push(Statement::Conversion(conv)),
- Statement::Constant(_) => return Err(error_unreachable()),
- }
- }
- Ok(result)
-}
-
-struct FlattenArguments<'a, 'b> {
- func: &'b mut Vec<ExpandedStatement>,
- id_def: &'b mut MutableNumericIdResolver<'a>,
- post_stmts: Vec<ExpandedStatement>,
+struct FlattenArguments<'a, 'input, I, P: ast::ArgParams> {
+ func: &'a mut Vec<Statement<I, P>>,
+ id_def: &'a mut IdNameMapBuilder<'input>,
+ post_stmts: Vec<Statement<I, P>>,
}
-impl<'a, 'b> FlattenArguments<'a, 'b> {
- fn new(
- func: &'b mut Vec<ExpandedStatement>,
- id_def: &'b mut MutableNumericIdResolver<'a>,
- ) -> Self {
+impl<'a, 'input, I, P: ast::ArgParams> FlattenArguments<'a, 'input, I, P> {
+ fn new(func: &'a mut Vec<Statement<I, P>>, id_def: &'a mut IdNameMapBuilder<'input>) -> Self {
FlattenArguments {
func,
id_def,
@@ -2290,133 +4735,126 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn reg(
&mut self,
- desc: ArgumentDescriptor<spirv::Word>,
- _: Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError> {
+ desc: ArgumentDescriptor<Id>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError> {
Ok(desc.op)
}
- fn reg_offset(
+ fn immediate(
&mut self,
- desc: ArgumentDescriptor<(spirv::Word, i32)>,
+ desc: ArgumentDescriptor<ast::ImmediateValue>,
typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- let (reg, offset) = desc.op;
- let add_type;
- match typ {
- ast::Type::Pointer(underlying_type, state_space) => {
- let reg_typ = self.id_def.get_typed(reg)?;
- if let ast::Type::Pointer(_, _) = reg_typ {
- let id_constant_stmt = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: ast::ScalarType::S64,
- value: ast::ImmediateValue::S64(offset as i64),
- }));
- let dst = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::PtrAccess(PtrAccess {
- underlying_type: underlying_type.clone(),
- state_space: *state_space,
- dst,
- ptr_src: reg,
- offset_src: id_constant_stmt,
- }));
- return Ok(dst);
- } else {
- add_type = self.id_def.get_typed(reg)?;
- }
- }
- _ => {
- add_type = typ.clone();
- }
- };
- let (width, kind) = match add_type {
- ast::Type::Scalar(scalar_t) => {
- let kind = match scalar_t.kind() {
- kind @ ScalarKind::Bit
- | kind @ ScalarKind::Unsigned
- | kind @ ScalarKind::Signed => kind,
- ScalarKind::Float => return Err(TranslateError::MismatchedType),
- ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
- ScalarKind::Pred => return Err(TranslateError::MismatchedType),
- };
- (scalar_t.size_of(), kind)
- }
- _ => return Err(TranslateError::MismatchedType),
- };
- let arith_detail = if kind == ScalarKind::Signed {
- ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::from_size(width),
- saturate: false,
- })
+ state_space: ast::StateSpace,
+ ) -> Result<Id, TranslateError> {
+ let (id, statement) = Self::immediate_impl(self.id_def, desc.op, typ, state_space);
+ self.func.push(statement);
+ Ok(id)
+ }
+
+ fn immediate_impl(
+ id_def: &mut IdNameMapBuilder<'input>,
+ immediate: ast::ImmediateValue,
+ typ: &ast::Type,
+ state_space: ast::StateSpace,
+ ) -> (Id, Statement<I, P>) {
+ let scalar_t = if let ast::Type::Scalar(scalar) = typ {
+ *scalar
} else {
- ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
+ todo!()
};
- let id_constant_stmt = self.id_def.new_non_variable(add_type.clone());
- let result_id = self.id_def.new_non_variable(add_type);
- // TODO: check for edge cases around min value/max value/wrapping
- if offset < 0 && kind != ScalarKind::Signed {
+ let id = id_def.register_intermediate(Some((ast::Type::Scalar(scalar_t), state_space)));
+ (
+ id,
+ Statement::Constant(ConstantDefinition {
+ dst: id,
+ typ: scalar_t,
+ value: immediate,
+ }),
+ )
+ }
+}
+
+impl<'a, 'b> FlattenArguments<'a, 'b, ast::Instruction<ExpandedArgParams>, ExpandedArgParams> {
+ fn reg_offset(
+ &mut self,
+ desc: ArgumentDescriptor<(Id, i64)>,
+ typ: &ast::Type,
+ state_space: ast::StateSpace,
+ ) -> Result<Id, TranslateError> {
+ let (reg, offset) = desc.op;
+ if !desc.is_memory_access {
+ let (reg_type, reg_space, ..) = self.id_def.get_typed(reg)?;
+ if !reg_space.is_compatible(ast::StateSpace::Reg) {
+ return Err(TranslateError::mismatched_type());
+ }
+ let reg_scalar_type = match reg_type {
+ ast::Type::Scalar(underlying_type) => underlying_type,
+ _ => return Err(TranslateError::mismatched_type()),
+ };
+ let id_constant_stmt = self
+ .id_def
+ .register_intermediate(Some((reg_type.clone(), ast::StateSpace::Reg)));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
- typ: ast::ScalarType::from_parts(width, kind),
- value: ast::ImmediateValue::U64(-(offset as i64) as u64),
+ typ: reg_scalar_type,
+ value: ast::ImmediateValue::S64(offset),
}));
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Sub(
- arith_detail,
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
+ let arith_details = match reg_scalar_type.kind() {
+ ast::ScalarKind::Signed => ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: reg_scalar_type,
+ saturate: false,
+ }),
+ ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
+ ast::ArithDetails::Unsigned(reg_scalar_type)
+ }
+ _ => return Err(TranslateError::unreachable()),
+ };
+ let id_add_result = self
+ .id_def
+ .register_intermediate(Some((reg_type, state_space)));
+ self.func.push(Statement::Instruction(ast::Instruction::Add(
+ arith_details,
+ ast::Arg3 {
+ dst: id_add_result,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ )));
+ Ok(id_add_result)
} else {
+ let id_constant_stmt = self.id_def.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ )));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
- typ: ast::ScalarType::from_parts(width, kind),
- value: ast::ImmediateValue::S64(offset as i64),
+ typ: ast::ScalarType::S64,
+ value: ast::ImmediateValue::S64(offset),
}));
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- arith_detail,
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
+ let dst = self
+ .id_def
+ .register_intermediate(Some((typ.clone(), state_space)));
+ self.func.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: typ.clone(),
+ state_space: state_space,
+ dst,
+ ptr_src: reg,
+ offset_src: id_constant_stmt,
+ }));
+ Ok(dst)
}
- Ok(result_id)
- }
-
- fn immediate(
- &mut self,
- desc: ArgumentDescriptor<ast::ImmediateValue>,
- typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- *scalar
- } else {
- todo!()
- };
- let id = self.id_def.new_non_variable(ast::Type::Scalar(scalar_t));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id,
- typ: scalar_t,
- value: desc.op,
- }));
- Ok(id)
}
}
-impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenArguments<'a, 'b> {
+impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams>
+ for FlattenArguments<'a, 'b, ast::Instruction<ExpandedArgParams>, ExpandedArgParams>
+{
fn id(
&mut self,
- desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError> {
+ desc: ArgumentDescriptor<Id>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError> {
self.reg(desc, t)
}
@@ -2424,2670 +4862,730 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
&mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
+ state_space: ast::StateSpace,
+ ) -> Result<Id, TranslateError> {
match desc.op {
- TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
- TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ),
+ TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some((typ, state_space))),
+ TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ, state_space),
TypedOperand::RegOffset(reg, offset) => {
- self.reg_offset(desc.new_op((reg, offset)), typ)
+ self.reg_offset(desc.new_op((reg, offset)), typ, state_space)
}
- TypedOperand::VecMember(..) => Err(error_unreachable()),
+ TypedOperand::VecMember(..) => Err(TranslateError::unreachable()),
}
}
}
-/*
- There are several kinds of implicit conversions in PTX:
- * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands
- * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size
- - ld.param: not documented, but for instruction `ld.param.<type> x, [y]`,
- semantics are to first zext/chop/bitcast `y` as needed and then do
- documented special ld/st/cvt conversion rules for destination operands
- - st.param [x] y (used as function return arguments) same rule as above applies
- - generic/global ld: for instruction `ld x, [y]`, y must be of type
- b64/u64/s64, which is bitcast to a pointer, dereferenced and then
- documented special ld/st/cvt conversion rules are applied to dst
- - generic/global st: for instruction `st [x], y`, x must be of type
- b64/u64/s64, which is bitcast to a pointer
-*/
-fn insert_implicit_conversions(
- func: Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
-) -> Result<Vec<ExpandedStatement>, TranslateError> {
- let mut result = Vec::with_capacity(func.len());
- for s in func.into_iter() {
- match s {
- Statement::Call(call) => insert_implicit_conversions_impl(
- &mut result,
- id_def,
- call,
- should_bitcast_wrapper,
- None,
- )?,
- Statement::Instruction(inst) => {
- let mut default_conversion_fn =
- should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _;
- let mut state_space = None;
- if let ast::Instruction::Ld(d, _) = &inst {
- state_space = Some(d.state_space);
- }
- if let ast::Instruction::St(d, _) = &inst {
- state_space = Some(d.state_space.to_ld_ss());
- }
- if let ast::Instruction::Atom(d, _) = &inst {
- state_space = Some(d.space.to_ld_ss());
- }
- if let ast::Instruction::AtomCas(d, _) = &inst {
- state_space = Some(d.space.to_ld_ss());
- }
- if let ast::Instruction::Mov(..) = &inst {
- default_conversion_fn = should_bitcast_packed;
- }
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- inst,
- default_conversion_fn,
- state_space,
- )?;
- }
- Statement::PtrAccess(PtrAccess {
- underlying_type,
- state_space,
- dst,
- ptr_src,
- offset_src: constant_src,
- }) => {
- let visit_desc = VisitArgumentDescriptor {
- desc: ArgumentDescriptor {
- op: ptr_src,
- is_dst: false,
- sema: ArgumentSemantics::PhysicalPointer,
- },
- typ: &ast::Type::Pointer(underlying_type.clone(), state_space),
- stmt_ctor: |new_ptr_src| {
- Statement::PtrAccess(PtrAccess {
- underlying_type,
- state_space,
- dst,
- ptr_src: new_ptr_src,
- offset_src: constant_src,
- })
- },
- };
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- visit_desc,
- bitcast_physical_pointer,
- Some(state_space),
- )?;
- }
- Statement::RepackVector(repack) => insert_implicit_conversions_impl(
- &mut result,
- id_def,
- repack,
- should_bitcast_wrapper,
- None,
- )?,
- s @ Statement::Conditional(_)
- | s @ Statement::Conversion(_)
- | s @ Statement::Label(_)
- | s @ Statement::Constant(_)
- | s @ Statement::Variable(_)
- | s @ Statement::LoadVar(..)
- | s @ Statement::StoreVar(..)
- | s @ Statement::RetValue(_, _) => result.push(s),
- }
- }
- Ok(result)
-}
-
fn insert_implicit_conversions_impl(
func: &mut Vec<ExpandedStatement>,
- id_def: &mut MutableNumericIdResolver,
+ id_def: &mut IdNameMapBuilder,
stmt: impl Visitable<ExpandedArgParams, ExpandedArgParams>,
- default_conversion_fn: for<'a> fn(
- &'a ast::Type,
- &'a ast::Type,
- Option<ast::LdStateSpace>,
- ) -> Result<Option<ConversionKind>, TranslateError>,
- state_space: Option<ast::LdStateSpace>,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
- let statement = stmt.visit(
- &mut |desc: ArgumentDescriptor<spirv::Word>, typ: Option<&ast::Type>| {
- let instr_type = match typ {
+ let statement =
+ stmt.visit(&mut |desc: ArgumentDescriptor<Id>,
+ typ: Option<(&ast::Type, ast::StateSpace)>| {
+ let (instr_type, instruction_space) = match typ {
None => return Ok(desc.op),
Some(t) => t,
};
- let operand_type = id_def.get_typed(desc.op)?;
- let mut conversion_fn = default_conversion_fn;
- match desc.sema {
- ArgumentSemantics::Default => {}
- ArgumentSemantics::DefaultRelaxed => {
- if desc.is_dst {
- conversion_fn = should_convert_relaxed_dst_wrapper;
- } else {
- conversion_fn = should_convert_relaxed_src_wrapper;
- }
- }
- ArgumentSemantics::PhysicalPointer => {
- conversion_fn = bitcast_physical_pointer;
- }
- ArgumentSemantics::RegisterPointer => {
- conversion_fn = bitcast_register_pointer;
- }
- ArgumentSemantics::Address => {
- conversion_fn = force_bitcast_ptr_to_bit;
- }
- };
- match conversion_fn(&operand_type, instr_type, state_space)? {
+ let (operand_type, operand_space, ..) = id_def.get_typed(desc.op)?;
+ let conversion_fn = desc
+ .non_default_implicit_conversion
+ .unwrap_or(default_implicit_conversion);
+ match conversion_fn(
+ (operand_space, &operand_type),
+ (instruction_space, instr_type),
+ )? {
Some(conv_kind) => {
let conv_output = if desc.is_dst {
&mut post_conv
} else {
&mut *func
};
- let mut from = instr_type.clone();
- let mut to = operand_type;
- let mut src = id_def.new_non_variable(instr_type.clone());
+ let mut from_type = instr_type.clone();
+ let mut from_space = instruction_space;
+ let mut to_type = operand_type;
+ let mut to_space = operand_space;
+ let mut src =
+ id_def.register_intermediate(Some((instr_type.clone(), instruction_space)));
let mut dst = desc.op;
let result = Ok(src);
if !desc.is_dst {
mem::swap(&mut src, &mut dst);
- mem::swap(&mut from, &mut to);
+ mem::swap(&mut from_type, &mut to_type);
+ mem::swap(&mut from_space, &mut to_space);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
- from,
- to,
+ from_type,
+ from_space,
+ to_type,
+ to_space,
kind: conv_kind,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
}));
result
}
None => Ok(desc.op),
}
- },
- )?;
+ })?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
}
-fn get_function_type(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- spirv_input: &[ast::Variable<ast::Type, spirv::Word>],
- spirv_output: &[ast::Variable<ast::Type, spirv::Word>],
-) -> (spirv::Word, spirv::Word) {
- map.get_or_add_fn(
- builder,
- spirv_input
- .iter()
- .map(|var| SpirvType::from(var.v_type.clone())),
- spirv_output
- .iter()
- .map(|var| SpirvType::from(var.v_type.clone())),
- )
+#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
+enum PtxSpecialRegister {
+ Tid,
+ Ntid,
+ Ctaid,
+ Nctaid,
+ Clock,
+ LanemaskLt,
+ LanemaskLe,
+ LanemaskGe,
+ Laneid,
+ Clock64,
}
-fn emit_function_body_ops(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- func: &[ExpandedStatement],
-) -> Result<(), TranslateError> {
- for s in func {
- match s {
- Statement::Label(id) => {
- if builder.block.is_some() {
- builder.branch(*id)?;
- }
- builder.begin_block(Some(*id))?;
- }
- _ => {
- if builder.block.is_none() && builder.function.is_some() {
- builder.begin_block(None)?;
- }
- }
- }
+impl PtxSpecialRegister {
+ fn try_parse(s: &str) -> Option<Self> {
match s {
- Statement::Label(_) => (),
- Statement::Call(call) => {
- let (result_type, result_id) = match &*call.ret_params {
- [(id, typ)] => (
- map.get_or_add(builder, SpirvType::from(typ.to_func_type())),
- Some(*id),
- ),
- [] => (map.void(), None),
- _ => todo!(),
- };
- let arg_list = call
- .param_list
- .iter()
- .map(|(id, _)| *id)
- .collect::<Vec<_>>();
- builder.function_call(result_type, result_id, call.func, arg_list)?;
- }
- Statement::Variable(var) => {
- emit_variable(builder, map, var)?;
- }
- Statement::Constant(cnst) => {
- let typ_id = map.get_or_add_scalar(builder, cnst.typ);
- match (cnst.typ, cnst.value) {
- (ast::ScalarType::B8, ast::ImmediateValue::U64(value))
- | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => {
- builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32);
- }
- (ast::ScalarType::B16, ast::ImmediateValue::U64(value))
- | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => {
- builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32);
- }
- (ast::ScalarType::B32, ast::ImmediateValue::U64(value))
- | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => {
- builder.constant_u32(typ_id, Some(cnst.dst), value as u32);
- }
- (ast::ScalarType::B64, ast::ImmediateValue::U64(value))
- | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => {
- builder.constant_u64(typ_id, Some(cnst.dst), value);
- }
- (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => {
- builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32);
- }
- (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => {
- builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32);
- }
- (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => {
- builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32);
- }
- (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => {
- builder.constant_u64(typ_id, Some(cnst.dst), value as i64 as u64);
- }
- (ast::ScalarType::B8, ast::ImmediateValue::S64(value))
- | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => {
- builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32);
- }
- (ast::ScalarType::B16, ast::ImmediateValue::S64(value))
- | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => {
- builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32);
- }
- (ast::ScalarType::B32, ast::ImmediateValue::S64(value))
- | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => {
- builder.constant_u32(typ_id, Some(cnst.dst), value as u32);
- }
- (ast::ScalarType::B64, ast::ImmediateValue::S64(value))
- | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => {
- builder.constant_u64(typ_id, Some(cnst.dst), value as u64);
- }
- (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => {
- builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32);
- }
- (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => {
- builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32);
- }
- (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => {
- builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32);
- }
- (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => {
- builder.constant_u64(typ_id, Some(cnst.dst), value as u64);
- }
- (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => {
- builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f32(value).to_f32());
- }
- (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => {
- builder.constant_f32(typ_id, Some(cnst.dst), value);
- }
- (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => {
- builder.constant_f64(typ_id, Some(cnst.dst), value as f64);
- }
- (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => {
- builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f64(value).to_f32());
- }
- (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => {
- builder.constant_f32(typ_id, Some(cnst.dst), value as f32);
- }
- (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => {
- builder.constant_f64(typ_id, Some(cnst.dst), value);
- }
- (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => {
- let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
- if value == 0 {
- builder.constant_false(bool_type, Some(cnst.dst));
- } else {
- builder.constant_true(bool_type, Some(cnst.dst));
- }
- }
- (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => {
- let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
- if value == 0 {
- builder.constant_false(bool_type, Some(cnst.dst));
- } else {
- builder.constant_true(bool_type, Some(cnst.dst));
- }
- }
- _ => return Err(TranslateError::MismatchedType),
- }
- }
- Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?,
- Statement::Conditional(bra) => {
- builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?;
- }
- Statement::Instruction(inst) => match inst {
- ast::Instruction::Abs(d, arg) => emit_abs(builder, map, opencl, d, arg)?,
- ast::Instruction::Call(_) => unreachable!(),
- // SPIR-V does not support marking jumps as guaranteed-converged
- ast::Instruction::Bra(_, arg) => {
- builder.branch(arg.src)?;
- }
- ast::Instruction::Ld(data, arg) => {
- if data.qualifier != ast::LdStQualifier::Weak {
- todo!()
- }
- let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone())));
- builder.load(
- result_type,
- Some(arg.dst),
- arg.src,
- Some(spirv::MemoryAccess::ALIGNED),
- [dr::Operand::LiteralInt32(
- ast::Type::from(data.typ.clone()).size_of() as u32,
- )],
- )?;
- }
- ast::Instruction::St(data, arg) => {
- if data.qualifier != ast::LdStQualifier::Weak {
- todo!()
- }
- builder.store(
- arg.src1,
- arg.src2,
- Some(spirv::MemoryAccess::ALIGNED),
- [dr::Operand::LiteralInt32(
- ast::Type::from(data.typ.clone()).size_of() as u32,
- )],
- )?;
- }
- // SPIR-V does not support ret as guaranteed-converged
- ast::Instruction::Ret(_) => builder.ret()?,
- ast::Instruction::Mov(d, arg) => {
- let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone())));
- builder.copy_object(result_type, Some(arg.dst), arg.src)?;
- }
- ast::Instruction::Mul(mul, arg) => match mul {
- ast::MulDetails::Signed(ref ctr) => {
- emit_mul_sint(builder, map, opencl, ctr, arg)?
- }
- ast::MulDetails::Unsigned(ref ctr) => {
- emit_mul_uint(builder, map, opencl, ctr, arg)?
- }
- ast::MulDetails::Float(ref ctr) => emit_mul_float(builder, map, ctr, arg)?,
- },
- ast::Instruction::Add(add, arg) => match add {
- ast::ArithDetails::Signed(ref desc) => {
- emit_add_int(builder, map, desc.typ.into(), desc.saturate, arg)?
- }
- ast::ArithDetails::Unsigned(ref desc) => {
- emit_add_int(builder, map, (*desc).into(), false, arg)?
- }
- ast::ArithDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?,
- },
- ast::Instruction::Setp(setp, arg) => {
- if arg.dst2.is_some() {
- todo!()
- }
- emit_setp(builder, map, setp, arg)?;
- }
- ast::Instruction::Not(t, a) => {
- let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
- let result_id = Some(a.dst);
- let operand = a.src;
- match t {
- ast::BooleanType::Pred => {
- // HACK ALERT
- // Temporary workaround until IGC gets its shit together
- // Currently IGC carries two copies of SPIRV-LLVM translator
- // a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/.
- // Obviously, old and buggy one is used for compiling L0 SPIRV
- // https://github.com/intel/intel-graphics-compiler/issues/148
- let type_pred = map.get_or_add_scalar(builder, ast::ScalarType::Pred);
- let const_true = builder.constant_true(type_pred, None);
- let const_false = builder.constant_false(type_pred, None);
- builder.select(result_type, result_id, operand, const_false, const_true)
- }
- _ => builder.not(result_type, result_id, operand),
- }?;
- }
- ast::Instruction::Shl(t, a) => {
- let full_type = t.to_type();
- let size_of = full_type.size_of();
- let result_type = map.get_or_add(builder, SpirvType::from(full_type));
- let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?;
- builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?;
- }
- ast::Instruction::Shr(t, a) => {
- let full_type = ast::ScalarType::from(*t);
- let size_of = full_type.size_of();
- let result_type = map.get_or_add_scalar(builder, full_type);
- let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?;
- if t.signed() {
- builder.shift_right_arithmetic(
- result_type,
- Some(a.dst),
- a.src1,
- offset_src,
- )?;
- } else {
- builder.shift_right_logical(
- result_type,
- Some(a.dst),
- a.src1,
- offset_src,
- )?;
- }
- }
- ast::Instruction::Cvt(dets, arg) => {
- emit_cvt(builder, map, opencl, dets, arg)?;
- }
- ast::Instruction::Cvta(_, arg) => {
- // This would be only meaningful if const/slm/global pointers
- // had a different format than generic pointers, but they don't pretty much by ptx definition
- // Honestly, I have no idea why this instruction exists and is emitted by the compiler
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64);
- builder.copy_object(result_type, Some(arg.dst), arg.src)?;
- }
- ast::Instruction::SetpBool(_, _) => todo!(),
- ast::Instruction::Mad(mad, arg) => match mad {
- ast::MulDetails::Signed(ref desc) => {
- emit_mad_sint(builder, map, opencl, desc, arg)?
- }
- ast::MulDetails::Unsigned(ref desc) => {
- emit_mad_uint(builder, map, opencl, desc, arg)?
- }
- ast::MulDetails::Float(desc) => {
- emit_mad_float(builder, map, opencl, desc, arg)?
- }
- },
- ast::Instruction::Or(t, a) => {
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- if *t == ast::BooleanType::Pred {
- builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?;
- } else {
- builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
- }
- }
- ast::Instruction::Sub(d, arg) => match d {
- ast::ArithDetails::Signed(desc) => {
- emit_sub_int(builder, map, desc.typ.into(), desc.saturate, arg)?;
- }
- ast::ArithDetails::Unsigned(desc) => {
- emit_sub_int(builder, map, (*desc).into(), false, arg)?;
- }
- ast::ArithDetails::Float(desc) => {
- emit_sub_float(builder, map, desc, arg)?;
- }
- },
- ast::Instruction::Min(d, a) => {
- emit_min(builder, map, opencl, d, a)?;
- }
- ast::Instruction::Max(d, a) => {
- emit_max(builder, map, opencl, d, a)?;
- }
- ast::Instruction::Rcp(d, a) => {
- emit_rcp(builder, map, d, a)?;
- }
- ast::Instruction::And(t, a) => {
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- if *t == ast::BooleanType::Pred {
- builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?;
- } else {
- builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?;
- }
- }
- ast::Instruction::Selp(t, a) => {
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- builder.select(result_type, Some(a.dst), a.src3, a.src1, a.src2)?;
- }
- // TODO: implement named barriers
- ast::Instruction::Bar(d, _) => {
- let workgroup_scope = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(spirv::Scope::Workgroup as u32),
- )?;
- let barrier_semantics = match d {
- ast::BarDetails::SyncAligned => map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(
- spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY
- | spirv::MemorySemantics::WORKGROUP_MEMORY
- | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT,
- ),
- )?,
- };
- builder.control_barrier(workgroup_scope, workgroup_scope, barrier_semantics)?;
- }
- ast::Instruction::Atom(details, arg) => {
- emit_atom(builder, map, details, arg)?;
- }
- ast::Instruction::AtomCas(details, arg) => {
- let result_type = map.get_or_add_scalar(builder, details.typ.into());
- let memory_const = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(details.scope.to_spirv() as u32),
- )?;
- let semantics_const = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(details.semantics.to_spirv().bits()),
- )?;
- builder.atomic_compare_exchange(
- result_type,
- Some(arg.dst),
- arg.src1,
- memory_const,
- semantics_const,
- semantics_const,
- arg.src3,
- arg.src2,
- )?;
- }
- ast::Instruction::Div(details, arg) => match details {
- ast::DivDetails::Unsigned(t) => {
- let result_type = map.get_or_add_scalar(builder, (*t).into());
- builder.u_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
- }
- ast::DivDetails::Signed(t) => {
- let result_type = map.get_or_add_scalar(builder, (*t).into());
- builder.s_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
- }
- ast::DivDetails::Float(t) => {
- let result_type = map.get_or_add_scalar(builder, t.typ.into());
- builder.f_div(result_type, Some(arg.dst), arg.src1, arg.src2)?;
- emit_float_div_decoration(builder, arg.dst, t.kind);
- }
- },
- ast::Instruction::Sqrt(details, a) => {
- emit_sqrt(builder, map, opencl, details, a)?;
- }
- ast::Instruction::Rsqrt(details, a) => {
- let result_type = map.get_or_add_scalar(builder, details.typ.into());
- builder.ext_inst(
- result_type,
- Some(a.dst),
- opencl,
- spirv::CLOp::native_rsqrt as spirv::Word,
- &[a.src],
- )?;
- }
- ast::Instruction::Neg(details, arg) => {
- let result_type = map.get_or_add_scalar(builder, details.typ);
- let negate_func = if details.typ.kind() == ScalarKind::Float {
- dr::Builder::f_negate
- } else {
- dr::Builder::s_negate
- };
- negate_func(builder, result_type, Some(arg.dst), arg.src)?;
- }
- ast::Instruction::Sin { arg, .. } => {
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
- builder.ext_inst(
- result_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::sin as u32,
- [arg.src],
- )?;
- }
- ast::Instruction::Cos { arg, .. } => {
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
- builder.ext_inst(
- result_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::cos as u32,
- [arg.src],
- )?;
- }
- ast::Instruction::Lg2 { arg, .. } => {
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
- builder.ext_inst(
- result_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::log2 as u32,
- [arg.src],
- )?;
- }
- ast::Instruction::Ex2 { arg, .. } => {
- let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32);
- builder.ext_inst(
- result_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::exp2 as u32,
- [arg.src],
- )?;
- }
- ast::Instruction::Clz { typ, arg } => {
- let result_type = map.get_or_add_scalar(builder, (*typ).into());
- builder.ext_inst(
- result_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::clz as u32,
- [arg.src],
- )?;
- }
- ast::Instruction::Brev { typ, arg } => {
- let result_type = map.get_or_add_scalar(builder, (*typ).into());
- builder.bit_reverse(result_type, Some(arg.dst), arg.src)?;
- }
- ast::Instruction::Popc { typ, arg } => {
- let result_type = map.get_or_add_scalar(builder, (*typ).into());
- builder.bit_count(result_type, Some(arg.dst), arg.src)?;
- }
- ast::Instruction::Xor { typ, arg } => {
- let builder_fn = match typ {
- ast::BooleanType::Pred => emit_logical_xor_spirv,
- _ => dr::Builder::bitwise_xor,
- };
- let result_type = map.get_or_add_scalar(builder, (*typ).into());
- builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
- }
- ast::Instruction::Bfe { typ, arg } => {
- let builder_fn = if typ.is_signed() {
- dr::Builder::bit_field_s_extract
- } else {
- dr::Builder::bit_field_u_extract
- };
- let result_type = map.get_or_add_scalar(builder, (*typ).into());
- builder_fn(
- builder,
- result_type,
- Some(arg.dst),
- arg.src1,
- arg.src2,
- arg.src3,
- )?;
- }
- ast::Instruction::Rem { typ, arg } => {
- let builder_fn = if typ.is_signed() {
- dr::Builder::s_mod
- } else {
- dr::Builder::u_mod
- };
- let result_type = map.get_or_add_scalar(builder, (*typ).into());
- builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?;
- }
- },
- Statement::LoadVar(details) => {
- emit_load_var(builder, map, details)?;
- }
- Statement::StoreVar(details) => {
- let dst_ptr = match details.member_index {
- Some(index) => {
- let result_ptr_type = map.get_or_add(
- builder,
- SpirvType::new_pointer(
- details.typ.clone(),
- spirv::StorageClass::Function,
- ),
- );
- let index_spirv = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(index as u32),
- )?;
- builder.in_bounds_access_chain(
- result_ptr_type,
- None,
- details.arg.src1,
- &[index_spirv],
- )?
- }
- None => details.arg.src1,
- };
- builder.store(dst_ptr, details.arg.src2, None, [])?;
- }
- Statement::RetValue(_, id) => {
- builder.ret_value(*id)?;
- }
- Statement::PtrAccess(PtrAccess {
- underlying_type,
- state_space,
- dst,
- ptr_src,
- offset_src,
- }) => {
- let u8_pointer = map.get_or_add(
- builder,
- SpirvType::from(ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- *state_space,
- )),
- );
- let result_type = map.get_or_add(
- builder,
- SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)),
- );
- let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
- let temp = builder.in_bounds_ptr_access_chain(
- u8_pointer,
- None,
- ptr_src_u8,
- *offset_src,
- &[],
- )?;
- builder.bitcast(result_type, Some(*dst), temp)?;
- }
- Statement::RepackVector(repack) => {
- if repack.is_extract {
- let scalar_type = map.get_or_add_scalar(builder, repack.typ);
- for (index, dst_id) in repack.unpacked.iter().enumerate() {
- builder.composite_extract(
- scalar_type,
- Some(*dst_id),
- repack.packed,
- &[index as u32],
- )?;
- }
- } else {
- let vector_type = map.get_or_add(
- builder,
- SpirvType::Vector(
- SpirvScalarKey::from(repack.typ),
- repack.unpacked.len() as u8,
- ),
- );
- let mut temp_vec = builder.undef(vector_type, None);
- for (index, src_id) in repack.unpacked.iter().enumerate() {
- temp_vec = builder.composite_insert(
- vector_type,
- None,
- *src_id,
- temp_vec,
- &[index as u32],
- )?;
- }
- builder.copy_object(vector_type, Some(repack.packed), temp_vec)?;
- }
- }
+ "%tid" => Some(Self::Tid),
+ "%ntid" => Some(Self::Ntid),
+ "%ctaid" => Some(Self::Ctaid),
+ "%nctaid" => Some(Self::Nctaid),
+ "%clock" => Some(Self::Clock),
+ "%lanemask_lt" => Some(Self::LanemaskLt),
+ "%lanemask_le" => Some(Self::LanemaskLe),
+ "%lanemask_ge" => Some(Self::LanemaskGe),
+ "%laneid" => Some(Self::Laneid),
+ "%clock64" => Some(Self::Clock64),
+ _ => None,
}
}
- Ok(())
-}
-// HACK ALERT
-// For some reason IGC fails linking if the value and shift size are of different type
-fn insert_shift_hack(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- offset_var: spirv::Word,
- size_of: usize,
-) -> Result<spirv::Word, TranslateError> {
- let result_type = match size_of {
- 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16),
- 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64),
- 4 => return Ok(offset_var),
- _ => return Err(error_unreachable()),
- };
- Ok(builder.u_convert(result_type, None, offset_var)?)
-}
-
-// TODO: check what kind of assembly do we emit
-fn emit_logical_xor_spirv(
- builder: &mut dr::Builder,
- result_type: spirv::Word,
- result_id: Option<spirv::Word>,
- op1: spirv::Word,
- op2: spirv::Word,
-) -> Result<spirv::Word, dr::Error> {
- let temp_or = builder.logical_or(result_type, None, op1, op2)?;
- let temp_and = builder.logical_and(result_type, None, op1, op2)?;
- let temp_neg = builder.logical_not(result_type, None, temp_and)?;
- builder.logical_and(result_type, result_id, temp_or, temp_neg)
-}
-
-fn emit_sqrt(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- details: &ast::SqrtDetails,
- a: &ast::Arg2<ExpandedArgParams>,
-) -> Result<(), TranslateError> {
- let result_type = map.get_or_add_scalar(builder, details.typ.into());
- let (ocl_op, rounding) = match details.kind {
- ast::SqrtKind::Approx => (spirv::CLOp::native_sqrt, None),
- ast::SqrtKind::Rounding(rnd) => (spirv::CLOp::sqrt, Some(rnd)),
- };
- builder.ext_inst(
- result_type,
- Some(a.dst),
- opencl,
- ocl_op as spirv::Word,
- &[a.src],
- )?;
- emit_rounding_decoration(builder, a.dst, rounding);
- Ok(())
-}
-
-fn emit_float_div_decoration(builder: &mut dr::Builder, dst: spirv::Word, kind: ast::DivFloatKind) {
- match kind {
- ast::DivFloatKind::Approx => {
- builder.decorate(
- dst,
- spirv::Decoration::FPFastMathMode,
- &[dr::Operand::FPFastMathMode(
- spirv::FPFastMathMode::ALLOW_RECIP,
- )],
- );
- }
- ast::DivFloatKind::Rounding(rnd) => {
- emit_rounding_decoration(builder, dst, Some(rnd));
+ fn get_type(self) -> ast::Type {
+ match self {
+ PtxSpecialRegister::Tid
+ | PtxSpecialRegister::Ntid
+ | PtxSpecialRegister::Ctaid
+ | PtxSpecialRegister::Nctaid => ast::Type::Vector(self.get_function_return_type(), 4),
+ _ => ast::Type::Scalar(self.get_function_return_type()),
}
- ast::DivFloatKind::Full => {}
}
-}
-fn emit_atom(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- details: &ast::AtomDetails,
- arg: &ast::Arg3<ExpandedArgParams>,
-) -> Result<(), TranslateError> {
- let (spirv_op, typ) = match details.inner {
- ast::AtomInnerDetails::Bit { op, typ } => {
- let spirv_op = match op {
- ast::AtomBitOp::And => dr::Builder::atomic_and,
- ast::AtomBitOp::Or => dr::Builder::atomic_or,
- ast::AtomBitOp::Xor => dr::Builder::atomic_xor,
- ast::AtomBitOp::Exchange => dr::Builder::atomic_exchange,
- };
- (spirv_op, ast::ScalarType::from(typ))
- }
- ast::AtomInnerDetails::Unsigned { op, typ } => {
- let spirv_op = match op {
- ast::AtomUIntOp::Add => dr::Builder::atomic_i_add,
- ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => {
- return Err(error_unreachable());
- }
- ast::AtomUIntOp::Min => dr::Builder::atomic_u_min,
- ast::AtomUIntOp::Max => dr::Builder::atomic_u_max,
- };
- (spirv_op, typ.into())
- }
- ast::AtomInnerDetails::Signed { op, typ } => {
- let spirv_op = match op {
- ast::AtomSIntOp::Add => dr::Builder::atomic_i_add,
- ast::AtomSIntOp::Min => dr::Builder::atomic_s_min,
- ast::AtomSIntOp::Max => dr::Builder::atomic_s_max,
- };
- (spirv_op, typ.into())
+ fn get_function_return_type(self) -> ast::ScalarType {
+ match self {
+ PtxSpecialRegister::Tid
+ | PtxSpecialRegister::Ntid
+ | PtxSpecialRegister::Ctaid
+ | PtxSpecialRegister::Nctaid
+ | PtxSpecialRegister::Clock
+ | PtxSpecialRegister::LanemaskLt
+ | PtxSpecialRegister::LanemaskLe
+ | PtxSpecialRegister::LanemaskGe
+ | PtxSpecialRegister::Laneid => ast::ScalarType::U32,
+ PtxSpecialRegister::Clock64 => ast::ScalarType::U64,
}
- // TODO: Hardware is capable of this, implement it through builtin
- ast::AtomInnerDetails::Float { .. } => todo!(),
- };
- let result_type = map.get_or_add_scalar(builder, typ);
- let memory_const = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(details.scope.to_spirv() as u32),
- )?;
- let semantics_const = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(details.semantics.to_spirv().bits()),
- )?;
- spirv_op(
- builder,
- result_type,
- Some(arg.dst),
- arg.src1,
- memory_const,
- semantics_const,
- arg.src2,
- )?;
- Ok(())
-}
-
-#[derive(Clone)]
-struct PtxImplImport {
- out_arg: ast::Type,
- fn_id: u32,
- in_args: Vec<ast::Type>,
-}
-
-fn ptx_semantics_name(sema: ast::AtomSemantics) -> &'static str {
- match sema {
- ast::AtomSemantics::Relaxed => "relaxed",
- ast::AtomSemantics::Acquire => "acquire",
- ast::AtomSemantics::Release => "release",
- ast::AtomSemantics::AcquireRelease => "acq_rel",
}
-}
-fn ptx_scope_name(scope: ast::MemScope) -> &'static str {
- match scope {
- ast::MemScope::Cta => "cta",
- ast::MemScope::Gpu => "gpu",
- ast::MemScope::Sys => "sys",
- }
-}
-
-fn ptx_space_name(space: ast::AtomSpace) -> &'static str {
- match space {
- ast::AtomSpace::Generic => "generic",
- ast::AtomSpace::Global => "global",
- ast::AtomSpace::Shared => "shared",
- }
-}
-
-fn emit_mul_float(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- ctr: &ast::ArithFloat,
- arg: &ast::Arg3<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- if ctr.saturate {
- todo!()
+ fn get_function_input_type(self) -> Option<ast::ScalarType> {
+ match self {
+ PtxSpecialRegister::Tid
+ | PtxSpecialRegister::Ntid
+ | PtxSpecialRegister::Ctaid
+ | PtxSpecialRegister::Nctaid => Some(ast::ScalarType::U8),
+ PtxSpecialRegister::Clock
+ | PtxSpecialRegister::Clock64
+ | PtxSpecialRegister::LanemaskLt
+ | PtxSpecialRegister::LanemaskLe
+ | PtxSpecialRegister::LanemaskGe
+ | PtxSpecialRegister::Laneid => None,
+ }
}
- let result_type = map.get_or_add_scalar(builder, ctr.typ.into());
- builder.f_mul(result_type, Some(arg.dst), arg.src1, arg.src2)?;
- emit_rounding_decoration(builder, arg.dst, ctr.rounding);
- Ok(())
-}
-
-fn emit_rcp(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- desc: &ast::RcpDetails,
- a: &ast::Arg2<ExpandedArgParams>,
-) -> Result<(), TranslateError> {
- let (instr_type, constant) = if desc.is_f64 {
- (ast::ScalarType::F64, vec_repr(1.0f64))
- } else {
- (ast::ScalarType::F32, vec_repr(1.0f32))
- };
- let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?;
- let result_type = map.get_or_add_scalar(builder, instr_type);
- builder.f_div(result_type, Some(a.dst), one, a.src)?;
- emit_rounding_decoration(builder, a.dst, desc.rounding);
- builder.decorate(
- a.dst,
- spirv::Decoration::FPFastMathMode,
- &[dr::Operand::FPFastMathMode(
- spirv::FPFastMathMode::ALLOW_RECIP,
- )],
- );
- Ok(())
-}
-fn vec_repr<T: Copy>(t: T) -> Vec<u8> {
- let mut result = vec![0; mem::size_of::<T>()];
- unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) };
- result
-}
-
-fn emit_variable(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- var: &ast::Variable<ast::VariableType, spirv::Word>,
-) -> Result<(), TranslateError> {
- let (must_init, st_class) = match var.v_type {
- ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
- (false, spirv::StorageClass::Function)
+ fn get_unprefixed_function_name(self) -> &'static str {
+ match self {
+ PtxSpecialRegister::Tid => "sreg_tid",
+ PtxSpecialRegister::Ntid => "sreg_ntid",
+ PtxSpecialRegister::Ctaid => "sreg_ctaid",
+ PtxSpecialRegister::Nctaid => "sreg_nctaid",
+ PtxSpecialRegister::Clock => "sreg_clock",
+ PtxSpecialRegister::Clock64 => "sreg_clock64",
+ PtxSpecialRegister::LanemaskLt => "sreg_lanemask_lt",
+ PtxSpecialRegister::LanemaskLe => "sreg_lanemask_le",
+ PtxSpecialRegister::LanemaskGe => "sreg_lanemask_ge",
+ PtxSpecialRegister::Laneid => "sreg_laneid",
}
- ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
- ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup),
- };
- let initalizer = if var.array_init.len() > 0 {
- Some(map.get_or_add_constant(
- builder,
- &ast::Type::from(var.v_type.clone()),
- &*var.array_init,
- )?)
- } else if must_init {
- let type_id = map.get_or_add(
- builder,
- SpirvType::from(ast::Type::from(var.v_type.clone())),
- );
- Some(builder.constant_null(type_id, None))
- } else {
- None
- };
- let ptr_type_id = map.get_or_add(
- builder,
- SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class),
- );
- builder.variable(ptr_type_id, Some(var.name), st_class, initalizer);
- if let Some(align) = var.align {
- builder.decorate(
- var.name,
- spirv::Decoration::Alignment,
- &[dr::Operand::LiteralInt32(align)],
- );
}
- Ok(())
}
-fn emit_mad_uint(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- desc: &ast::MulUInt,
- arg: &ast::Arg4<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
- match desc.control {
- ast::MulIntControl::Low => {
- let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?;
- builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
- }
- ast::MulIntControl::High => {
- builder.ext_inst(
- inst_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::u_mad_hi as spirv::Word,
- [arg.src1, arg.src2, arg.src3],
- )?;
- }
- ast::MulIntControl::Wide => todo!(),
- };
- Ok(())
+struct SpecialRegistersMap {
+ reg_to_id: FxHashMap<PtxSpecialRegister, Id>,
+ id_to_reg: FxHashMap<Id, PtxSpecialRegister>,
}
-fn emit_mad_sint(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- desc: &ast::MulSInt,
- arg: &ast::Arg4<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
- match desc.control {
- ast::MulIntControl::Low => {
- let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?;
- builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?;
- }
- ast::MulIntControl::High => {
- builder.ext_inst(
- inst_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::s_mad_hi as spirv::Word,
- [arg.src1, arg.src2, arg.src3],
- )?;
+impl SpecialRegistersMap {
+ fn new() -> Self {
+ SpecialRegistersMap {
+ reg_to_id: FxHashMap::default(),
+ id_to_reg: FxHashMap::default(),
}
- ast::MulIntControl::Wide => todo!(),
- };
- Ok(())
-}
-
-fn emit_mad_float(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- desc: &ast::ArithFloat,
- arg: &ast::Arg4<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
- builder.ext_inst(
- inst_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::mad as spirv::Word,
- [arg.src1, arg.src2, arg.src3],
- )?;
- Ok(())
-}
-
-fn emit_add_float(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- desc: &ast::ArithFloat,
- arg: &ast::Arg3<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
- builder.f_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
- emit_rounding_decoration(builder, arg.dst, desc.rounding);
- Ok(())
-}
-
-fn emit_sub_float(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- desc: &ast::ArithFloat,
- arg: &ast::Arg3<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
- builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
- emit_rounding_decoration(builder, arg.dst, desc.rounding);
- Ok(())
-}
-
-fn emit_min(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- desc: &ast::MinMaxDetails,
- arg: &ast::Arg3<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- let cl_op = match desc {
- ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min,
- ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
- ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
- };
- let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
- builder.ext_inst(
- inst_type,
- Some(arg.dst),
- opencl,
- cl_op as spirv::Word,
- [arg.src1, arg.src2],
- )?;
- Ok(())
-}
+ }
-fn emit_max(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- desc: &ast::MinMaxDetails,
- arg: &ast::Arg3<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- let cl_op = match desc {
- ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max,
- ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
- ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
- };
- let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
- builder.ext_inst(
- inst_type,
- Some(arg.dst),
- opencl,
- cl_op as spirv::Word,
- [arg.src1, arg.src2],
- )?;
- Ok(())
-}
+ fn get(&self, id: Id) -> Option<PtxSpecialRegister> {
+ self.id_to_reg.get(&id).copied()
+ }
-fn emit_cvt(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- dets: &ast::CvtDetails,
- arg: &ast::Arg2<ExpandedArgParams>,
-) -> Result<(), TranslateError> {
- match dets {
- ast::CvtDetails::FloatFromFloat(desc) => {
- if desc.saturate {
- todo!()
- }
- let dest_t: ast::ScalarType = desc.dst.into();
- let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- if desc.dst == desc.src {
- match desc.rounding {
- Some(ast::RoundingMode::NearestEven) => {
- builder.ext_inst(
- result_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::rint as u32,
- [arg.src],
- )?;
- }
- Some(ast::RoundingMode::Zero) => {
- builder.ext_inst(
- result_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::trunc as u32,
- [arg.src],
- )?;
- }
- Some(ast::RoundingMode::NegativeInf) => {
- builder.ext_inst(
- result_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::floor as u32,
- [arg.src],
- )?;
- }
- Some(ast::RoundingMode::PositiveInf) => {
- builder.ext_inst(
- result_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::ceil as u32,
- [arg.src],
- )?;
- }
- None => {
- builder.copy_object(result_type, Some(arg.dst), arg.src)?;
- }
- }
- } else {
- builder.f_convert(result_type, Some(arg.dst), arg.src)?;
- emit_rounding_decoration(builder, arg.dst, desc.rounding);
- }
- }
- ast::CvtDetails::FloatFromInt(desc) => {
- if desc.saturate {
- todo!()
- }
- let dest_t: ast::ScalarType = desc.dst.into();
- let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- if desc.src.is_signed() {
- builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
- } else {
- builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?;
- }
- emit_rounding_decoration(builder, arg.dst, desc.rounding);
- }
- ast::CvtDetails::IntFromFloat(desc) => {
- let dest_t: ast::ScalarType = desc.dst.into();
- let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- if desc.dst.is_signed() {
- builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?;
- } else {
- builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?;
- }
- emit_rounding_decoration(builder, arg.dst, desc.rounding);
- emit_saturating_decoration(builder, arg.dst, desc.saturate);
- }
- ast::CvtDetails::IntFromInt(desc) => {
- let dest_t: ast::ScalarType = desc.dst.into();
- let src_t: ast::ScalarType = desc.src.into();
- // first do shortening/widening
- let src = if desc.dst.width() != desc.src.width() {
- let new_dst = if dest_t.kind() == src_t.kind() {
- arg.dst
- } else {
- builder.id()
- };
- let cv = ImplicitConversion {
- src: arg.src,
- dst: new_dst,
- from: ast::Type::Scalar(src_t),
- to: ast::Type::Scalar(ast::ScalarType::from_parts(
- dest_t.size_of(),
- src_t.kind(),
- )),
- kind: ConversionKind::Default,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
- };
- emit_implicit_conversion(builder, map, &cv)?;
- new_dst
- } else {
- arg.src
- };
- if dest_t.kind() == src_t.kind() {
- return Ok(());
- }
- // now do actual conversion
- let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- if desc.saturate {
- if desc.dst.is_signed() {
- builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?;
- } else {
- builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?;
- }
- } else {
- builder.bitcast(result_type, Some(arg.dst), src)?;
+ fn get_or_add(&mut self, id_gen: &mut IdGenerator, reg: PtxSpecialRegister) -> Id {
+ match self.reg_to_id.entry(reg) {
+ hash_map::Entry::Occupied(e) => *e.get(),
+ hash_map::Entry::Vacant(e) => {
+ let numeric_id = id_gen.next();
+ e.insert(numeric_id);
+ self.id_to_reg.insert(numeric_id, reg);
+ numeric_id
}
}
}
- Ok(())
}
-fn emit_saturating_decoration(builder: &mut dr::Builder, dst: u32, saturate: bool) {
- if saturate {
- builder.decorate(dst, spirv::Decoration::SaturatedConversion, []);
- }
-}
+#[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
+pub struct Id(NonZeroU32);
-fn emit_rounding_decoration(
- builder: &mut dr::Builder,
- dst: spirv::Word,
- rounding: Option<ast::RoundingMode>,
-) {
- if let Some(rounding) = rounding {
- builder.decorate(
- dst,
- spirv::Decoration::FPRoundingMode,
- [rounding.to_spirv()],
- );
+impl Id {
+ pub(crate) fn get(self) -> u32 {
+ self.0.get()
}
}
-impl ast::RoundingMode {
- fn to_spirv(self) -> rspirv::dr::Operand {
- let mode = match self {
- ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE,
- ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ,
- ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP,
- ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN,
- };
- rspirv::dr::Operand::FPRoundingMode(mode)
- }
+pub(crate) struct IdGenerator {
+ pub(crate) next: NonZeroU32,
}
-fn emit_setp(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- setp: &ast::SetpData,
- arg: &ast::Arg4Setp<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- let result_type = map.get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred));
- let result_id = Some(arg.dst1);
- let operand_1 = arg.src1;
- let operand_2 = arg.src2;
- match (setp.cmp_op, setp.typ.kind()) {
- (ast::SetpCompareOp::Eq, ScalarKind::Signed)
- | (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Eq, ScalarKind::Bit) => {
- builder.i_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::Eq, ScalarKind::Float) => {
- builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::NotEq, ScalarKind::Signed)
- | (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => {
- builder.i_not_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
- builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::Less, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Less, ScalarKind::Bit) => {
- builder.u_less_than(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::Less, ScalarKind::Signed) => {
- builder.s_less_than(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::Less, ScalarKind::Float) => {
- builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => {
- builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
- builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => {
- builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Greater, ScalarKind::Bit) => {
- builder.u_greater_than(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
- builder.s_greater_than(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::Greater, ScalarKind::Float) => {
- builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => {
- builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
- builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
- builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::NanEq, _) => {
- builder.f_unord_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::NanNotEq, _) => {
- builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::NanLess, _) => {
- builder.f_unord_less_than(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::NanLessOrEq, _) => {
- builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::NanGreater, _) => {
- builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2)
- }
- (ast::SetpCompareOp::NanGreaterOrEq, _) => {
- builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2)
+impl IdGenerator {
+ pub(crate) fn new() -> Self {
+ Self {
+ next: unsafe { NonZeroU32::new_unchecked(1) },
}
- _ => todo!(),
- }?;
- Ok(())
-}
+ }
-fn emit_mul_sint(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- desc: &ast::MulSInt,
- arg: &ast::Arg3<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- let instruction_type = ast::ScalarType::from(desc.typ);
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
- match desc.control {
- ast::MulIntControl::Low => {
- builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
- }
- ast::MulIntControl::High => {
- builder.ext_inst(
- inst_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::s_mul_hi as spirv::Word,
- [arg.src1, arg.src2],
- )?;
- }
- ast::MulIntControl::Wide => {
- let mul_ext_type = SpirvType::Struct(vec![
- SpirvScalarKey::from(instruction_type),
- SpirvScalarKey::from(instruction_type),
- ]);
- let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
- let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
- let instr_width = instruction_type.size_of();
- let instr_kind = instruction_type.kind();
- let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
- let dst_type_id = map.get_or_add_scalar(builder, dst_type);
- struct2_bitcast_to_wide(
- builder,
- map,
- SpirvScalarKey::from(instruction_type),
- inst_type,
- arg.dst,
- dst_type_id,
- mul,
- )?;
- }
+ pub(crate) fn next(&mut self) -> Id {
+ self.reserve(1).next().unwrap()
}
- Ok(())
-}
-fn emit_mul_uint(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- desc: &ast::MulUInt,
- arg: &ast::Arg3<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- let instruction_type = ast::ScalarType::from(desc.typ);
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ)));
- match desc.control {
- ast::MulIntControl::Low => {
- builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
- }
- ast::MulIntControl::High => {
- builder.ext_inst(
- inst_type,
- Some(arg.dst),
- opencl,
- spirv::CLOp::u_mul_hi as spirv::Word,
- [arg.src1, arg.src2],
- )?;
- }
- ast::MulIntControl::Wide => {
- let mul_ext_type = SpirvType::Struct(vec![
- SpirvScalarKey::from(instruction_type),
- SpirvScalarKey::from(instruction_type),
- ]);
- let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
- let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
- let instr_width = instruction_type.size_of();
- let instr_kind = instruction_type.kind();
- let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
- let dst_type_id = map.get_or_add_scalar(builder, dst_type);
- struct2_bitcast_to_wide(
- builder,
- map,
- SpirvScalarKey::from(instruction_type),
- inst_type,
- arg.dst,
- dst_type_id,
- mul,
- )?;
+ // Returns first reserved id
+ pub(crate) fn reserve(&mut self, count: usize) -> impl ExactSizeIterator<Item = Id> + Clone {
+ let value = self.next.get();
+ unsafe {
+ self.next = NonZeroU32::new_unchecked(value + count as u32);
+ let start = Id(NonZeroU32::new_unchecked(value));
+ let end = Id(self.next);
+ (start.0.get()..end.0.get()).map(|i| Id(NonZeroU32::new_unchecked(i)))
}
}
- Ok(())
}
-// Surprisingly, structs can't be bitcast, so we route everything through a vector
-fn struct2_bitcast_to_wide(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- base_type_key: SpirvScalarKey,
- instruction_type: spirv::Word,
- dst: spirv::Word,
- dst_type_id: spirv::Word,
- src: spirv::Word,
-) -> Result<(), dr::Error> {
- let low_bits = builder.composite_extract(instruction_type, None, src, [0])?;
- let high_bits = builder.composite_extract(instruction_type, None, src, [1])?;
- let vector_type = map.get_or_add(builder, SpirvType::Vector(base_type_key, 2));
- let vector = builder.composite_construct(vector_type, None, [low_bits, high_bits])?;
- builder.bitcast(dst_type_id, Some(dst), vector)?;
- Ok(())
+pub(crate) struct IdNameMapBuilder<'input> {
+ pub(crate) id_gen: IdGenerator,
+ type_check: FxHashMap<Id, Option<(ast::Type, ast::StateSpace, Option<u32>, bool)>>,
+ pub(crate) globals: GlobalsResolver<'input>,
}
-fn emit_abs(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- opencl: spirv::Word,
- d: &ast::AbsDetails,
- arg: &ast::Arg2<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- let scalar_t = ast::ScalarType::from(d.typ);
- let result_type = map.get_or_add(builder, SpirvType::from(scalar_t));
- let cl_abs = if scalar_t.kind() == ScalarKind::Signed {
- spirv::CLOp::s_abs
- } else {
- spirv::CLOp::fabs
- };
- builder.ext_inst(
- result_type,
- Some(arg.dst),
- opencl,
- cl_abs as spirv::Word,
- [arg.src],
- )?;
- Ok(())
-}
-
-fn emit_add_int(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- typ: ast::ScalarType,
- saturate: bool,
- arg: &ast::Arg3<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- if saturate {
- todo!()
+impl<'input> IdNameMapBuilder<'input> {
+ pub(crate) fn new(id_gen: IdGenerator) -> Self {
+ let globals = GlobalsResolver {
+ variables: FxHashMap::default(),
+ reverse_variables: FxHashMap::default(),
+ special_registers: SpecialRegistersMap::new(),
+ function_prototypes: FxHashMap::default(),
+ };
+ Self {
+ id_gen,
+ globals,
+ type_check: FxHashMap::default(),
+ }
}
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
- builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
- Ok(())
-}
-fn emit_sub_int(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- typ: ast::ScalarType,
- saturate: bool,
- arg: &ast::Arg3<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
- if saturate {
- todo!()
+ pub(crate) fn get_or_add_non_variable<T: Into<Cow<'input, str>>>(&mut self, id: T) -> Id {
+ self.get_or_add_impl(id.into(), None)
}
- let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ)));
- builder.i_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?;
- Ok(())
-}
-fn emit_implicit_conversion(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- cv: &ImplicitConversion,
-) -> Result<(), TranslateError> {
- let from_parts = cv.from.to_parts();
- let to_parts = cv.to.to_parts();
- match (from_parts.kind, to_parts.kind, cv.kind) {
- (_, _, ConversionKind::PtrToBit(typ)) => {
- let dst_type = map.get_or_add_scalar(builder, typ.into());
- builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
- }
- (_, _, ConversionKind::BitToPtr(_)) => {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
- builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
- }
- (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
- if from_parts.width == to_parts.width {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
- if from_parts.scalar_kind != ScalarKind::Float
- && to_parts.scalar_kind != ScalarKind::Float
- {
- // It is noop, but another instruction expects result of this conversion
- builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
- } else {
- builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
- }
- } else {
- // This block is safe because it's illegal to implictly convert between floating point instructions
- let same_width_bit_type = map.get_or_add(
- builder,
- SpirvType::from(ast::Type::from_parts(TypeParts {
- scalar_kind: ScalarKind::Bit,
- ..from_parts
- })),
- );
- let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
- let wide_bit_type = ast::Type::from_parts(TypeParts {
- scalar_kind: ScalarKind::Bit,
- ..to_parts
- });
- let wide_bit_type_spirv =
- map.get_or_add(builder, SpirvType::from(wide_bit_type.clone()));
- if to_parts.scalar_kind == ScalarKind::Unsigned
- || to_parts.scalar_kind == ScalarKind::Bit
- {
- builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?;
- } else {
- let conversion_fn = if from_parts.scalar_kind == ScalarKind::Signed
- && to_parts.scalar_kind == ScalarKind::Signed
- {
- dr::Builder::s_convert
- } else {
- dr::Builder::u_convert
- };
- let wide_bit_value =
- conversion_fn(builder, wide_bit_type_spirv, None, same_width_bit_value)?;
- emit_implicit_conversion(
- builder,
- map,
- &ImplicitConversion {
- src: wide_bit_value,
- dst: cv.dst,
- from: wide_bit_type,
- to: cv.to.clone(),
- kind: ConversionKind::Default,
- src_sema: cv.src_sema,
- dst_sema: cv.dst_sema,
- },
- )?;
- }
- }
- }
- (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(),
- (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
- | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default)
- | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
- let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
- builder.bitcast(into_type, Some(cv.dst), cv.src)?;
- }
- (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => {
- let result_type = if spirv_ptr {
- map.get_or_add(
- builder,
- SpirvType::Pointer(
- Box::new(SpirvType::from(cv.to.clone())),
- spirv::StorageClass::Function,
- ),
- )
- } else {
- map.get_or_add(builder, SpirvType::from(cv.to.clone()))
- };
- builder.bitcast(result_type, Some(cv.dst), cv.src)?;
+ fn get_or_add_impl(
+ &mut self,
+ name: Cow<'input, str>,
+ type_: Option<(ast::Type, ast::StateSpace, Option<u32>)>,
+ ) -> Id {
+ let id = self.globals.get_or_add_impl(&mut self.id_gen, name.clone());
+ if cfg!(debug_assertions) {
+ eprintln!("{}: {}", id.get(), name.to_owned());
}
- _ => unreachable!(),
+ self.type_check
+ .insert(id, type_.map(|(t, ss, align)| (t, ss, align, true)));
+ id
}
- Ok(())
-}
-
-fn emit_load_var(
- builder: &mut dr::Builder,
- map: &mut TypeWordMap,
- details: &LoadVarDetails,
-) -> Result<(), TranslateError> {
- let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone()));
- match details.member_index {
- Some((index, Some(width))) => {
- let vector_type = match details.typ {
- ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
- _ => return Err(TranslateError::MismatchedType),
- };
- let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type));
- let vector_temp = builder.load(vector_type_spirv, None, details.arg.src, None, [])?;
- builder.composite_extract(
- result_type,
- Some(details.arg.dst),
- vector_temp,
- &[index as u32],
- )?;
- }
- Some((index, None)) => {
- let result_ptr_type = map.get_or_add(
- builder,
- SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function),
- );
- let index_spirv = map.get_or_add_constant(
- builder,
- &ast::Type::Scalar(ast::ScalarType::U32),
- &vec_repr(index as u32),
- )?;
- let src = builder.in_bounds_access_chain(
- result_ptr_type,
- None,
- details.arg.src,
- &[index_spirv],
- )?;
- builder.load(result_type, Some(details.arg.dst), src, None, [])?;
- }
- None => {
- builder.load(
- result_type,
- Some(details.arg.dst),
- details.arg.src,
- None,
- [],
- )?;
- }
- };
- Ok(())
-}
-fn normalize_identifiers<'a, 'b>(
- id_defs: &mut FnStringIdResolver<'a, 'b>,
- fn_defs: &GlobalFnDeclResolver<'a, 'b>,
- func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
-) -> Result<Vec<NormalizedStatement>, TranslateError> {
- for s in func.iter() {
- match s {
- ast::Statement::Label(id) => {
- id_defs.add_def(*id, None, false);
- }
- _ => (),
- }
- }
- let mut result = Vec::new();
- for s in func {
- expand_map_variables(id_defs, fn_defs, &mut result, s)?;
+ pub(crate) fn register_intermediate(
+ &mut self,
+ typ: Option<(ast::Type, ast::StateSpace)>,
+ ) -> Id {
+ let new_id = self.id_gen.next();
+ self.type_check
+ .insert(new_id, typ.map(|(t, space)| (t, space, None, false)));
+ new_id
}
- Ok(result)
-}
-fn expand_map_variables<'a, 'b>(
- id_defs: &mut FnStringIdResolver<'a, 'b>,
- fn_defs: &GlobalFnDeclResolver<'a, 'b>,
- result: &mut Vec<NormalizedStatement>,
- s: ast::Statement<ast::ParsedArgParams<'a>>,
-) -> Result<(), TranslateError> {
- match s {
- ast::Statement::Block(block) => {
- id_defs.start_block();
- for s in block {
- expand_map_variables(id_defs, fn_defs, result, s)?;
- }
- id_defs.end_block();
- }
- ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)),
- ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
- p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id)))
- .transpose()?,
- i.map_variable(&mut |id| id_defs.get_id(id))?,
- ))),
- ast::Statement::Variable(var) => {
- let mut var_type = ast::Type::from(var.var.v_type.clone());
- let mut is_variable = false;
- var_type = match var.var.v_type {
- ast::VariableType::Reg(_) => {
- is_variable = true;
- var_type
- }
- ast::VariableType::Shared(_) => {
- // If it's a pointer it will be translated to a method parameter later
- if let ast::Type::Pointer(..) = var_type {
- is_variable = true;
- var_type
- } else {
- var_type.param_pointer_to(ast::LdStateSpace::Shared)?
- }
- }
- ast::VariableType::Global(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Global)?
- }
- ast::VariableType::Param(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Param)?
- }
- ast::VariableType::Local(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Local)?
- }
- };
- match var.count {
- Some(count) => {
- for new_id in id_defs.add_defs(var.var.name, count, var_type, is_variable) {
- result.push(Statement::Variable(ast::Variable {
- align: var.var.align,
- v_type: var.var.v_type.clone(),
- name: new_id,
- array_init: var.var.array_init.clone(),
- }))
- }
- }
- None => {
- let new_id = id_defs.add_def(var.var.name, Some(var_type), is_variable);
- result.push(Statement::Variable(ast::Variable {
- align: var.var.align,
- v_type: var.var.v_type.clone(),
- name: new_id,
- array_init: var.var.array_init,
- }));
- }
- }
- }
- };
- Ok(())
-}
-
-// TODO: detect more patterns (mov, call via reg, call via param)
-// TODO: don't convert to ptr if the register is not ultimately used for ld/st
-// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
-// argument expansion
-// TODO: propagate through calls?
-fn convert_to_stateful_memory_access<'a>(
- func_args: &mut SpirvMethodDecl,
- func_body: Vec<TypedStatement>,
- id_defs: &mut NumericIdResolver<'a>,
-) -> Result<Vec<TypedStatement>, TranslateError> {
- let func_args_64bit = func_args
- .input
- .iter()
- .filter_map(|arg| match arg.v_type {
- ast::Type::Scalar(ast::ScalarType::U64)
- | ast::Type::Scalar(ast::ScalarType::B64)
- | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name),
- _ => None,
- })
- .collect::<HashSet<_>>();
- let mut stateful_markers = Vec::new();
- let mut stateful_init_reg = MultiHashMap::new();
- for statement in func_body.iter() {
- match statement {
- Statement::Instruction(ast::Instruction::Cvta(
- ast::CvtaDetails {
- to: ast::CvtaStateSpace::Global,
- size: ast::CvtaSize::U64,
- from: ast::CvtaStateSpace::Generic,
- },
- arg,
- )) => {
- if let (TypedOperand::Reg(dst), Some(src)) =
- (arg.dst, arg.src.upcast().underlying())
- {
- if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, dst) {
- stateful_markers.push((dst, *src));
- }
- }
- }
- Statement::Instruction(ast::Instruction::Ld(
- ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::LdStType::Scalar(ast::LdStScalarType::U64),
- ..
- },
- arg,
- ))
- | Statement::Instruction(ast::Instruction::Ld(
- ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::LdStType::Scalar(ast::LdStScalarType::S64),
- ..
- },
- arg,
- ))
- | Statement::Instruction(ast::Instruction::Ld(
- ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::LdStType::Scalar(ast::LdStScalarType::B64),
- ..
- },
- arg,
- )) => {
- if let (TypedOperand::Reg(dst), Some(src)) =
- (&arg.dst, arg.src.upcast().underlying())
- {
- if func_args_64bit.contains(src) {
- multi_hash_map_append(&mut stateful_init_reg, *dst, *src);
- }
- }
- }
- _ => {}
+ // This is for identifiers which will be emitted later as OpVariable
+ // They are candidates for insertion of LoadVar/StoreVar
+ pub(crate) fn register_variable_decl(
+ &mut self,
+ align: Option<u32>,
+ type_: ast::Type,
+ state_space: ast::StateSpace,
+ ) -> ast::VariableDeclaration<Id> {
+ let name = self.id_gen.next();
+ self.type_check
+ .insert(name, Some((type_.clone(), state_space, align, true)));
+ ast::VariableDeclaration {
+ align,
+ type_,
+ state_space,
+ name,
}
}
- let mut func_args_ptr = HashSet::new();
- let mut regs_ptr_current = HashSet::new();
- for (dst, src) in stateful_markers {
- if let Some(func_args) = stateful_init_reg.get(&src) {
- for a in func_args {
- func_args_ptr.insert(*a);
- regs_ptr_current.insert(src);
- regs_ptr_current.insert(dst);
- }
+ pub(crate) fn register_variable_def(
+ &mut self,
+ align: Option<u32>,
+ type_: ast::Type,
+ state_space: ast::StateSpace,
+ initializer: Option<ast::Initializer<Id>>,
+ ) -> Variable {
+ let name = self.id_gen.next();
+ self.type_check
+ .insert(name, Some((type_.clone(), state_space, align, true)));
+ Variable {
+ name,
+ align,
+ type_,
+ state_space,
+ initializer,
}
}
- // BTreeSet here to have a stable order of iteration,
- // unfortunately our tests rely on it
- let mut regs_ptr_seen = BTreeSet::new();
- while regs_ptr_current.len() > 0 {
- let mut regs_ptr_new = HashSet::new();
- for statement in func_body.iter() {
- match statement {
- Statement::Instruction(ast::Instruction::Add(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
- arg,
- ))
- | Statement::Instruction(ast::Instruction::Add(
- ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
- saturate: false,
- }),
- arg,
- ))
- | Statement::Instruction(ast::Instruction::Sub(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
- arg,
- ))
- | Statement::Instruction(ast::Instruction::Sub(
- ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
- saturate: false,
- }),
- arg,
- )) => {
- if let (TypedOperand::Reg(dst), Some(src1)) =
- (arg.dst, arg.src1.upcast().underlying())
- {
- if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) {
- regs_ptr_new.insert(dst);
- }
- } else if let (TypedOperand::Reg(dst), Some(src2)) =
- (arg.dst, arg.src2.upcast().underlying())
- {
- if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) {
- regs_ptr_new.insert(dst);
- }
- }
- }
- _ => {}
- }
+
+ pub(crate) fn get_typed(
+ &self,
+ id: Id,
+ ) -> Result<(ast::Type, ast::StateSpace, Option<u32>, bool), TranslateError> {
+ match self.type_check.get(&id) {
+ Some(Some(x)) => Ok(x.clone()),
+ Some(None) => Err(TranslateError::untyped_symbol()),
+ None => match self.globals.special_registers.get(id) {
+ Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, None, true)),
+ None => match self.type_check.get(&id) {
+ Some(Some(result)) => Ok(result.clone()),
+ Some(None) | None => Err(TranslateError::untyped_symbol()),
+ },
+ },
}
- for id in regs_ptr_current {
- regs_ptr_seen.insert(id);
- }
- regs_ptr_current = regs_ptr_new;
- }
- drop(regs_ptr_current);
- let mut remapped_ids = HashMap::new();
- let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
- for reg in regs_ptr_seen {
- let new_id = id_defs.new_variable(ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ));
- result.push(Statement::Variable(ast::Variable {
- align: None,
- name: new_id,
- array_init: Vec::new(),
- v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
- ast::SizedScalarType::U8,
- ast::PointerStateSpace::Global,
- )),
- }));
- remapped_ids.insert(reg, new_id);
}
- for statement in func_body {
- match statement {
- l @ Statement::Label(_) => result.push(l),
- c @ Statement::Conditional(_) => result.push(c),
- Statement::Variable(var) => {
- if !remapped_ids.contains_key(&var.name) {
- result.push(Statement::Variable(var));
- }
- }
- Statement::Instruction(ast::Instruction::Add(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
- arg,
- ))
- | Statement::Instruction(ast::Instruction::Add(
- ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
- saturate: false,
- }),
- arg,
- )) if is_add_ptr_direct(&remapped_ids, &arg) => {
- let (ptr, offset) = match arg.src1.upcast().underlying() {
- Some(src1) if remapped_ids.contains_key(src1) => {
- (remapped_ids.get(src1).unwrap(), arg.src2)
- }
- Some(src2) if remapped_ids.contains_key(src2) => {
- (remapped_ids.get(src2).unwrap(), arg.src1)
- }
- _ => return Err(error_unreachable()),
- };
- let dst = arg.dst.upcast().unwrap_reg()?;
- result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
- state_space: ast::LdStateSpace::Global,
- dst: *remapped_ids.get(&dst).unwrap(),
- ptr_src: *ptr,
- offset_src: offset,
- }))
- }
- Statement::Instruction(ast::Instruction::Sub(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
- arg,
- ))
- | Statement::Instruction(ast::Instruction::Sub(
- ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
- saturate: false,
- }),
- arg,
- )) if is_add_ptr_direct(&remapped_ids, &arg) => {
- let (ptr, offset) = match arg.src1.upcast().underlying() {
- Some(src1) if remapped_ids.contains_key(src1) => {
- (remapped_ids.get(src1).unwrap(), arg.src2)
- }
- Some(src2) if remapped_ids.contains_key(src2) => {
- (remapped_ids.get(src2).unwrap(), arg.src1)
- }
- _ => return Err(error_unreachable()),
- };
- let offset_neg =
- id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64)));
- result.push(Statement::Instruction(ast::Instruction::Neg(
- ast::NegDetails {
- typ: ast::ScalarType::S64,
- flush_to_zero: None,
- },
- ast::Arg2 {
- src: offset,
- dst: TypedOperand::Reg(offset_neg),
- },
- )));
- let dst = arg.dst.upcast().unwrap_reg()?;
- result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
- state_space: ast::LdStateSpace::Global,
- dst: *remapped_ids.get(&dst).unwrap(),
- ptr_src: *ptr,
- offset_src: TypedOperand::Reg(offset_neg),
- }))
- }
- Statement::Instruction(inst) => {
- let mut post_statements = Vec::new();
- let new_statement = inst.visit(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>| {
- convert_to_stateful_memory_access_postprocess(
- id_defs,
- &remapped_ids,
- &func_args_ptr,
- &mut result,
- &mut post_statements,
- arg_desc,
- expected_type,
- )
- },
- )?;
- result.push(new_statement);
- result.extend(post_statements);
- }
- Statement::Call(call) => {
- let mut post_statements = Vec::new();
- let new_statement = call.visit(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>| {
- convert_to_stateful_memory_access_postprocess(
- id_defs,
- &remapped_ids,
- &func_args_ptr,
- &mut result,
- &mut post_statements,
- arg_desc,
- expected_type,
- )
- },
- )?;
- result.push(new_statement);
- result.extend(post_statements);
- }
- Statement::RepackVector(pack) => {
- let mut post_statements = Vec::new();
- let new_statement = pack.visit(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>| {
- convert_to_stateful_memory_access_postprocess(
- id_defs,
- &remapped_ids,
- &func_args_ptr,
- &mut result,
- &mut post_statements,
- arg_desc,
- expected_type,
- )
- },
- )?;
- result.push(new_statement);
- result.extend(post_statements);
+
+ fn change_type(&mut self, id: Id, new_type: ast::Type) -> Result<(), TranslateError> {
+ Ok(match self.type_check.get_mut(&id) {
+ Some(Some((type_, ..))) => {
+ *type_ = new_type;
}
- _ => return Err(error_unreachable()),
- }
- }
- for arg in func_args.input.iter_mut() {
- if func_args_ptr.contains(&arg.name) {
- arg.v_type = ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- );
- }
+ _ => return Err(TranslateError::unreachable()),
+ })
}
- Ok(result)
}
-fn convert_to_stateful_memory_access_postprocess(
- id_defs: &mut NumericIdResolver,
- remapped_ids: &HashMap<spirv::Word, spirv::Word>,
- func_args_ptr: &HashSet<spirv::Word>,
- result: &mut Vec<TypedStatement>,
- post_statements: &mut Vec<TypedStatement>,
- arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>,
-) -> Result<spirv::Word, TranslateError> {
- Ok(match remapped_ids.get(&arg_desc.op) {
- Some(new_id) => {
- // We skip conversion here to trigger PtrAcces in a later pass
- let old_type = match expected_type {
- Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
- _ => id_defs.get_typed(arg_desc.op)?.0,
- };
- let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type_clone));
- if arg_desc.is_dst {
- post_statements.push(Statement::Conversion(ImplicitConversion {
- src: converting_id,
- dst: *new_id,
- from: old_type,
- to: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global),
- src_sema: ArgumentSemantics::Default,
- dst_sema: arg_desc.sema,
- }));
- converting_id
- } else {
- result.push(Statement::Conversion(ImplicitConversion {
- src: *new_id,
- dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- to: old_type,
- kind: ConversionKind::PtrToBit(ast::UIntType::U64),
- src_sema: arg_desc.sema,
- dst_sema: ArgumentSemantics::Default,
- }));
- converting_id
- }
- }
- None => match func_args_ptr.get(&arg_desc.op) {
- Some(new_id) => {
- if arg_desc.is_dst {
- return Err(error_unreachable());
- }
- // We skip conversion here to trigger PtrAcces in a later pass
- let old_type = match expected_type {
- Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
- _ => id_defs.get_typed(arg_desc.op)?.0,
- };
- let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type));
- result.push(Statement::Conversion(ImplicitConversion {
- src: *new_id,
- dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global),
- ast::LdStateSpace::Param,
- ),
- to: old_type_clone,
- kind: ConversionKind::PtrToPtr { spirv_ptr: false },
- src_sema: arg_desc.sema,
- dst_sema: ArgumentSemantics::Default,
- }));
- converting_id
- }
- None => arg_desc.op,
- },
- })
+pub(crate) struct GlobalsResolver<'input> {
+ // Thos two fields below are only used by raytracing
+ // TODO: move to raytracing-specific structures
+ pub(crate) variables: FxHashMap<Cow<'input, str>, Id>,
+ pub(crate) reverse_variables: FxHashMap<Id, Cow<'input, str>>,
+ special_registers: SpecialRegistersMap,
+ pub(crate) function_prototypes: FxHashMap<Id, Callprototype>,
}
-fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
- match arg.dst {
- TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => {
- return false
- }
- TypedOperand::Reg(dst) => {
- if !remapped_ids.contains_key(&dst) {
- return false;
- }
- match arg.src1.upcast().underlying() {
- Some(src1) if remapped_ids.contains_key(src1) => true,
- Some(src2) if remapped_ids.contains_key(src2) => true,
- _ => false,
+impl<'input> GlobalsResolver<'input> {
+ fn get_or_add_impl(&mut self, id_gen: &mut IdGenerator, id: Cow<'input, str>) -> Id {
+ let id = match self.variables.entry(id) {
+ hash_map::Entry::Occupied(e) => *(e.get()),
+ hash_map::Entry::Vacant(e) => {
+ let new_id = id_gen.next();
+ self.reverse_variables.insert(new_id, e.key().clone());
+ e.insert(new_id);
+ new_id
}
- }
+ };
+ id
}
}
-fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
- match id_defs.get_typed(id) {
- Ok((ast::Type::Scalar(ast::ScalarType::U64), _))
- | Ok((ast::Type::Scalar(ast::ScalarType::S64), _))
- | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true,
- _ => false,
- }
+pub struct Callprototype {
+ pub return_arguments: Vec<(ast::Type, ast::StateSpace)>,
+ pub input_arguments: Vec<(ast::Type, ast::StateSpace)>,
}
-#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
-enum PtxSpecialRegister {
- Tid,
- Tid64,
- Ntid,
- Ntid64,
- Ctaid,
- Ctaid64,
- Nctaid,
- Nctaid64,
+pub(crate) struct StringIdResolver<'a, 'input> {
+ module: &'a mut IdNameMapBuilder<'input>,
+ scopes: Vec<FxHashMap<Cow<'input, str>, Id>>,
}
-impl PtxSpecialRegister {
- fn try_parse(s: &str) -> Option<Self> {
- match s {
- "%tid" => Some(Self::Tid),
- "%ntid" => Some(Self::Ntid),
- "%ctaid" => Some(Self::Ctaid),
- "%nctaid" => Some(Self::Nctaid),
- _ => None,
- }
- }
-
- fn get_type(self) -> ast::Type {
- match self {
- PtxSpecialRegister::Tid => ast::Type::Vector(ast::ScalarType::U32, 4),
- PtxSpecialRegister::Tid64 => ast::Type::Vector(ast::ScalarType::U64, 3),
- PtxSpecialRegister::Ntid => ast::Type::Vector(ast::ScalarType::U32, 4),
- PtxSpecialRegister::Ntid64 => ast::Type::Vector(ast::ScalarType::U64, 3),
- PtxSpecialRegister::Ctaid => ast::Type::Vector(ast::ScalarType::U32, 4),
- PtxSpecialRegister::Ctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3),
- PtxSpecialRegister::Nctaid => ast::Type::Vector(ast::ScalarType::U32, 4),
- PtxSpecialRegister::Nctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3),
- }
- }
-
- fn get_builtin(self) -> spirv::BuiltIn {
- match self {
- PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => {
- spirv::BuiltIn::LocalInvocationId
- }
- PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => spirv::BuiltIn::WorkgroupSize,
- PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => spirv::BuiltIn::WorkgroupId,
- PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => {
- spirv::BuiltIn::NumWorkgroups
+impl<'a, 'input> StringIdResolver<'a, 'input> {
+ fn new<P: ast::ArgParams<Id = Id>>(
+ module_resolver: &'a mut IdNameMapBuilder<'input>,
+ existing_directives: &[TranslationDirective<'input, P>],
+ ) -> Result<Self, TranslateError> {
+ let mut result = Self {
+ module: module_resolver,
+ scopes: vec![FxHashMap::default(), FxHashMap::default()],
+ };
+ for directive in existing_directives {
+ match directive {
+ TranslationDirective::Variable(..) => return Err(TranslateError::unreachable()),
+ TranslationDirective::Method(method) => {
+ let string_name = result
+ .module
+ .globals
+ .reverse_variables
+ .get(&method.name)
+ .ok_or_else(TranslateError::unreachable)?;
+ result.scopes[StringIdResolverScope::IMPLICIT_GLOBALS]
+ .insert(string_name.clone(), method.name);
+ }
}
}
+ Ok(result)
}
- fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> {
- match self {
- PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)),
- PtxSpecialRegister::Ntid => Some((PtxSpecialRegister::Ntid64, ast::ScalarType::U64, 3)),
- PtxSpecialRegister::Ctaid => {
- Some((PtxSpecialRegister::Ctaid64, ast::ScalarType::U64, 3))
- }
- PtxSpecialRegister::Nctaid => {
- Some((PtxSpecialRegister::Nctaid64, ast::ScalarType::U64, 3))
- }
- PtxSpecialRegister::Tid64
- | PtxSpecialRegister::Ntid64
- | PtxSpecialRegister::Ctaid64
- | PtxSpecialRegister::Nctaid64 => None,
- }
+ fn start_module<'b>(&'b mut self) -> StringIdResolverScope<'a, 'b, 'input> {
+ self.scopes.push(FxHashMap::default());
+ StringIdResolverScope(self)
}
}
-struct SpecialRegistersMap {
- reg_to_id: HashMap<PtxSpecialRegister, spirv::Word>,
- id_to_reg: HashMap<spirv::Word, PtxSpecialRegister>,
-}
+pub(crate) struct StringIdResolverScope<'a, 'b, 'input>(&'b mut StringIdResolver<'a, 'input>);
-impl SpecialRegistersMap {
- fn new() -> Self {
- SpecialRegistersMap {
- reg_to_id: HashMap::new(),
- id_to_reg: HashMap::new(),
- }
+impl<'a, 'b, 'input> StringIdResolverScope<'a, 'b, 'input> {
+ // Items with visible, weak, etc. visibility. Accessible only by items
+ // taking part in cross-module linking (so items with visible, weak, etc. visibility)
+ const CROSS_MODULE: usize = 0;
+ // Some items are not explicitly declared, but are anyway visible inside a module.
+ // Currently this is the scope for raytracing function declarations
+ // TOOD: refactor special registers (activemask, etc.) to use this scope
+ const IMPLICIT_GLOBALS: usize = 1;
+ const CURRENT_MODULE: usize = 2;
+
+ fn start_scope<'x>(&'x mut self) -> StringIdResolverScope<'a, 'x, 'input> {
+ self.0.scopes.push(FxHashMap::default());
+ StringIdResolverScope(self.0)
}
- fn builtins<'a>(&'a self) -> impl Iterator<Item = (PtxSpecialRegister, spirv::Word)> + 'a {
- self.reg_to_id.iter().filter_map(|(sreg, id)| {
- if sreg.normalized_sreg_and_type().is_none() {
- Some((*sreg, *id))
- } else {
- None
- }
- })
+ fn get_id_in_module_scope(&self, name: &str) -> Result<Id, TranslateError> {
+ self.0.scopes[Self::CURRENT_MODULE]
+ .get(name)
+ .copied()
+ .ok_or_else(TranslateError::unknown_symbol)
}
- fn interface(&self) -> Vec<spirv::Word> {
- self.reg_to_id
+ fn get_id_in_function_scopes(&self, name: &str) -> Result<Id, TranslateError> {
+ let func_scopes_count = self.0.scopes.len() - (Self::CURRENT_MODULE + 1);
+ self.0
+ .scopes
.iter()
- .filter_map(|(sreg, id)| {
- if sreg.normalized_sreg_and_type().is_none() {
- Some(*id)
- } else {
- None
- }
+ .rev()
+ .take(func_scopes_count)
+ .find_map(|scope| scope.get(name))
+ .copied()
+ .ok_or_else(TranslateError::unknown_symbol)
+ }
+
+ fn get_id_in_module_scopes(&mut self, name: &str) -> Result<Id, TranslateError> {
+ // Scope 0 is global scope
+ let func_scopes_count = self.0.scopes.len() - (Self::CROSS_MODULE + 1);
+ PtxSpecialRegister::try_parse(name)
+ .map(|sreg| {
+ self.0
+ .module
+ .globals
+ .special_registers
+ .get_or_add(&mut self.0.module.id_gen, sreg)
})
- .collect::<Vec<_>>()
- }
-
- fn get(&self, id: spirv::Word) -> Option<PtxSpecialRegister> {
- self.id_to_reg.get(&id).copied()
+ .or_else(|| {
+ self.0
+ .scopes
+ .iter()
+ .rev()
+ .take(func_scopes_count)
+ .find_map(|scope| scope.get(name))
+ .copied()
+ })
+ .ok_or_else(TranslateError::unknown_symbol)
}
- fn get_or_add(&mut self, current_id: &mut spirv::Word, reg: PtxSpecialRegister) -> spirv::Word {
- match self.reg_to_id.entry(reg) {
- hash_map::Entry::Occupied(e) => *e.get(),
- hash_map::Entry::Vacant(e) => {
- let numeric_id = *current_id;
- *current_id += 1;
- e.insert(numeric_id);
- self.id_to_reg.insert(numeric_id, reg);
- numeric_id
+ fn add_or_get_at_module_level(
+ &mut self,
+ name: &'input str,
+ use_global_scope: bool,
+ ) -> Result<Id, TranslateError> {
+ debug_assert!(self.0.scopes.len() == 3);
+ if self.0.scopes[Self::IMPLICIT_GLOBALS].get(name).is_some() {
+ return Err(TranslateError::symbol_redefinition());
+ }
+ if use_global_scope {
+ let id = Self::get_or_add_untyped_in_scope(
+ &mut self.0.module,
+ &mut self.0.scopes[Self::CROSS_MODULE],
+ Cow::Borrowed(name),
+ None,
+ );
+ match self.0.scopes[Self::CURRENT_MODULE].entry(Cow::Borrowed(name)) {
+ hash_map::Entry::Occupied(existing_id) => {
+ if *existing_id.get() != id {
+ return Err(TranslateError::unreachable());
+ }
+ }
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(id);
+ }
}
+ Ok(id)
+ } else {
+ Ok(Self::get_or_add_untyped_in_scope(
+ &mut self.0.module,
+ &mut self.0.scopes[Self::CURRENT_MODULE],
+ Cow::Borrowed(name),
+ None,
+ ))
}
}
-}
-struct GlobalStringIdResolver<'input> {
- current_id: spirv::Word,
- variables: HashMap<Cow<'input, str>, spirv::Word>,
- variables_type_check: HashMap<u32, Option<(ast::Type, bool)>>,
- special_registers: SpecialRegistersMap,
- fns: HashMap<spirv::Word, FnDecl>,
-}
-
-pub struct FnDecl {
- ret_vals: Vec<ast::FnArgumentType>,
- params: Vec<ast::FnArgumentType>,
-}
-
-impl<'a> GlobalStringIdResolver<'a> {
- fn new(start_id: spirv::Word) -> Self {
- Self {
- current_id: start_id,
- variables: HashMap::new(),
- variables_type_check: HashMap::new(),
- special_registers: SpecialRegistersMap::new(),
- fns: HashMap::new(),
+ fn get_or_add_untyped_in_scope(
+ id_def: &mut IdNameMapBuilder<'input>,
+ scope: &mut FxHashMap<Cow<'input, str>, Id>,
+ name: Cow<'input, str>,
+ type_: Option<(ast::Type, ast::StateSpace, Option<u32>, bool)>,
+ ) -> Id {
+ match scope.entry(name) {
+ hash_map::Entry::Occupied(entry) => *entry.get(),
+ hash_map::Entry::Vacant(entry) => {
+ let id = id_def.id_gen.next();
+ if cfg!(debug_assertions) {
+ eprintln!("{}: {}", id.get(), entry.key().to_owned());
+ }
+ id_def.type_check.insert(id, type_);
+ entry.insert(id);
+ id
+ }
}
}
- fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word {
- self.get_or_add_impl(id, None)
+ fn add_untyped_checked(&mut self, name: &'input str) -> Result<Id, TranslateError> {
+ let (id, overwrite) = self.add_untyped_impl(name);
+ if overwrite {
+ Err(TranslateError::SymbolRedefinition)
+ } else {
+ Ok(id)
+ }
}
- fn get_or_add_def_typed(
+ fn add_variable_checked(
&mut self,
- id: &'a str,
- typ: ast::Type,
- is_variable: bool,
- ) -> spirv::Word {
- self.get_or_add_impl(id, Some((typ, is_variable)))
- }
-
- fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word {
- let id = match self.variables.entry(Cow::Borrowed(id)) {
- hash_map::Entry::Occupied(e) => *(e.get()),
- hash_map::Entry::Vacant(e) => {
- let numeric_id = self.current_id;
- e.insert(numeric_id);
- self.current_id += 1;
- numeric_id
- }
- };
- self.variables_type_check.insert(id, typ);
- id
- }
-
- fn get_id(&self, id: &str) -> Result<spirv::Word, TranslateError> {
- self.variables
- .get(id)
- .copied()
- .ok_or(TranslateError::UnknownSymbol)
+ name: &'input str,
+ type_: ast::Type,
+ space: ast::StateSpace,
+ align: Option<u32>,
+ ) -> Result<Id, TranslateError> {
+ let id = self.0.module.id_gen.next();
+ self.0
+ .module
+ .type_check
+ .insert(id, Some((type_, space, align, true)));
+ let old = self
+ .0
+ .scopes
+ .last_mut()
+ .unwrap()
+ .insert(Cow::Borrowed(name), id);
+ if old.is_some() {
+ Err(TranslateError::SymbolRedefinition)
+ } else {
+ Ok(id)
+ }
}
- fn current_id(&self) -> spirv::Word {
- self.current_id
+ fn add_untyped_impl(&mut self, name: &'input str) -> (Id, bool) {
+ let id = self.0.module.id_gen.next();
+ self.0.module.type_check.insert(id, None);
+ let old = self
+ .0
+ .scopes
+ .last_mut()
+ .unwrap()
+ .insert(Cow::Borrowed(name), id);
+ (id, old.is_some())
}
- fn start_fn<'b>(
- &'b mut self,
- header: &'b ast::MethodDecl<'a, &'a str>,
- ) -> Result<
- (
- FnStringIdResolver<'a, 'b>,
- GlobalFnDeclResolver<'a, 'b>,
- ast::MethodDecl<'a, spirv::Word>,
- ),
- TranslateError,
- > {
- // In case a function decl was inserted earlier we want to use its id
- let name_id = self.get_or_add_def(header.name());
- let mut fn_resolver = FnStringIdResolver {
- current_id: &mut self.current_id,
- global_variables: &self.variables,
- global_type_check: &self.variables_type_check,
- special_registers: &mut self.special_registers,
- variables: vec![HashMap::new(); 1],
- type_check: HashMap::new(),
- };
- let new_fn_decl = match header {
- ast::MethodDecl::Kernel { name, in_args } => ast::MethodDecl::Kernel {
- name,
- in_args: expand_kernel_params(&mut fn_resolver, in_args.iter())?,
- },
- ast::MethodDecl::Func(ret_params, _, params) => {
- let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter())?;
- let params_ids = expand_fn_params(&mut fn_resolver, params.iter())?;
- self.fns.insert(
- name_id,
- FnDecl {
- ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(),
- params: params_ids.iter().map(|p| p.v_type.clone()).collect(),
- },
- );
- ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
+ fn add_or_get_module_variable(
+ &mut self,
+ name: Cow<'input, str>,
+ use_global_scope: bool,
+ type_: ast::Type,
+ state_space: ast::StateSpace,
+ align: Option<u32>,
+ initializer: Option<ast::Initializer<Id>>,
+ ) -> Result<Variable, TranslateError> {
+ let id = if use_global_scope {
+ let id = Self::get_or_add_untyped_in_scope(
+ &mut self.0.module,
+ &mut self.0.scopes[Self::CROSS_MODULE],
+ name.clone(),
+ Some((type_.clone(), state_space, align, true)),
+ );
+ if self.0.scopes[Self::CURRENT_MODULE]
+ .insert(name.clone(), id)
+ .is_some()
+ {
+ return Err(TranslateError::unreachable());
}
+ id
+ } else {
+ Self::get_or_add_untyped_in_scope(
+ &mut self.0.module,
+ &mut self.0.scopes[Self::CURRENT_MODULE],
+ name.clone(),
+ Some((type_.clone(), state_space, align, true)),
+ )
};
- Ok((
- fn_resolver,
- GlobalFnDeclResolver {
- variables: &self.variables,
- fns: &self.fns,
- },
- new_fn_decl,
- ))
- }
-}
-
-pub struct GlobalFnDeclResolver<'input, 'a> {
- variables: &'a HashMap<Cow<'input, str>, spirv::Word>,
- fns: &'a HashMap<spirv::Word, FnDecl>,
-}
-
-impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
- fn get_fn_decl(&self, id: spirv::Word) -> Result<&FnDecl, TranslateError> {
- self.fns.get(&id).ok_or(TranslateError::UnknownSymbol)
- }
-
- fn get_fn_decl_str(&self, id: &str) -> Result<&'a FnDecl, TranslateError> {
- match self.variables.get(id).map(|var_id| self.fns.get(var_id)) {
- Some(Some(fn_d)) => Ok(fn_d),
- _ => Err(TranslateError::UnknownSymbol),
- }
- }
-}
-
-struct FnStringIdResolver<'input, 'b> {
- current_id: &'b mut spirv::Word,
- global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
- global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
- special_registers: &'b mut SpecialRegistersMap,
- variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
- type_check: HashMap<u32, Option<(ast::Type, bool)>>,
-}
-
-impl<'a, 'b> FnStringIdResolver<'a, 'b> {
- fn finish(self) -> NumericIdResolver<'b> {
- NumericIdResolver {
- current_id: self.current_id,
- global_type_check: self.global_type_check,
- type_check: self.type_check,
- special_registers: self.special_registers,
- }
- }
-
- fn start_block(&mut self) {
- self.variables.push(HashMap::new())
- }
-
- fn end_block(&mut self) {
- self.variables.pop();
+ self.0.module.globals.variables.insert(name.clone(), id);
+ self.0.module.globals.reverse_variables.insert(id, name);
+ Ok(Variable {
+ align,
+ type_,
+ state_space,
+ name: id,
+ initializer,
+ })
}
- fn get_id(&mut self, id: &str) -> Result<spirv::Word, TranslateError> {
- for scope in self.variables.iter().rev() {
- match scope.get(id) {
- Some(id) => return Ok(*id),
- None => continue,
- }
- }
- match self.global_variables.get(id) {
- Some(id) => Ok(*id),
- None => {
- let sreg =
- PtxSpecialRegister::try_parse(id).ok_or(TranslateError::UnknownSymbol)?;
- Ok(self.special_registers.get_or_add(self.current_id, sreg))
- }
+ fn register_variable(
+ &mut self,
+ name: Cow<'input, str>,
+ type_: ast::Type,
+ state_space: ast::StateSpace,
+ align: Option<u32>,
+ initializer: Option<ast::Initializer<Id>>,
+ ) -> Result<Variable, TranslateError> {
+ let id = self.0.module.id_gen.next();
+ self.0
+ .module
+ .type_check
+ .insert(id, Some((type_.clone(), state_space, align, true)));
+ let old = self.0.scopes.last_mut().unwrap().insert(name, id);
+ if old.is_some() {
+ Err(TranslateError::SymbolRedefinition)
+ } else {
+ Ok(Variable {
+ align,
+ type_: type_,
+ state_space,
+ name: id,
+ initializer,
+ })
}
}
- fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>, is_variable: bool) -> spirv::Word {
- let numeric_id = *self.current_id;
- self.variables
- .last_mut()
- .unwrap()
- .insert(Cow::Borrowed(id), numeric_id);
- self.type_check
- .insert(numeric_id, typ.map(|t| (t, is_variable)));
- *self.current_id += 1;
- numeric_id
+ fn new_untyped(&mut self) -> Id {
+ let id = self.0.module.id_gen.next();
+ self.0.module.type_check.insert(id, None);
+ id
}
+}
- #[must_use]
- fn add_defs(
- &mut self,
- base_id: &'a str,
- count: u32,
- typ: ast::Type,
- is_variable: bool,
- ) -> impl Iterator<Item = spirv::Word> {
- let numeric_id = *self.current_id;
- for i in 0..count {
- self.variables
- .last_mut()
- .unwrap()
- .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
- self.type_check
- .insert(numeric_id + i, Some((typ.clone(), is_variable)));
- }
- *self.current_id += count;
- (0..count).into_iter().map(move |i| i + numeric_id)
+impl<'a, 'b, 'input> Drop for StringIdResolverScope<'a, 'b, 'input> {
+ fn drop(&mut self) {
+ self.0.scopes.pop();
}
}
-struct NumericIdResolver<'b> {
- current_id: &'b mut spirv::Word,
- global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
- type_check: HashMap<u32, Option<(ast::Type, bool)>>,
- special_registers: &'b mut SpecialRegistersMap,
+pub(crate) struct FunctionPointerDetails {
+ pub(crate) dst: Id,
+ pub(crate) src: Id,
}
-impl<'b> NumericIdResolver<'b> {
- fn finish(self) -> MutableNumericIdResolver<'b> {
- MutableNumericIdResolver { base: self }
- }
-
- fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, bool), TranslateError> {
- match self.type_check.get(&id) {
- Some(Some(x)) => Ok(x.clone()),
- Some(None) => Err(TranslateError::UntypedSymbol),
- None => match self.special_registers.get(id) {
- Some(x) => Ok((x.get_type(), true)),
- None => match self.global_type_check.get(&id) {
- Some(Some(result)) => Ok(result.clone()),
- Some(None) | None => Err(TranslateError::UntypedSymbol),
+impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for FunctionPointerDetails {
+ fn visit(
+ self,
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::FunctionPointer(FunctionPointerDetails {
+ dst: visitor.id(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- },
- }
- }
-
- // This is for identifiers which will be emitted later as OpVariable
- // They are candidates for insertion of LoadVar/StoreVar
- fn new_variable(&mut self, typ: ast::Type) -> spirv::Word {
- let new_id = *self.current_id;
- self.type_check.insert(new_id, Some((typ, true)));
- *self.current_id += 1;
- new_id
- }
-
- fn new_non_variable(&mut self, typ: Option<ast::Type>) -> spirv::Word {
- let new_id = *self.current_id;
- self.type_check.insert(new_id, typ.map(|t| (t, false)));
- *self.current_id += 1;
- new_id
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::U64),
+ ast::StateSpace::Reg,
+ )),
+ )?,
+ src: visitor.id(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ None,
+ )?,
+ }))
}
}
-struct MutableNumericIdResolver<'b> {
- base: NumericIdResolver<'b>,
+pub(crate) struct MadCDetails<P: ast::ArgParams> {
+ pub(crate) type_: ast::ScalarType,
+ pub(crate) is_hi: bool,
+ pub(crate) arg: Arg4CarryIn<P>,
}
-impl<'b> MutableNumericIdResolver<'b> {
- fn unmut(self) -> NumericIdResolver<'b> {
- self.base
+impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for MadCDetails<T> {
+ fn visit(
+ self,
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::MadC(MadCDetails {
+ type_: self.type_,
+ is_hi: self.is_hi,
+ arg: self.arg.map(visitor, self.type_)?,
+ }))
}
+}
- fn get_typed(&self, id: spirv::Word) -> Result<ast::Type, TranslateError> {
- self.base.get_typed(id).map(|(t, _)| t)
- }
+pub(crate) struct MadCCDetails<P: ast::ArgParams> {
+ pub(crate) type_: ast::ScalarType,
+ pub(crate) arg: Arg4CarryOut<P>,
+}
- fn new_non_variable(&mut self, typ: ast::Type) -> spirv::Word {
- self.base.new_non_variable(Some(typ))
+impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for MadCCDetails<T> {
+ fn visit(
+ self,
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::MadCC(MadCCDetails {
+ type_: self.type_,
+ arg: self.arg.map(visitor, self.type_)?,
+ }))
}
}
-enum Statement<I, P: ast::ArgParams> {
- Label(u32),
- Variable(ast::Variable<ast::VariableType, P::Id>),
+pub(crate) enum Statement<I, P: ast::ArgParams> {
+ Label(Id),
+ Variable(Variable),
Instruction(I),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
@@ -5096,13 +5594,24 @@ enum Statement<I, P: ast::ArgParams> {
StoreVar(StoreVarDetails),
Conversion(ImplicitConversion),
Constant(ConstantDefinition),
- RetValue(ast::RetData, spirv::Word),
+ RetValue(ast::RetData, Vec<(Id, ast::Type)>),
PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails),
+ FunctionPointer(FunctionPointerDetails),
+ MadC(MadCDetails<P>),
+ MadCC(MadCCDetails<P>),
+ AddC(ast::ScalarType, Arg3CarryIn<P>),
+ AddCC(ast::ScalarType, Arg3CarryOut<P>),
+ SubC(ast::ScalarType, Arg3CarryIn<P>),
+ SubCC(ast::ScalarType, Arg3CarryOut<P>),
+ AsmVolatile {
+ asm: &'static str,
+ constraints: &'static str,
+ },
}
impl ExpandedStatement {
- fn map_id(self, f: &mut impl FnMut(spirv::Word, bool) -> spirv::Word) -> ExpandedStatement {
+ pub(crate) fn map_id(self, f: &mut impl FnMut(Id, bool) -> Id) -> ExpandedStatement {
match self {
Statement::Label(id) => Statement::Label(f(id, false)),
Statement::Variable(mut var) => {
@@ -5110,7 +5619,8 @@ impl ExpandedStatement {
Statement::Variable(var)
}
Statement::Instruction(inst) => inst
- .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| {
+ .visit(&mut |arg: ArgumentDescriptor<_>,
+ _: Option<(&ast::Type, ast::StateSpace)>| {
Ok(f(arg.op, arg.is_dst))
})
.unwrap(),
@@ -5125,16 +5635,17 @@ impl ExpandedStatement {
Statement::StoreVar(details)
}
Statement::Call(mut call) => {
- for (id, typ) in call.ret_params.iter_mut() {
- let is_dst = match typ {
- ast::FnArgumentType::Reg(_) => true,
- ast::FnArgumentType::Param(_) => false,
- ast::FnArgumentType::Shared => false,
+ for (id, _, space) in call.return_arguments.iter_mut() {
+ let is_dst = match space {
+ ast::StateSpace::Reg => true,
+ ast::StateSpace::Param => false,
+ ast::StateSpace::Shared => false,
+ _ => todo!(),
};
*id = f(*id, is_dst);
}
- call.func = f(call.func, false);
- for (id, _) in call.param_list.iter_mut() {
+ call.name = f(call.name, false);
+ for (id, _, _) in call.input_arguments.iter_mut() {
*id = f(*id, false);
}
Statement::Call(call)
@@ -5154,9 +5665,12 @@ impl ExpandedStatement {
constant.dst = f(constant.dst, true);
Statement::Constant(constant)
}
- Statement::RetValue(data, id) => {
- let id = f(id, false);
- Statement::RetValue(data, id)
+ Statement::RetValue(data, ids) => {
+ let ids = ids
+ .into_iter()
+ .map(|(id, type_)| (f(id, false), type_))
+ .collect();
+ Statement::RetValue(data, ids)
}
Statement::PtrAccess(PtrAccess {
underlying_type,
@@ -5189,39 +5703,86 @@ impl ExpandedStatement {
..repack
})
}
+ Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => {
+ Statement::FunctionPointer(FunctionPointerDetails {
+ dst: f(dst, true),
+ src: f(src, false),
+ })
+ }
+ Statement::MadC(madc) => madc
+ .visit(&mut |arg: ArgumentDescriptor<_>,
+ _: Option<(&ast::Type, ast::StateSpace)>| {
+ Ok(f(arg.op, arg.is_dst))
+ })
+ .unwrap(),
+ Statement::MadCC(madcc) => madcc
+ .visit(&mut |arg: ArgumentDescriptor<_>,
+ _: Option<(&ast::Type, ast::StateSpace)>| {
+ Ok(f(arg.op, arg.is_dst))
+ })
+ .unwrap(),
+ Statement::AddC(details, arg) => VisitAddC(details, arg)
+ .visit(&mut |arg: ArgumentDescriptor<_>,
+ _: Option<(&ast::Type, ast::StateSpace)>| {
+ Ok(f(arg.op, arg.is_dst))
+ })
+ .unwrap(),
+ Statement::AddCC(details, arg) => VisitAddCC(details, arg)
+ .visit(&mut |arg: ArgumentDescriptor<_>,
+ _: Option<(&ast::Type, ast::StateSpace)>| {
+ Ok(f(arg.op, arg.is_dst))
+ })
+ .unwrap(),
+ Statement::SubC(details, arg) => VisitSubC(details, arg)
+ .visit(&mut |arg: ArgumentDescriptor<_>,
+ _: Option<(&ast::Type, ast::StateSpace)>| {
+ Ok(f(arg.op, arg.is_dst))
+ })
+ .unwrap(),
+ Statement::SubCC(details, arg) => VisitSubCC(details, arg)
+ .visit(&mut |arg: ArgumentDescriptor<_>,
+ _: Option<(&ast::Type, ast::StateSpace)>| {
+ Ok(f(arg.op, arg.is_dst))
+ })
+ .unwrap(),
+ Statement::AsmVolatile { asm, constraints } => {
+ Statement::AsmVolatile { asm, constraints }
+ }
}
}
}
-struct LoadVarDetails {
- arg: ast::Arg2<ExpandedArgParams>,
- typ: ast::Type,
+pub(crate) struct LoadVarDetails {
+ pub(crate) arg: ast::Arg2<ExpandedArgParams>,
+ pub(crate) typ: ast::Type,
+ pub(crate) _state_space: ast::StateSpace,
// (index, vector_width)
- // HACK ALERT
- // For some reason IGC explodes when you try to load from builtin vectors
- // using OpInBoundsAccessChain, the one true way to do it is to
- // OpLoad+OpCompositeExtract
- member_index: Option<(u8, Option<u8>)>,
+ pub(crate) member_index: Option<(u8, u8)>,
}
-struct StoreVarDetails {
- arg: ast::Arg2St<ExpandedArgParams>,
- typ: ast::Type,
- member_index: Option<u8>,
+pub(crate) struct StoreVarDetails {
+ pub(crate) arg: ast::Arg2St<ExpandedArgParams>,
+ pub(crate) type_: ast::Type,
+ pub(crate) member_index: Option<u8>,
}
-struct RepackVectorDetails {
- is_extract: bool,
- typ: ast::ScalarType,
- packed: spirv::Word,
- unpacked: Vec<spirv::Word>,
- vector_sema: ArgumentSemantics,
+pub(crate) struct RepackVectorDetails {
+ pub(crate) is_extract: bool,
+ pub(crate) typ: ast::ScalarType,
+ pub(crate) packed: Id,
+ pub(crate) unpacked: Vec<Id>,
+ pub(crate) non_default_implicit_conversion: Option<
+ fn(
+ (ast::StateSpace, &ast::Type),
+ (ast::StateSpace, &ast::Type),
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ >,
}
impl RepackVectorDetails {
fn map<
- From: ArgParamsEx<Id = spirv::Word>,
- To: ArgParamsEx<Id = spirv::Word>,
+ From: ArgParamsEx<Id = Id>,
+ To: ArgParamsEx<Id = Id>,
V: ArgumentMapVisitor<From, To>,
>(
self,
@@ -5231,13 +5792,17 @@ impl RepackVectorDetails {
ArgumentDescriptor {
op: self.packed,
is_dst: !self.is_extract,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)),
+ Some((
+ &ast::Type::Vector(self.typ, self.unpacked.len() as u8),
+ ast::StateSpace::Reg,
+ )),
)?;
let scalar_type = self.typ;
let is_extract = self.is_extract;
- let vector_sema = self.vector_sema;
+ let non_default_implicit_conversion = self.non_default_implicit_conversion;
let vector = self
.unpacked
.into_iter()
@@ -5246,9 +5811,10 @@ impl RepackVectorDetails {
ArgumentDescriptor {
op: id,
is_dst: is_extract,
- sema: vector_sema,
+ is_memory_access: false,
+ non_default_implicit_conversion,
},
- Some(&ast::Type::Scalar(scalar_type)),
+ Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)),
)
})
.collect::<Result<_, _>>()?;
@@ -5257,14 +5823,12 @@ impl RepackVectorDetails {
typ: self.typ,
packed: scalar,
unpacked: vector,
- vector_sema,
+ non_default_implicit_conversion,
})
}
}
-impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitable<T, U>
- for RepackVectorDetails
-{
+impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for RepackVectorDetails {
fn visit(
self,
visitor: &mut impl ArgumentMapVisitor<T, U>,
@@ -5273,79 +5837,160 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab
}
}
-struct ResolvedCall<P: ast::ArgParams> {
+pub(crate) struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
- pub ret_params: Vec<(P::Id, ast::FnArgumentType)>,
- pub func: P::Id,
- pub param_list: Vec<(P::Operand, ast::FnArgumentType)>,
+ pub return_arguments: Vec<(P::Id, ast::Type, ast::StateSpace)>,
+ pub name: P::Id,
+ pub input_arguments: Vec<(P::Operand, ast::Type, ast::StateSpace)>,
+ pub is_indirect: bool,
}
impl<T: ast::ArgParams> ResolvedCall<T> {
+ fn from_declaration<'input>(
+ call: ast::CallInst<T>,
+ return_arguments: &[ast::VariableDeclaration<Id>],
+ input_arguments: &[ast::VariableDeclaration<Id>],
+ ) -> Result<Self, TranslateError> {
+ if call.ret_params.len() != return_arguments.len()
+ || call.param_list.len() != input_arguments.len()
+ {
+ return Err(TranslateError::mismatched_type());
+ }
+ let return_arguments = call
+ .ret_params
+ .into_iter()
+ .zip(return_arguments.iter())
+ .map(|(id, var_decl)| (id, var_decl.type_.clone(), var_decl.state_space))
+ .collect::<Vec<_>>();
+ let input_arguments = call
+ .param_list
+ .into_iter()
+ .zip(input_arguments.iter())
+ .map(|(id, var_decl)| (id, var_decl.type_.clone(), var_decl.state_space))
+ .collect::<Vec<_>>();
+ Ok(Self {
+ return_arguments,
+ input_arguments,
+ uniform: call.uniform,
+ name: call.func,
+ is_indirect: false,
+ })
+ }
+
+ fn from_callprototype<'input>(
+ call: ast::CallInst<T>,
+ proto: &Callprototype,
+ ) -> Result<Self, TranslateError> {
+ if call.ret_params.len() != proto.return_arguments.len()
+ || call.param_list.len() != proto.input_arguments.len()
+ {
+ return Err(TranslateError::mismatched_type());
+ }
+ let return_arguments = call
+ .ret_params
+ .into_iter()
+ .zip(proto.return_arguments.iter())
+ .map(|(id, (type_, state_space))| (id, type_.clone(), *state_space))
+ .collect::<Vec<_>>();
+ let input_arguments = call
+ .param_list
+ .into_iter()
+ .zip(proto.input_arguments.iter())
+ .map(|(id, (type_, state_space))| (id, type_.clone(), *state_space))
+ .collect::<Vec<_>>();
+ Ok(Self {
+ return_arguments,
+ input_arguments,
+ uniform: call.uniform,
+ name: call.func,
+ is_indirect: true,
+ })
+ }
+
fn cast<U: ast::ArgParams<Id = T::Id, Operand = T::Operand>>(self) -> ResolvedCall<U> {
ResolvedCall {
uniform: self.uniform,
- ret_params: self.ret_params,
- func: self.func,
- param_list: self.param_list,
+ return_arguments: self.return_arguments,
+ name: self.name,
+ input_arguments: self.input_arguments,
+ is_indirect: self.is_indirect,
}
}
}
-impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
- fn map<To: ArgParamsEx<Id = spirv::Word>, V: ArgumentMapVisitor<From, To>>(
+impl<From: ArgParamsEx<Id = Id>> ResolvedCall<From> {
+ fn map<To: ArgParamsEx<Id = Id>, V: ArgumentMapVisitor<From, To>>(
self,
visitor: &mut V,
) -> Result<ResolvedCall<To>, TranslateError> {
- let ret_params = self
- .ret_params
+ let return_arguments = self
+ .return_arguments
.into_iter()
- .map::<Result<_, TranslateError>, _>(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.id(
ArgumentDescriptor {
op: id,
- is_dst: !typ.is_param(),
- sema: typ.semantics(),
+ is_dst: space != ast::StateSpace::Param,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&typ.to_func_type()),
+ Some((&typ, space)),
)?;
- Ok((new_id, typ))
+ Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
- let func = visitor.id(
- ArgumentDescriptor {
- op: self.func,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- None,
- )?;
- let param_list = self
- .param_list
+ let func = if self.is_indirect {
+ visitor.id(
+ ArgumentDescriptor {
+ op: self.name,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::B64),
+ ast::StateSpace::Reg,
+ )),
+ )
+ } else {
+ visitor.id(
+ ArgumentDescriptor {
+ op: self.name,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ None,
+ )
+ }?;
+ let input_arguments = self
+ .input_arguments
.into_iter()
- .map::<Result<_, TranslateError>, _>(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.operand(
ArgumentDescriptor {
op: id,
is_dst: false,
- sema: typ.semantics(),
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- &typ.to_func_type(),
+ &typ,
+ space,
)?;
- Ok((new_id, typ))
+ Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedCall {
uniform: self.uniform,
- ret_params,
- func,
- param_list,
+ return_arguments,
+ name: func,
+ input_arguments,
+ is_indirect: self.is_indirect,
})
}
}
-impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitable<T, U>
- for ResolvedCall<T>
-{
+impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for ResolvedCall<T> {
fn visit(
self,
visitor: &mut impl ArgumentMapVisitor<T, U>,
@@ -5354,44 +5999,38 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab
}
}
-impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
- fn map<To: ArgParamsEx<Id = spirv::Word>, V: ArgumentMapVisitor<P, To>>(
+impl<P: ArgParamsEx<Id = Id>> PtrAccess<P> {
+ fn map<To: ArgParamsEx<Id = Id>, V: ArgumentMapVisitor<P, To>>(
self,
visitor: &mut V,
) -> Result<PtrAccess<To>, TranslateError> {
- let sema = match self.state_space {
- ast::LdStateSpace::Const
- | ast::LdStateSpace::Global
- | ast::LdStateSpace::Shared
- | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
- ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
- ArgumentSemantics::RegisterPointer
- }
- };
- let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space);
let new_dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ptr_type),
+ Some((&self.underlying_type, self.state_space)),
)?;
let new_ptr_src = visitor.id(
ArgumentDescriptor {
op: self.ptr_src,
is_dst: false,
- sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ptr_type),
+ Some((&self.underlying_type, self.state_space)),
)?;
let new_constant_src = visitor.operand(
ArgumentDescriptor {
op: self.offset_src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
)?;
Ok(PtrAccess {
underlying_type: self.underlying_type,
@@ -5403,9 +6042,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
}
}
-impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitable<T, U>
- for PtrAccess<T>
-{
+impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for PtrAccess<T> {
fn visit(
self,
visitor: &mut impl ArgumentMapVisitor<T, U>,
@@ -5414,41 +6051,22 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab
}
}
-pub trait ArgParamsEx: ast::ArgParams + Sized {
- fn get_fn_decl<'x, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'x, 'b>,
- ) -> Result<&'b FnDecl, TranslateError>;
-}
+pub trait ArgParamsEx: ast::ArgParams + Sized {}
-impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
- fn get_fn_decl<'x, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'x, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl_str(id)
- }
-}
+impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {}
-enum NormalizedArgParams {}
+pub(crate) enum NormalizedArgParams {}
impl ast::ArgParams for NormalizedArgParams {
- type Id = spirv::Word;
- type Operand = ast::Operand<spirv::Word>;
+ type Id = Id;
+ type Operand = ast::Operand<Id>;
}
-impl ArgParamsEx for NormalizedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for NormalizedArgParams {}
type NormalizedStatement = Statement<
(
- Option<ast::PredAt<spirv::Word>>,
+ Option<ast::PredAt<Id>>,
ast::Instruction<NormalizedArgParams>,
),
NormalizedArgParams,
@@ -5456,119 +6074,112 @@ type NormalizedStatement = Statement<
type UnconditionalStatement = Statement<ast::Instruction<NormalizedArgParams>, NormalizedArgParams>;
-enum TypedArgParams {}
+pub(crate) enum TypedArgParams {}
impl ast::ArgParams for TypedArgParams {
- type Id = spirv::Word;
+ type Id = Id;
type Operand = TypedOperand;
}
-impl ArgParamsEx for TypedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for TypedArgParams {}
#[derive(Copy, Clone)]
-enum TypedOperand {
- Reg(spirv::Word),
- RegOffset(spirv::Word, i32),
+pub(crate) enum TypedOperand {
+ Reg(Id),
+ RegOffset(Id, i64),
Imm(ast::ImmediateValue),
- VecMember(spirv::Word, u8),
-}
-
-impl TypedOperand {
- fn upcast(self) -> ast::Operand<spirv::Word> {
- match self {
- TypedOperand::Reg(reg) => ast::Operand::Reg(reg),
- TypedOperand::RegOffset(reg, idx) => ast::Operand::RegOffset(reg, idx),
- TypedOperand::Imm(x) => ast::Operand::Imm(x),
- TypedOperand::VecMember(vec, idx) => ast::Operand::VecMember(vec, idx),
- }
- }
+ VecMember(Id, u8),
}
type TypedStatement = Statement<ast::Instruction<TypedArgParams>, TypedArgParams>;
-enum ExpandedArgParams {}
-type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
+pub(crate) enum ExpandedArgParams {}
+pub(crate) type ExpandedStatement =
+ Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
impl ast::ArgParams for ExpandedArgParams {
- type Id = spirv::Word;
- type Operand = spirv::Word;
+ type Id = Id;
+ type Operand = Id;
}
-impl ArgParamsEx for ExpandedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
+impl ArgParamsEx for ExpandedArgParams {}
+
+pub(crate) type Directive<'input> = TranslationDirective<'input, ExpandedArgParams>;
+
+pub(crate) enum TranslationDirective<'input, P: ast::ArgParams> {
+ Variable(ast::LinkingDirective, Option<Cow<'input, str>>, Variable),
+ Method(TranslationMethod<'input, P>),
}
-enum Directive<'input> {
- Variable(ast::Variable<ast::VariableType, spirv::Word>),
- Method(Function<'input>),
+pub(crate) struct Variable {
+ pub align: Option<u32>,
+ pub type_: ast::Type,
+ pub state_space: ast::StateSpace,
+ pub name: Id,
+ pub initializer: Option<ast::Initializer<Id>>,
}
-struct Function<'input> {
- pub func_decl: ast::MethodDecl<'input, spirv::Word>,
- pub spirv_decl: SpirvMethodDecl<'input>,
- pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
- pub body: Option<Vec<ExpandedStatement>>,
- import_as: Option<String>,
+pub(crate) type Function<'input> = TranslationMethod<'input, ExpandedArgParams>;
+
+pub(crate) struct TranslationMethod<'input, P: ast::ArgParams> {
+ pub(crate) return_arguments: Vec<ast::VariableDeclaration<P::Id>>,
+ pub(crate) name: P::Id,
+ pub(crate) input_arguments: Vec<ast::VariableDeclaration<P::Id>>,
+ pub(crate) body: Option<Vec<Statement<ast::Instruction<P>, P>>>,
+ pub(crate) tuning: Vec<ast::TuningDirective>,
+ pub(crate) is_kernel: bool,
+ pub(crate) source_name: Option<Cow<'input, str>>,
+ pub(crate) special_raytracing_linking: bool,
}
-pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
+pub(crate) trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
fn id(
&mut self,
desc: ArgumentDescriptor<T::Id>,
- typ: Option<&ast::Type>,
+ typ: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<U::Id, TranslateError>;
fn operand(
&mut self,
desc: ArgumentDescriptor<T::Operand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<U::Operand, TranslateError>;
}
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
where
T: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
+ ArgumentDescriptor<Id>,
+ Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError>,
{
fn id(
&mut self,
- desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError> {
+ desc: ArgumentDescriptor<Id>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError> {
self(desc, t)
}
fn operand(
&mut self,
- desc: ArgumentDescriptor<spirv::Word>,
+ desc: ArgumentDescriptor<Id>,
typ: &ast::Type,
- ) -> Result<spirv::Word, TranslateError> {
- self(desc, Some(typ))
+ state_space: ast::StateSpace,
+ ) -> Result<Id, TranslateError> {
+ self(desc, Some((typ, state_space)))
}
}
impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> for T
where
- T: FnMut(&str) -> Result<spirv::Word, TranslateError>,
+ T: FnMut(&str) -> Result<Id, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<&str>,
- _: Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError> {
+ _: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError> {
self(desc.op)
}
@@ -5576,7 +6187,8 @@ where
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
typ: &ast::Type,
- ) -> Result<ast::Operand<spirv::Word>, TranslateError> {
+ state_space: ast::StateSpace,
+ ) -> Result<ast::Operand<Id>, TranslateError> {
Ok(match desc.op {
ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?),
ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id)?, imm),
@@ -5584,39 +6196,37 @@ where
ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member),
ast::Operand::VecPack(ref ids) => ast::Operand::VecPack(
ids.into_iter()
- .map(|id| self.id(desc.new_op(id), Some(typ)))
+ .map(|id_or_immediate| {
+ Ok::<_, TranslateError>(match id_or_immediate {
+ ast::RegOrImmediate::Reg(reg) => ast::RegOrImmediate::Reg(
+ self.id(desc.new_op(reg), Some((typ, state_space)))?,
+ ),
+ ast::RegOrImmediate::Imm(imm) => ast::RegOrImmediate::Imm(*imm),
+ })
+ })
.collect::<Result<Vec<_>, _>>()?,
),
})
}
}
-pub struct ArgumentDescriptor<Op> {
- op: Op,
- is_dst: bool,
- sema: ArgumentSemantics,
+pub(crate) struct ArgumentDescriptor<Op> {
+ pub(crate) op: Op,
+ pub(crate) is_dst: bool,
+ pub(crate) is_memory_access: bool,
+ pub(crate) non_default_implicit_conversion: Option<
+ fn(
+ (ast::StateSpace, &ast::Type),
+ (ast::StateSpace, &ast::Type),
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ >,
}
-
-pub struct PtrAccess<P: ast::ArgParams> {
- underlying_type: ast::PointerType,
- state_space: ast::LdStateSpace,
- dst: spirv::Word,
- ptr_src: spirv::Word,
- offset_src: P::Operand,
-}
-
-#[derive(Copy, Clone, PartialEq, Eq, Debug)]
-pub enum ArgumentSemantics {
- // normal register access
- Default,
- // normal register access with relaxed conversion rules (ld/st)
- DefaultRelaxed,
- // st/ld global
- PhysicalPointer,
- // st/ld .param, .local
- RegisterPointer,
- // mov of .local/.global variables
- Address,
+pub(crate) struct PtrAccess<P: ast::ArgParams> {
+ pub(crate) underlying_type: ast::Type,
+ pub(crate) state_space: ast::StateSpace,
+ pub(crate) dst: Id,
+ pub(crate) ptr_src: Id,
+ pub(crate) offset_src: P::Operand,
}
impl<T> ArgumentDescriptor<T> {
@@ -5624,7 +6234,8 @@ impl<T> ArgumentDescriptor<T> {
ArgumentDescriptor {
op: u,
is_dst: self.is_dst,
- sema: self.sema,
+ is_memory_access: self.is_memory_access,
+ non_default_implicit_conversion: self.non_default_implicit_conversion,
}
}
}
@@ -5639,7 +6250,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?)
}
// Call instruction is converted to a call statement early on
- ast::Instruction::Call(_) => return Err(error_unreachable()),
+ ast::Instruction::Call(_) => return Err(TranslateError::unreachable()),
ast::Instruction::Ld(d, a) => {
let new_args = a.map(visitor, &d)?;
ast::Instruction::Ld(d, new_args)
@@ -5651,53 +6262,44 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
let is_wide = d.is_wide();
- ast::Instruction::Mul(d, a.map_non_shift(visitor, &inst_type, is_wide)?)
+ ast::Instruction::Mul(d, a.map_generic(visitor, &inst_type, is_wide)?)
}
ast::Instruction::Add(d, a) => {
- let inst_type = d.get_type();
- ast::Instruction::Add(d, a.map_non_shift(visitor, &inst_type, false)?)
+ let inst_type = ast::Type::Scalar(d.get_type());
+ ast::Instruction::Add(d, a.map_generic(visitor, &inst_type, false)?)
}
ast::Instruction::Setp(d, a) => {
let inst_type = d.typ;
ast::Instruction::Setp(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
ast::Instruction::SetpBool(d, a) => {
- let inst_type = d.typ;
+ let inst_type = d.base.typ;
ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
- ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?),
+ ast::Instruction::Not(t, a) => {
+ ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?)
+ }
ast::Instruction::Cvt(d, a) => {
- let (dst_t, src_t) = match &d {
- ast::CvtDetails::FloatFromFloat(desc) => (
- ast::Type::Scalar(desc.dst.into()),
- ast::Type::Scalar(desc.src.into()),
- ),
- ast::CvtDetails::FloatFromInt(desc) => (
- ast::Type::Scalar(desc.dst.into()),
- ast::Type::Scalar(desc.src.into()),
- ),
- ast::CvtDetails::IntFromFloat(desc) => (
- ast::Type::Scalar(desc.dst.into()),
- ast::Type::Scalar(desc.src.into()),
- ),
- ast::CvtDetails::IntFromInt(desc) => (
- ast::Type::Scalar(desc.dst.into()),
- ast::Type::Scalar(desc.src.into()),
- ),
+ let (dst_t, src_t, int_to_int) = match &d {
+ ast::CvtDetails::FloatFromFloat(desc) => (desc.dst, desc.src, false),
+ ast::CvtDetails::FloatFromInt(desc) => (desc.dst, desc.src, false),
+ ast::CvtDetails::IntFromFloat(desc) => (desc.dst, desc.src, false),
+ ast::CvtDetails::IntFromInt(desc) => (desc.dst, desc.src, true),
};
- ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?)
+ ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t, int_to_int)?)
}
ast::Instruction::Shl(t, a) => {
- ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?)
+ ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?)
}
ast::Instruction::Shr(t, a) => {
ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?)
}
ast::Instruction::St(d, a) => {
- let new_args = a.map(visitor, &d)?;
+ let new_args = a.map(visitor, &d.typ, d.state_space)?;
ast::Instruction::St(d, new_args)
}
- ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?),
+ ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, false, None)?),
+ ast::Instruction::Exit => ast::Instruction::Exit,
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
ast::Instruction::Cvta(d, a) => {
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
@@ -5708,36 +6310,37 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let is_wide = d.is_wide();
ast::Instruction::Mad(d, a.map(visitor, &inst_type, is_wide)?)
}
+ ast::Instruction::Fma(d, a) => {
+ let inst_type = ast::Type::Scalar(d.typ);
+ ast::Instruction::Fma(d, a.map(visitor, &inst_type, false)?)
+ }
ast::Instruction::Or(t, a) => ast::Instruction::Or(
t,
- a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?,
+ a.map_generic(visitor, &ast::Type::Scalar(t.into()), false)?,
),
ast::Instruction::Sub(d, a) => {
- let typ = d.get_type();
- ast::Instruction::Sub(d, a.map_non_shift(visitor, &typ, false)?)
+ let typ = ast::Type::Scalar(d.get_type());
+ ast::Instruction::Sub(d, a.map_generic(visitor, &typ, false)?)
}
ast::Instruction::Min(d, a) => {
- let typ = d.get_type();
- ast::Instruction::Min(d, a.map_non_shift(visitor, &typ, false)?)
+ let typ = ast::Type::Scalar(d.get_type());
+ ast::Instruction::Min(d, a.map_generic(visitor, &typ, false)?)
}
ast::Instruction::Max(d, a) => {
- let typ = d.get_type();
- ast::Instruction::Max(d, a.map_non_shift(visitor, &typ, false)?)
+ let typ = ast::Type::Scalar(d.get_type());
+ ast::Instruction::Max(d, a.map_generic(visitor, &typ, false)?)
}
ast::Instruction::Rcp(d, a) => {
- let typ = ast::Type::Scalar(if d.is_f64 {
- ast::ScalarType::F64
- } else {
- ast::ScalarType::F32
- });
+ let typ = ast::Type::Scalar(d.type_);
ast::Instruction::Rcp(d, a.map(visitor, &typ)?)
}
ast::Instruction::And(t, a) => ast::Instruction::And(
t,
- a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?,
+ a.map_generic(visitor, &ast::Type::Scalar(t.into()), false)?,
),
ast::Instruction::Selp(t, a) => ast::Instruction::Selp(t, a.map_selp(visitor, t)?),
ast::Instruction::Bar(d, a) => ast::Instruction::Bar(d, a.map(visitor)?),
+ ast::Instruction::BarWarp(d, a) => ast::Instruction::BarWarp(d, a.map(visitor)?),
ast::Instruction::Atom(d, a) => {
ast::Instruction::Atom(d, a.map_atom(visitor, d.inner.get_type(), d.space)?)
}
@@ -5745,10 +6348,11 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?)
}
ast::Instruction::Div(d, a) => {
- ast::Instruction::Div(d, a.map_non_shift(visitor, &d.get_type(), false)?)
+ ast::Instruction::Div(d, a.map_generic(visitor, &d.get_type(), false)?)
}
ast::Instruction::Sqrt(d, a) => {
- ast::Instruction::Sqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
+ let type_ = ast::Type::Scalar(d.type_);
+ ast::Instruction::Sqrt(d, a.map(visitor, &type_)?)
}
ast::Instruction::Rsqrt(d, a) => {
ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?)
@@ -5792,6 +6396,14 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
arg: arg.map_different_types(visitor, &dst_type, &src_type)?,
}
}
+ ast::Instruction::Bfind(details, arg) => {
+ let dst_type = ast::Type::Scalar(ast::ScalarType::B32);
+ let src_type = ast::Type::Scalar(details.type_);
+ ast::Instruction::Bfind(
+ details,
+ arg.map_different_types(visitor, &dst_type, &src_type)?,
+ )
+ }
ast::Instruction::Brev { typ, arg } => {
let full_type = ast::Type::Scalar(typ.into());
ast::Instruction::Brev {
@@ -5811,7 +6423,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let full_type = ast::Type::Scalar(typ.into());
ast::Instruction::Xor {
typ,
- arg: arg.map_non_shift(visitor, &full_type, false)?,
+ arg: arg.map_generic(visitor, &full_type, false)?,
}
}
ast::Instruction::Bfe { typ, arg } => {
@@ -5821,13 +6433,167 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
arg: arg.map_bfe(visitor, &full_type)?,
}
}
+ ast::Instruction::Bfi { typ, arg } => {
+ let full_type = ast::Type::Scalar(typ.into());
+ ast::Instruction::Bfi {
+ typ,
+ arg: arg.map_bfi(visitor, &full_type)?,
+ }
+ }
ast::Instruction::Rem { typ, arg } => {
let full_type = ast::Type::Scalar(typ.into());
ast::Instruction::Rem {
typ,
- arg: arg.map_non_shift(visitor, &full_type, false)?,
+ arg: arg.map_generic(visitor, &full_type, false)?,
}
}
+ ast::Instruction::Prmt { control, arg } => ast::Instruction::Prmt {
+ control,
+ arg: arg.map_prmt(visitor)?,
+ },
+ ast::Instruction::PrmtSlow { control, arg } => ast::Instruction::PrmtSlow {
+ arg: arg.map_prmt(visitor)?,
+ control: ast::Arg1 { src: control }
+ .map(
+ visitor,
+ false,
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )),
+ )?
+ .src,
+ },
+ ast::Instruction::Activemask { arg } => ast::Instruction::Activemask {
+ arg: arg.map(
+ visitor,
+ true,
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )),
+ )?,
+ },
+ ast::Instruction::Membar { level } => ast::Instruction::Membar { level },
+ ast::Instruction::MadC {
+ type_,
+ arg,
+ is_hi,
+ carry_out,
+ } => ast::Instruction::MadC {
+ type_,
+ is_hi,
+ carry_out,
+ arg: arg.map(visitor, &ast::Type::Scalar(type_), false)?,
+ },
+ ast::Instruction::MadCC { type_, arg } => ast::Instruction::MadCC {
+ type_,
+ arg: arg.map(visitor, &ast::Type::Scalar(type_), false)?,
+ },
+ ast::Instruction::Tex(details, arg) => {
+ let image_type_space = if details.direct {
+ (ast::Type::Texref, ast::StateSpace::Global)
+ } else {
+ (
+ ast::Type::Scalar(ast::ScalarType::B64),
+ ast::StateSpace::Reg,
+ )
+ };
+ let arg = arg.map(
+ visitor,
+ image_type_space,
+ details.geometry,
+ ast::Type::Vector(details.channel_type, 4),
+ details.coordinate_type,
+ )?;
+ ast::Instruction::Tex(details, arg)
+ }
+ ast::Instruction::Suld(details, arg) => {
+ let arg = arg.map(
+ visitor,
+ (ast::Type::Surfref, ast::StateSpace::Global),
+ details.geometry,
+ details.value_type(),
+ ast::ScalarType::B32,
+ )?;
+ ast::Instruction::Suld(details, arg)
+ }
+ ast::Instruction::Sust(details, arg) => {
+ let arg = arg.map(visitor, &details)?;
+ ast::Instruction::Sust(details, arg)
+ }
+ ast::Instruction::Shfl(mode, arg) => {
+ let arg = arg.map(visitor)?;
+ ast::Instruction::Shfl(mode, arg)
+ }
+ ast::Instruction::Shf(details, arg) => {
+ let arg = arg.map(visitor, &ast::Type::Scalar(ast::ScalarType::B32), false)?;
+ ast::Instruction::Shf(details, arg)
+ }
+ ast::Instruction::Vote(details, arg) => {
+ let arg = arg.map_vote(visitor, details.mode)?;
+ ast::Instruction::Vote(details, arg)
+ }
+ ast::Instruction::BarRed(details, arg) => {
+ let arg = arg.map_bar_red(visitor, details)?;
+ ast::Instruction::BarRed(details, arg)
+ }
+ ast::Instruction::Trap => ast::Instruction::Trap,
+ ast::Instruction::Brkpt => ast::Instruction::Brkpt,
+ ast::Instruction::AddC(details, arg) => {
+ let arg = arg.map_generic(visitor, &ast::Type::Scalar(details.type_), false)?;
+ ast::Instruction::AddC(details, arg)
+ }
+ ast::Instruction::AddCC(type_, arg) => {
+ let arg = arg.map_generic(visitor, &ast::Type::Scalar(type_), false)?;
+ ast::Instruction::AddCC(type_, arg)
+ }
+ ast::Instruction::SubC(details, arg) => {
+ let arg = arg.map_generic(visitor, &ast::Type::Scalar(details.type_), false)?;
+ ast::Instruction::SubC(details, arg)
+ }
+ ast::Instruction::SubCC(type_, arg) => {
+ let arg = arg.map_generic(visitor, &ast::Type::Scalar(type_), false)?;
+ ast::Instruction::SubCC(type_, arg)
+ }
+ ast::Instruction::Vshr(arg) => ast::Instruction::Vshr(arg.map(
+ visitor,
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ false,
+ )?),
+ ast::Instruction::Set(details, arg) => {
+ let arg = arg.map_different_types(
+ visitor,
+ &ast::Type::Scalar(details.dst_type),
+ &ast::Type::Scalar(details.src_type),
+ )?;
+ ast::Instruction::Set(details, arg)
+ }
+ ast::Instruction::Dp4a(type_, arg) => {
+ let arg = arg.map(visitor, &ast::Type::Scalar(type_), false)?;
+ ast::Instruction::Dp4a(type_, arg)
+ }
+ ast::Instruction::MatchAny(arg) => {
+ let arg =
+ arg.map_generic(visitor, &ast::Type::Scalar(ast::ScalarType::B32), false)?;
+ ast::Instruction::MatchAny(arg)
+ }
+ ast::Instruction::Red(details, args) => {
+ let args = args.map(
+ visitor,
+ &ast::Type::Scalar(details.inner.get_type()),
+ details.space,
+ )?;
+ ast::Instruction::Red(details, args)
+ }
+ ast::Instruction::Nanosleep(a) => ast::Instruction::Nanosleep(a.map(
+ visitor,
+ false,
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ )),
+ )?),
})
}
}
@@ -5842,11 +6608,7 @@ impl<T: ArgParamsEx, U: ArgParamsEx> Visitable<T, U> for ast::Instruction<T> {
}
impl ImplicitConversion {
- fn map<
- T: ArgParamsEx<Id = spirv::Word>,
- U: ArgParamsEx<Id = spirv::Word>,
- V: ArgumentMapVisitor<T, U>,
- >(
+ fn map<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
@@ -5854,17 +6616,19 @@ impl ImplicitConversion {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: self.dst_sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&self.to),
+ Some((&self.to_type, self.to_space)),
)?;
let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: self.src_sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&self.from),
+ Some((&self.from_type, self.from_space)),
)?;
Ok(Statement::Conversion({
ImplicitConversion {
@@ -5876,7 +6640,13 @@ impl ImplicitConversion {
}
}
-impl<From: ArgParamsEx<Id = spirv::Word>, To: ArgParamsEx<Id = spirv::Word>> Visitable<From, To>
+#[derive(Copy, Clone)]
+pub(crate) enum FPDenormMode {
+ FlushToZero,
+ Preserve,
+}
+
+impl<From: ArgParamsEx<Id = Id>, To: ArgParamsEx<Id = Id>> Visitable<From, To>
for ImplicitConversion
{
fn visit(
@@ -5890,15 +6660,15 @@ impl<From: ArgParamsEx<Id = spirv::Word>, To: ArgParamsEx<Id = spirv::Word>> Vis
impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T
where
T: FnMut(
- ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError>,
+ ArgumentDescriptor<Id>,
+ Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError>,
{
fn id(
&mut self,
- desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
- ) -> Result<spirv::Word, TranslateError> {
+ desc: ArgumentDescriptor<Id>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<Id, TranslateError> {
self(desc, t)
}
@@ -5906,106 +6676,101 @@ where
&mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
- TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?),
+ TypedOperand::Reg(id) => {
+ TypedOperand::Reg(self(desc.new_op(id), Some((typ, state_space)))?)
+ }
TypedOperand::Imm(imm) => TypedOperand::Imm(imm),
TypedOperand::RegOffset(id, imm) => {
- TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm)
+ TypedOperand::RegOffset(self(desc.new_op(id), Some((typ, state_space)))?, imm)
}
TypedOperand::VecMember(reg, index) => {
let scalar_type = match typ {
ast::Type::Scalar(scalar_t) => *scalar_t,
- _ => return Err(error_unreachable()),
+ _ => return Err(TranslateError::unreachable()),
};
let vec_type = ast::Type::Vector(scalar_type, index + 1);
- TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index)
+ TypedOperand::VecMember(
+ self(desc.new_op(reg), Some((&vec_type, state_space)))?,
+ index,
+ )
}
})
}
}
impl ast::Type {
- fn widen(self) -> Result<Self, TranslateError> {
+ pub(crate) fn widen(self) -> Result<Self, TranslateError> {
match self {
- ast::Type::Scalar(scalar) => {
- let kind = scalar.kind();
- let width = scalar.size_of();
- if (kind != ScalarKind::Signed
- && kind != ScalarKind::Unsigned
- && kind != ScalarKind::Bit)
- || (width == 8)
- {
- return Err(TranslateError::MismatchedType);
- }
- Ok(ast::Type::Scalar(ast::ScalarType::from_parts(
- width * 2,
- kind,
- )))
- }
- _ => Err(error_unreachable()),
+ ast::Type::Scalar(scalar) => Ok(ast::Type::Scalar(scalar.widen()?)),
+ _ => Err(TranslateError::unreachable()),
}
}
- fn to_parts(&self) -> TypeParts {
+ pub(crate) fn to_parts(&self) -> TypeParts {
+ let width = self.layout().size() as u8;
match self {
ast::Type::Scalar(scalar) => TypeParts {
kind: TypeKind::Scalar,
+ state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
- width: scalar.size_of(),
components: Vec::new(),
- state_space: ast::LdStateSpace::Global,
+ width,
},
ast::Type::Vector(scalar, components) => TypeParts {
kind: TypeKind::Vector,
+ state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
- width: scalar.size_of(),
components: vec![*components as u32],
- state_space: ast::LdStateSpace::Global,
+ width,
},
ast::Type::Array(scalar, components) => TypeParts {
kind: TypeKind::Array,
+ state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
- width: scalar.size_of(),
components: components.clone(),
- state_space: ast::LdStateSpace::Global,
+ width,
},
- ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts {
- kind: TypeKind::PointerScalar,
+ ast::Type::Pointer(scalar, space) => TypeParts {
+ kind: TypeKind::Pointer,
+ state_space: *space,
scalar_kind: scalar.kind(),
- width: scalar.size_of(),
components: Vec::new(),
- state_space: *state_space,
+ width,
},
- ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts {
- kind: TypeKind::PointerVector,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: vec![*len as u32],
- state_space: *state_space,
+ ast::Type::Texref => TypeParts {
+ kind: TypeKind::Texref,
+ state_space: ast::StateSpace::Global,
+ scalar_kind: ast::ScalarKind::Bit,
+ components: Vec::new(),
+ width,
},
- ast::Type::Pointer(ast::PointerType::Array(scalar, components), state_space) => {
- TypeParts {
- kind: TypeKind::PointerArray,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: components.clone(),
- state_space: *state_space,
- }
- }
- ast::Type::Pointer(ast::PointerType::Pointer(scalar, inner_space), state_space) => {
+ ast::Type::Surfref => TypeParts {
+ kind: TypeKind::Surfref,
+ state_space: ast::StateSpace::Global,
+ scalar_kind: ast::ScalarKind::Bit,
+ components: Vec::new(),
+ width,
+ },
+ ast::Type::Struct(fields) => {
+ let components = fields
+ .iter()
+ .map(|field| unsafe { mem::transmute::<_, u16>(*field) as u32 })
+ .collect();
TypeParts {
- kind: TypeKind::PointerPointer,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: vec![*inner_space as u32],
- state_space: *state_space,
+ kind: TypeKind::Struct,
+ state_space: ast::StateSpace::Reg,
+ scalar_kind: ast::ScalarKind::Bit,
+ components,
+ width,
}
}
}
}
- fn from_parts(t: TypeParts) -> Self {
+ pub(crate) fn from_parts(t: TypeParts) -> Self {
match t.kind {
TypeKind::Scalar => {
ast::Type::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind))
@@ -6018,77 +6783,106 @@ impl ast::Type {
ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components,
),
- TypeKind::PointerScalar => ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)),
- t.state_space,
- ),
- TypeKind::PointerVector => ast::Type::Pointer(
- ast::PointerType::Vector(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- t.components[0] as u8,
- ),
- t.state_space,
- ),
- TypeKind::PointerArray => ast::Type::Pointer(
- ast::PointerType::Array(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- t.components,
- ),
+ TypeKind::Pointer => ast::Type::Pointer(
+ ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.state_space,
),
- TypeKind::PointerPointer => ast::Type::Pointer(
- ast::PointerType::Pointer(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- unsafe { mem::transmute::<_, ast::LdStateSpace>(t.components[0] as u8) },
- ),
- t.state_space,
+ TypeKind::Texref => ast::Type::Texref,
+ TypeKind::Surfref => ast::Type::Surfref,
+ TypeKind::Struct => ast::Type::Struct(
+ t.components
+ .into_iter()
+ .map(|component| unsafe { mem::transmute(component as u16) })
+ .collect(),
),
}
}
- pub fn size_of(&self) -> usize {
+ pub(crate) fn layout(&self) -> Layout {
match self {
- ast::Type::Scalar(typ) => typ.size_of() as usize,
- ast::Type::Vector(typ, len) => (typ.size_of() as usize) * (*len as usize),
- ast::Type::Array(typ, len) => len
- .iter()
- .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)),
- ast::Type::Pointer(_, _) => mem::size_of::<usize>(),
+ ast::Type::Scalar(typ) => {
+ let size_of = typ.size_of() as usize;
+ unsafe { Layout::from_size_align_unchecked(size_of, size_of) }
+ }
+ ast::Type::Vector(typ, len) => {
+ let size_of = typ.size_of() as usize * (*len) as usize;
+ unsafe { Layout::from_size_align_unchecked(size_of, size_of) }
+ }
+ ast::Type::Array(typ, len) => {
+ let scalar_size_of = typ.size_of() as usize;
+ let len = len
+ .iter()
+ .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize));
+ unsafe { Layout::from_size_align_unchecked(scalar_size_of * len, scalar_size_of) }
+ }
+ ast::Type::Struct(fields) => {
+ let mut layout = Layout::new::<()>();
+ for field in fields {
+ layout = layout.extend(field.to_type().layout()).unwrap().0
+ }
+ layout.pad_to_align()
+ }
+ ast::Type::Pointer(..) => Layout::new::<*const ()>(),
+ ast::Type::Texref | ast::Type::Surfref => Layout::new::<usize>(),
}
}
}
#[derive(Eq, PartialEq, Clone)]
-struct TypeParts {
- kind: TypeKind,
- scalar_kind: ScalarKind,
- width: u8,
- components: Vec<u32>,
- state_space: ast::LdStateSpace,
+pub(crate) struct TypeParts {
+ pub(crate) kind: TypeKind,
+ pub(crate) scalar_kind: ast::ScalarKind,
+ pub(crate) width: u8,
+ pub(crate) state_space: ast::StateSpace,
+ pub(crate) components: Vec<u32>,
}
#[derive(Eq, PartialEq, Copy, Clone)]
-enum TypeKind {
+pub(crate) enum TypeKind {
Scalar,
Vector,
Array,
- PointerScalar,
- PointerVector,
- PointerArray,
- PointerPointer,
+ Pointer,
+ Texref,
+ Surfref,
+ Struct,
}
-impl ast::Instruction<ExpandedArgParams> {
- fn jump_target(&self) -> Option<spirv::Word> {
+impl<T: ast::ArgParams<Id = Id>> ast::Instruction<T> {
+ fn jump_target(&self) -> Option<Id> {
match self {
ast::Instruction::Bra(_, a) => Some(a.src),
_ => None,
}
}
+}
+impl<T: ast::ArgParams> ast::Instruction<T> {
// .wide instructions don't support ftz, so it's enough to just look at the
// type declared by the instruction
fn flush_to_zero(&self) -> Option<(bool, u8)> {
+ fn scalar_size_of(type_: ast::ScalarType) -> u8 {
+ match type_ {
+ ast::ScalarType::U8 => 1,
+ ast::ScalarType::S8 => 1,
+ ast::ScalarType::B8 => 1,
+ ast::ScalarType::U16 => 2,
+ ast::ScalarType::S16 => 2,
+ ast::ScalarType::B16 => 2,
+ ast::ScalarType::F16 => 2,
+ ast::ScalarType::U32 => 4,
+ ast::ScalarType::S32 => 4,
+ ast::ScalarType::B32 => 4,
+ ast::ScalarType::F32 => 4,
+ ast::ScalarType::U64 => 8,
+ ast::ScalarType::S64 => 8,
+ ast::ScalarType::B64 => 8,
+ ast::ScalarType::F64 => 8,
+ ast::ScalarType::F16x2 => 2,
+ ast::ScalarType::Pred => 1,
+ }
+ }
+
match self {
ast::Instruction::Ld(_, _) => None,
ast::Instruction::St(_, _) => None,
@@ -6098,14 +6892,29 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Shl(_, _) => None,
ast::Instruction::Shr(_, _) => None,
ast::Instruction::Ret(_) => None,
+ ast::Instruction::Exit => None,
+ ast::Instruction::Trap => None,
+ ast::Instruction::Brkpt => None,
ast::Instruction::Call(_) => None,
ast::Instruction::Or(_, _) => None,
ast::Instruction::And(_, _) => None,
ast::Instruction::Cvta(_, _) => None,
ast::Instruction::Selp(_, _) => None,
ast::Instruction::Bar(_, _) => None,
+ ast::Instruction::BarWarp(_, _) => None,
ast::Instruction::Atom(_, _) => None,
+ ast::Instruction::Red(_, _) => None,
ast::Instruction::AtomCas(_, _) => None,
+ ast::Instruction::MadC { .. } => None,
+ ast::Instruction::MadCC { .. } => None,
+ ast::Instruction::BarRed { .. } => None,
+ ast::Instruction::AddC { .. } => None,
+ ast::Instruction::AddCC { .. } => None,
+ ast::Instruction::SubC { .. } => None,
+ ast::Instruction::SubCC { .. } => None,
+ ast::Instruction::Vshr { .. } => None,
+ ast::Instruction::Dp4a { .. } => None,
+ ast::Instruction::MatchAny { .. } => None,
ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None,
ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None,
ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None,
@@ -6123,33 +6932,58 @@ impl ast::Instruction<ExpandedArgParams> {
ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None,
ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None,
ast::Instruction::Clz { .. } => None,
+ ast::Instruction::Bfind(..) => None,
ast::Instruction::Brev { .. } => None,
ast::Instruction::Popc { .. } => None,
ast::Instruction::Xor { .. } => None,
ast::Instruction::Bfe { .. } => None,
+ ast::Instruction::Bfi { .. } => None,
ast::Instruction::Rem { .. } => None,
+ ast::Instruction::Prmt { .. } => None,
+ ast::Instruction::PrmtSlow { .. } => None,
+ ast::Instruction::Activemask { .. } => None,
+ ast::Instruction::Membar { .. } => None,
+ ast::Instruction::Tex(..) => None,
+ ast::Instruction::Suld(..) => None,
+ ast::Instruction::Sust(..) => None,
+ ast::Instruction::Shfl(..) => None,
+ ast::Instruction::Shf(..) => None,
+ ast::Instruction::Vote(..) => None,
+ ast::Instruction::Nanosleep(..) => None,
ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Add(ast::ArithDetails::Float(float_control), _)
| ast::Instruction::Mul(ast::MulDetails::Float(float_control), _)
- | ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => float_control
- .flush_to_zero
- .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())),
+ | ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => {
+ float_control.flush_to_zero.map(|ftz| {
+ (
+ ftz,
+ scalar_size_of(ast::ScalarType::from(float_control.typ)),
+ )
+ })
+ }
+ ast::Instruction::Fma(d, _) => d.flush_to_zero.map(|ftz| (ftz, scalar_size_of(d.typ))),
ast::Instruction::Setp(details, _) => details
.flush_to_zero
- .map(|ftz| (ftz, details.typ.size_of())),
+ .map(|ftz| (ftz, scalar_size_of(details.typ))),
ast::Instruction::SetpBool(details, _) => details
+ .base
.flush_to_zero
- .map(|ftz| (ftz, details.typ.size_of())),
+ .map(|ftz| (ftz, scalar_size_of(details.base.typ))),
ast::Instruction::Abs(details, _) => details
.flush_to_zero
- .map(|ftz| (ftz, details.typ.size_of())),
+ .map(|ftz| (ftz, scalar_size_of(details.typ))),
ast::Instruction::Min(ast::MinMaxDetails::Float(float_control), _)
- | ast::Instruction::Max(ast::MinMaxDetails::Float(float_control), _) => float_control
- .flush_to_zero
- .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())),
+ | ast::Instruction::Max(ast::MinMaxDetails::Float(float_control), _) => {
+ float_control.flush_to_zero.map(|ftz| {
+ (
+ ftz,
+ scalar_size_of(ast::ScalarType::from(float_control.typ)),
+ )
+ })
+ }
ast::Instruction::Rcp(details, _) => details
.flush_to_zero
- .map(|ftz| (ftz, if details.is_f64 { 8 } else { 4 })),
+ .map(|ftz| (ftz, scalar_size_of(details.type_))),
// Modifier .ftz can only be specified when either .dtype or .atype
// is .f32 and applies only to single precision (.f32) inputs and results.
ast::Instruction::Cvt(
@@ -6162,23 +6996,31 @@ impl ast::Instruction<ExpandedArgParams> {
) => flush_to_zero.map(|ftz| (ftz, 4)),
ast::Instruction::Div(ast::DivDetails::Float(details), _) => details
.flush_to_zero
- .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
+ .map(|ftz| (ftz, scalar_size_of(ast::ScalarType::from(details.typ)))),
ast::Instruction::Sqrt(details, _) => details
.flush_to_zero
- .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())),
+ .map(|ftz| (ftz, scalar_size_of(details.type_))),
ast::Instruction::Rsqrt(details, _) => Some((
details.flush_to_zero,
- ast::ScalarType::from(details.typ).size_of(),
+ scalar_size_of(ast::ScalarType::from(details.typ)),
)),
ast::Instruction::Neg(details, _) => details
.flush_to_zero
- .map(|ftz| (ftz, details.typ.size_of())),
+ .map(|ftz| (ftz, scalar_size_of(details.typ))),
ast::Instruction::Sin { flush_to_zero, .. }
| ast::Instruction::Cos { flush_to_zero, .. }
| ast::Instruction::Lg2 { flush_to_zero, .. }
| ast::Instruction::Ex2 { flush_to_zero, .. } => {
Some((*flush_to_zero, mem::size_of::<f32>() as u8))
}
+ ast::Instruction::Set(
+ ast::SetData {
+ flush_to_zero,
+ src_type,
+ ..
+ },
+ _,
+ ) => Some((*flush_to_zero, scalar_size_of(*src_type))),
}
}
}
@@ -6186,37 +7028,64 @@ impl ast::Instruction<ExpandedArgParams> {
type Arg2 = ast::Arg2<ExpandedArgParams>;
type Arg2St = ast::Arg2St<ExpandedArgParams>;
-struct ConstantDefinition {
- pub dst: spirv::Word,
+pub(crate) struct ConstantDefinition {
+ pub dst: Id,
pub typ: ast::ScalarType,
pub value: ast::ImmediateValue,
}
-struct BrachCondition {
- predicate: spirv::Word,
- if_true: spirv::Word,
- if_false: spirv::Word,
+pub(crate) struct BrachCondition {
+ pub(crate) predicate: Id,
+ pub(crate) if_true: Id,
+ pub(crate) if_false: Id,
+}
+
+impl<From: ArgParamsEx<Id = Id>, To: ArgParamsEx<Id = Id>> Visitable<From, To> for BrachCondition {
+ fn visit(
+ self,
+ visitor: &mut impl ArgumentMapVisitor<From, To>,
+ ) -> Result<Statement<ast::Instruction<To>, To>, TranslateError> {
+ let predicate = visitor.id(
+ ArgumentDescriptor {
+ op: self.predicate,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
+ )?;
+ let if_true = self.if_true;
+ let if_false = self.if_false;
+ Ok(Statement::Conditional(BrachCondition {
+ predicate,
+ if_true,
+ if_false,
+ }))
+ }
}
#[derive(Clone)]
-struct ImplicitConversion {
- src: spirv::Word,
- dst: spirv::Word,
- from: ast::Type,
- to: ast::Type,
- kind: ConversionKind,
- src_sema: ArgumentSemantics,
- dst_sema: ArgumentSemantics,
-}
-
-#[derive(PartialEq, Copy, Clone)]
-enum ConversionKind {
+pub(crate) struct ImplicitConversion {
+ pub(crate) src: Id,
+ pub(crate) dst: Id,
+ pub(crate) from_type: ast::Type,
+ pub(crate) to_type: ast::Type,
+ pub(crate) from_space: ast::StateSpace,
+ pub(crate) to_space: ast::StateSpace,
+ pub(crate) kind: ConversionKind,
+}
+
+#[derive(PartialEq, Clone)]
+pub(crate) enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
- BitToPtr(ast::LdStateSpace),
- PtrToBit(ast::UIntType),
- PtrToPtr { spirv_ptr: bool },
+ BitToPtr,
+ PtrToPtr,
+ AddressOf,
}
impl<T> ast::PredAt<T> {
@@ -6233,10 +7102,13 @@ impl<T> ast::PredAt<T> {
}
impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
- fn map_variable<F: FnMut(&str) -> Result<spirv::Word, TranslateError>>(
+ fn map_variable<F>(
self,
f: &mut F,
- ) -> Result<ast::Instruction<NormalizedArgParams>, TranslateError> {
+ ) -> Result<ast::Instruction<NormalizedArgParams>, TranslateError>
+ where
+ F: for<'x> FnMut(&'x str) -> Result<Id, TranslateError>,
+ {
match self {
ast::Instruction::Call(call) => {
let call_inst = ast::CallInst {
@@ -6252,6 +7124,7 @@ impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
.into_iter()
.map(|p| p.map_variable(f))
.collect::<Result<_, _>>()?,
+ prototype: call.prototype.map(f).transpose()?,
};
Ok(ast::Instruction::Call(call_inst))
}
@@ -6260,17 +7133,419 @@ impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
}
}
+pub(crate) struct Arg4CarryOut<P: ast::ArgParams> {
+ pub dst: P::Operand,
+ pub carry_out: P::Operand,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+ pub src3: P::Operand,
+}
+
+impl<T: ArgParamsEx> Arg4CarryOut<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ type_: ast::ScalarType,
+ ) -> Result<Arg4CarryOut<U>, TranslateError> {
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ let carry_out = visitor.operand(
+ ArgumentDescriptor {
+ op: self.carry_out,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ let src3 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(Arg4CarryOut {
+ dst,
+ src1,
+ src2,
+ src3,
+ carry_out,
+ })
+ }
+
+ fn new(arg: ast::Arg4<T>, carry_flag: T::Operand) -> Arg4CarryOut<T> {
+ Arg4CarryOut {
+ dst: arg.dst,
+ src1: arg.src1,
+ src2: arg.src2,
+ src3: arg.src3,
+ carry_out: carry_flag,
+ }
+ }
+}
+
+pub(crate) struct Arg4CarryIn<P: ast::ArgParams> {
+ pub dst: P::Operand,
+ pub carry_out: Option<P::Operand>,
+ pub carry_in: P::Operand,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+ pub src3: P::Operand,
+}
+
+impl<T: ArgParamsEx> Arg4CarryIn<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ type_: ast::ScalarType,
+ ) -> Result<Arg4CarryIn<U>, TranslateError> {
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ let carry_out = self
+ .carry_out
+ .map(|carry_out| {
+ visitor.operand(
+ ArgumentDescriptor {
+ op: carry_out,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )
+ })
+ .transpose()?;
+ let carry_in = visitor.operand(
+ ArgumentDescriptor {
+ op: self.carry_in,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ let src3 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(Arg4CarryIn {
+ dst,
+ src1,
+ src2,
+ src3,
+ carry_in,
+ carry_out,
+ })
+ }
+}
+
+impl<T: ArgParamsEx> Arg4CarryIn<T>
+where
+ T::Operand: Copy,
+{
+ fn new(arg: ast::Arg4<T>, carry_out: bool, carry_flag: T::Operand) -> Arg4CarryIn<T> {
+ Arg4CarryIn {
+ dst: arg.dst,
+ src1: arg.src1,
+ src2: arg.src2,
+ src3: arg.src3,
+ carry_in: carry_flag,
+ carry_out: if carry_out { Some(carry_flag) } else { None },
+ }
+ }
+}
+
+pub(crate) struct Arg3CarryOut<P: ast::ArgParams> {
+ pub dst: P::Operand,
+ pub carry_flag: P::Operand,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+}
+
+impl<P: ArgParamsEx> Arg3CarryOut<P> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<P, U>>(
+ self,
+ visitor: &mut V,
+ type_: ast::ScalarType,
+ ) -> Result<Arg3CarryOut<U>, TranslateError> {
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ let carry_flag = visitor.operand(
+ ArgumentDescriptor {
+ op: self.carry_flag,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(Arg3CarryOut {
+ dst,
+ carry_flag,
+ src1,
+ src2,
+ })
+ }
+
+ fn new(args: ast::Arg3<P>, carry_flag: P::Operand) -> Arg3CarryOut<P> {
+ Self {
+ dst: args.dst,
+ carry_flag,
+ src1: args.src1,
+ src2: args.src2,
+ }
+ }
+}
+
+pub(crate) struct Arg3CarryIn<P: ast::ArgParams> {
+ pub dst: P::Operand,
+ pub carry_out: Option<P::Operand>,
+ pub carry_in: P::Operand,
+ pub src1: P::Operand,
+ pub src2: P::Operand,
+}
+
+impl<P: ArgParamsEx> Arg3CarryIn<P> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<P, U>>(
+ self,
+ visitor: &mut V,
+ type_: ast::ScalarType,
+ ) -> Result<Arg3CarryIn<U>, TranslateError> {
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ let carry_out = self
+ .carry_out
+ .map(|carry_out| {
+ visitor.operand(
+ ArgumentDescriptor {
+ op: carry_out,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )
+ })
+ .transpose()?;
+ let carry_in = visitor.operand(
+ ArgumentDescriptor {
+ op: self.carry_in,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(type_),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(Arg3CarryIn {
+ dst,
+ carry_in,
+ carry_out,
+ src1,
+ src2,
+ })
+ }
+}
+
+impl<P: ArgParamsEx> Arg3CarryIn<P>
+where
+ P::Operand: Copy,
+{
+ fn new(args: ast::Arg3<P>, carry_out: bool, carry_flag: P::Operand) -> Arg3CarryIn<P> {
+ Arg3CarryIn {
+ dst: args.dst,
+ carry_in: carry_flag,
+ carry_out: if carry_out { Some(carry_flag) } else { None },
+ src1: args.src1,
+ src2: args.src2,
+ }
+ }
+}
+
+pub(crate) struct VisitAddC<P: ast::ArgParams>(pub ast::ScalarType, pub Arg3CarryIn<P>);
+
+impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for VisitAddC<T> {
+ fn visit(
+ self,
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::AddC(self.0, self.1.map(visitor, self.0)?))
+ }
+}
+
+pub(crate) struct VisitAddCC<P: ast::ArgParams>(pub ast::ScalarType, pub Arg3CarryOut<P>);
+
+impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for VisitAddCC<T> {
+ fn visit(
+ self,
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::AddCC(self.0, self.1.map(visitor, self.0)?))
+ }
+}
+
+pub(crate) struct VisitSubC<P: ast::ArgParams>(pub ast::ScalarType, pub Arg3CarryIn<P>);
+
+impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for VisitSubC<T> {
+ fn visit(
+ self,
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::SubC(self.0, self.1.map(visitor, self.0)?))
+ }
+}
+
+pub(crate) struct VisitSubCC<P: ast::ArgParams>(pub ast::ScalarType, pub Arg3CarryOut<P>);
+
+impl<T: ArgParamsEx<Id = Id>, U: ArgParamsEx<Id = Id>> Visitable<T, U> for VisitSubCC<T> {
+ fn visit(
+ self,
+ visitor: &mut impl ArgumentMapVisitor<T, U>,
+ ) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
+ Ok(Statement::SubCC(self.0, self.1.map(visitor, self.0)?))
+ }
+}
+
impl<T: ArgParamsEx> ast::Arg1<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<&ast::Type>,
+ is_dst: bool,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<ast::Arg1<U>, TranslateError> {
let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_dst,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
)?;
@@ -6287,9 +7562,11 @@ impl<T: ArgParamsEx> ast::Arg1Bar<T> {
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg1Bar { src: new_src })
}
@@ -6305,17 +7582,21 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let new_src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2 {
dst: new_dst,
@@ -6323,6 +7604,44 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
})
}
+ fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ dst_t: ast::ScalarType,
+ src_t: ast::ScalarType,
+ is_int_to_int: bool,
+ ) -> Result<ast::Arg2<U>, TranslateError> {
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: if is_int_to_int {
+ Some(should_convert_relaxed_dst_wrapper)
+ } else {
+ None
+ },
+ },
+ &ast::Type::Scalar(dst_t),
+ ast::StateSpace::Reg,
+ )?;
+ let src = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: if is_int_to_int {
+ Some(should_convert_relaxed_src_wrapper)
+ } else {
+ None
+ },
+ },
+ &ast::Type::Scalar(src_t),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(ast::Arg2 { dst, src })
+ }
+
fn map_different_types<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
@@ -6333,17 +7652,21 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
dst_t,
+ ast::StateSpace::Reg,
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
src_t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2 { dst, src })
}
@@ -6359,26 +7682,21 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::DefaultRelaxed,
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(should_convert_relaxed_dst_wrapper),
},
&ast::Type::from(details.typ.clone()),
+ ast::StateSpace::Reg,
)?;
- let is_logical_ptr = details.state_space == ast::LdStateSpace::Param
- || details.state_space == ast::LdStateSpace::Local;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: if is_logical_ptr {
- ArgumentSemantics::RegisterPointer
- } else {
- ArgumentSemantics::PhysicalPointer
- },
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::from(details.typ.clone()),
- details.state_space,
- ),
+ &details.typ,
+ details.state_space,
)?;
Ok(ast::Arg2Ld { dst, src })
}
@@ -6388,32 +7706,28 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- details: &ast::StData,
+ type_: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<ast::Arg2St<U>, TranslateError> {
- let is_logical_ptr = details.state_space == ast::StStateSpace::Param
- || details.state_space == ast::StStateSpace::Local;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: if is_logical_ptr {
- ArgumentSemantics::RegisterPointer
- } else {
- ArgumentSemantics::PhysicalPointer
- },
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::from(details.typ.clone()),
- details.state_space.to_ld_ss(),
- ),
+ &type_,
+ state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::DefaultRelaxed,
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(should_convert_relaxed_src_wrapper),
},
- &details.typ.clone().into(),
+ &type_.clone().into(),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2St { src1, src2 })
}
@@ -6429,28 +7743,28 @@ impl<T: ArgParamsEx> ast::Arg2Mov<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: if details.src_is_address {
- ArgumentSemantics::Address
- } else {
- ArgumentSemantics::Default
- },
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(implicit_conversion_mov),
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2Mov { dst, src })
}
}
impl<T: ArgParamsEx> ast::Arg3<T> {
- fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ fn map_generic<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
typ: &ast::Type,
@@ -6465,25 +7779,70 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
wide_type.as_ref().unwrap_or(typ),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
+ )?;
+ Ok(ast::Arg3 { dst, src1, src2 })
+ }
+
+ fn map_different_types<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ dst_type: &ast::Type,
+ src_type: &ast::Type,
+ ) -> Result<ast::Arg3<U>, TranslateError> {
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ dst_type,
+ ast::StateSpace::Reg,
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ src_type,
+ ast::StateSpace::Reg,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ src_type,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6497,25 +7856,31 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6524,38 +7889,322 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
self,
visitor: &mut V,
t: ast::ScalarType,
- state_space: ast::AtomSpace,
+ state_space: ast::StateSpace,
) -> Result<ast::Arg3<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::PhysicalPointer,
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::Scalar(scalar_type),
- state_space.to_ld_ss(),
- ),
+ &ast::Type::Scalar(scalar_type),
+ state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(ast::Arg3 { dst, src1, src2 })
+ }
+
+ fn map_prmt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ ) -> Result<ast::Arg3<U>, TranslateError> {
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(ast::Arg3 { dst, src1, src2 })
+ }
+
+ fn map_vote<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ mode: ast::VoteMode,
+ ) -> Result<ast::Arg3<U>, TranslateError> {
+ let dst_type = match mode {
+ ast::VoteMode::Ballot => ast::ScalarType::B32,
+ ast::VoteMode::All | ast::VoteMode::Any | ast::VoteMode::Uni => ast::ScalarType::Pred,
+ };
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(dst_type),
+ ast::StateSpace::Reg,
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
+
+ fn map_bar_red<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ op: ast::ReductionOp,
+ ) -> Result<ast::Arg3<U>, TranslateError> {
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(op.dst_type()),
+ ast::StateSpace::Reg,
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(ast::Arg3 { dst, src1, src2 })
+ }
+}
+
+fn texture_geometry_to_vec_length(geometry: ast::TextureGeometry) -> u8 {
+ match geometry {
+ ast::TextureGeometry::OneD | ast::TextureGeometry::Array1D => 1u8,
+ ast::TextureGeometry::TwoD | ast::TextureGeometry::Array2D => 2,
+ ast::TextureGeometry::ThreeD => 4,
+ }
+}
+
+impl<T: ArgParamsEx> ast::Arg4Tex<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ (image_type, image_space): (ast::Type, ast::StateSpace),
+ geometry: ast::TextureGeometry,
+ value_type: ast::Type,
+ coordinate_type: ast::ScalarType,
+ ) -> Result<ast::Arg4Tex<U>, TranslateError> {
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(should_convert_relaxed_dst_wrapper),
+ },
+ &value_type,
+ ast::StateSpace::Reg,
+ )?;
+ let image = visitor.operand(
+ ArgumentDescriptor {
+ op: self.image,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &image_type,
+ image_space,
+ )?;
+ let layer = self
+ .layer
+ .map(|layer| {
+ visitor.operand(
+ ArgumentDescriptor {
+ op: layer,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ )
+ })
+ .transpose()?;
+ let coord_length = texture_geometry_to_vec_length(geometry);
+ let coordinates = visitor.operand(
+ ArgumentDescriptor {
+ op: self.coordinates,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Vector(coordinate_type, coord_length),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(ast::Arg4Tex {
+ dst,
+ image,
+ layer,
+ coordinates,
+ })
+ }
+}
+
+impl<T: ArgParamsEx> ast::Arg4Sust<T> {
+ pub(crate) fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ details: &ast::SurfaceDetails,
+ ) -> Result<ast::Arg4Sust<U>, TranslateError> {
+ let image = visitor.operand(
+ ArgumentDescriptor {
+ op: self.image,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Surfref,
+ ast::StateSpace::Global,
+ )?;
+ let layer = self
+ .layer
+ .map(|layer| {
+ visitor.operand(
+ ArgumentDescriptor {
+ op: layer,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ )
+ })
+ .transpose()?;
+ let coord_length = texture_geometry_to_vec_length(details.geometry);
+ let coordinates = visitor.operand(
+ ArgumentDescriptor {
+ op: self.coordinates,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Vector(ast::ScalarType::B32, coord_length),
+ ast::StateSpace::Reg,
+ )?;
+ let value_type = details.value_type();
+ let value = visitor.operand(
+ ArgumentDescriptor {
+ op: self.value,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(should_convert_relaxed_src_wrapper),
+ },
+ &value_type,
+ ast::StateSpace::Reg,
+ )?;
+ Ok(ast::Arg4Sust {
+ image,
+ coordinates,
+ layer,
+ value,
+ })
+ }
+}
+
+impl ast::TextureGeometry {
+ fn as_ptx(self) -> &'static str {
+ match self {
+ ast::TextureGeometry::OneD => "1d",
+ ast::TextureGeometry::TwoD => "2d",
+ ast::TextureGeometry::ThreeD => "3d",
+ ast::TextureGeometry::Array1D => "a1d",
+ ast::TextureGeometry::Array2D => "a2d",
+ }
+ }
+}
+
+impl ast::SurfaceDetails {
+ fn value_type(&self) -> ast::Type {
+ match self.vector {
+ Some(vec_length) => ast::Type::Vector(self.type_, vec_length),
+ None => ast::Type::Scalar(self.type_),
+ }
+ }
+
+ fn vector_ptx(&self) -> Result<&'static str, TranslateError> {
+ Ok(match self.vector {
+ Some(2) => "_v2",
+ Some(4) => "_v4",
+ Some(_) => return Err(TranslateError::unreachable()),
+ None => "",
+ })
+ }
}
impl<T: ArgParamsEx> ast::Arg4<T> {
@@ -6566,41 +8215,49 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
is_wide: bool,
) -> Result<ast::Arg4<U>, TranslateError> {
let wide_type = if is_wide {
- Some(t.clone().widen()?)
+ t.clone().widen()?
} else {
- None
+ t.clone()
};
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- wide_type.as_ref().unwrap_or(t),
+ &wide_type,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- t,
+ &wide_type,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6613,39 +8270,47 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
fn map_selp<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::SelpType,
+ t: ast::ScalarType,
) -> Result<ast::Arg4<U>, TranslateError> {
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6658,44 +8323,49 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::BitType,
- state_space: ast::AtomSpace,
+ t: ast::ScalarType,
+ state_space: ast::StateSpace,
) -> Result<ast::Arg4<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::PhysicalPointer,
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::Scalar(scalar_type),
- state_space.to_ld_ss(),
- ),
+ &ast::Type::Scalar(scalar_type),
+ state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6714,34 +8384,42 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
let u32_type = ast::Type::Scalar(ast::ScalarType::U32);
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&u32_type,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&u32_type,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6762,9 +8440,13 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
ArgumentDescriptor {
op: self.dst1,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)?;
let dst2 = self
.dst2
@@ -6773,9 +8455,13 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
ArgumentDescriptor {
op: dst2,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)
})
.transpose()?;
@@ -6783,17 +8469,21 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4Setp {
dst1,
@@ -6804,6 +8494,72 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
}
}
+impl<T: ArgParamsEx> ast::Arg5<T> {
+ fn map_bfi<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ base_type: &ast::Type,
+ ) -> Result<ast::Arg5<U>, TranslateError> {
+ let dst = visitor.operand(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ base_type,
+ ast::StateSpace::Reg,
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ base_type,
+ ast::StateSpace::Reg,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ base_type,
+ ast::StateSpace::Reg,
+ )?;
+ let src3 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ )?;
+ let src4 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src4,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(ast::Arg5 {
+ dst,
+ src1,
+ src2,
+ src3,
+ src4,
+ })
+ }
+}
+
impl<T: ArgParamsEx> ast::Arg5Setp<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
@@ -6814,9 +8570,13 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
ArgumentDescriptor {
op: self.dst1,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)?;
let dst2 = self
.dst2
@@ -6825,9 +8585,13 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
ArgumentDescriptor {
op: dst2,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)
})
.transpose()?;
@@ -6835,25 +8599,31 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg5Setp {
dst1,
@@ -6865,6 +8635,80 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
}
}
+impl<T: ArgParamsEx> ast::Arg5Shfl<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ ) -> Result<ast::Arg5Shfl<U>, TranslateError> {
+ let dst1 = visitor.id(
+ ArgumentDescriptor {
+ op: self.dst1,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )),
+ )?;
+ let dst2 = self
+ .dst2
+ .map(|dst2| {
+ visitor.id(
+ ArgumentDescriptor {
+ op: dst2,
+ is_dst: true,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
+ )
+ })
+ .transpose()?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ let src3 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
+ },
+ &ast::Type::Scalar(ast::ScalarType::B32),
+ ast::StateSpace::Reg,
+ )?;
+ Ok(ast::Arg5Shfl {
+ dst1,
+ dst2,
+ src1,
+ src2,
+ src3,
+ })
+ }
+}
+
impl<T> ast::Operand<T> {
fn map_variable<U, F: FnMut(T) -> Result<U, TranslateError>>(
self,
@@ -6875,145 +8719,82 @@ impl<T> ast::Operand<T> {
ast::Operand::RegOffset(reg, offset) => ast::Operand::RegOffset(f(reg)?, offset),
ast::Operand::Imm(x) => ast::Operand::Imm(x),
ast::Operand::VecMember(reg, idx) => ast::Operand::VecMember(f(reg)?, idx),
- ast::Operand::VecPack(vec) => {
- ast::Operand::VecPack(vec.into_iter().map(f).collect::<Result<_, _>>()?)
- }
+ ast::Operand::VecPack(vec) => ast::Operand::VecPack(
+ vec.into_iter()
+ .map(|reg_or_immediate| {
+ Ok::<_, TranslateError>(match reg_or_immediate {
+ ast::RegOrImmediate::Reg(reg) => ast::RegOrImmediate::Reg(f(reg)?),
+ ast::RegOrImmediate::Imm(imm) => ast::RegOrImmediate::Imm(imm),
+ })
+ })
+ .collect::<Result<_, _>>()?,
+ ),
})
}
}
-impl ast::Operand<spirv::Word> {
- fn unwrap_reg(&self) -> Result<spirv::Word, TranslateError> {
- match self {
- ast::Operand::Reg(reg) => Ok(*reg),
- _ => Err(error_unreachable()),
+impl ast::ScalarType {
+ pub(crate) fn widen(self) -> Result<Self, TranslateError> {
+ let kind = self.kind();
+ let width = self.size_of();
+ if (kind != ast::ScalarKind::Signed
+ && kind != ast::ScalarKind::Unsigned
+ && kind != ast::ScalarKind::Bit)
+ || (width == 8)
+ {
+ return Err(TranslateError::mismatched_type());
}
+ Ok(ast::ScalarType::from_parts(width * 2, kind))
}
-}
-impl ast::StStateSpace {
- fn to_ld_ss(self) -> ast::LdStateSpace {
- match self {
- ast::StStateSpace::Generic => ast::LdStateSpace::Generic,
- ast::StStateSpace::Global => ast::LdStateSpace::Global,
- ast::StStateSpace::Local => ast::LdStateSpace::Local,
- ast::StStateSpace::Param => ast::LdStateSpace::Param,
- ast::StStateSpace::Shared => ast::LdStateSpace::Shared,
- }
- }
-}
-
-#[derive(Clone, Copy, PartialEq, Eq)]
-enum ScalarKind {
- Bit,
- Unsigned,
- Signed,
- Float,
- Float2,
- Pred,
-}
-
-impl ast::ScalarType {
- fn kind(self) -> ScalarKind {
- match self {
- ast::ScalarType::U8 => ScalarKind::Unsigned,
- ast::ScalarType::U16 => ScalarKind::Unsigned,
- ast::ScalarType::U32 => ScalarKind::Unsigned,
- ast::ScalarType::U64 => ScalarKind::Unsigned,
- ast::ScalarType::S8 => ScalarKind::Signed,
- ast::ScalarType::S16 => ScalarKind::Signed,
- ast::ScalarType::S32 => ScalarKind::Signed,
- ast::ScalarType::S64 => ScalarKind::Signed,
- ast::ScalarType::B8 => ScalarKind::Bit,
- ast::ScalarType::B16 => ScalarKind::Bit,
- ast::ScalarType::B32 => ScalarKind::Bit,
- ast::ScalarType::B64 => ScalarKind::Bit,
- ast::ScalarType::F16 => ScalarKind::Float,
- ast::ScalarType::F32 => ScalarKind::Float,
- ast::ScalarType::F64 => ScalarKind::Float,
- ast::ScalarType::F16x2 => ScalarKind::Float2,
- ast::ScalarType::Pred => ScalarKind::Pred,
- }
- }
-
- fn from_parts(width: u8, kind: ScalarKind) -> Self {
+ pub(crate) fn from_parts(width: u8, kind: ast::ScalarKind) -> Self {
match kind {
- ScalarKind::Float => match width {
+ ast::ScalarKind::Float => match width {
2 => ast::ScalarType::F16,
4 => ast::ScalarType::F32,
8 => ast::ScalarType::F64,
_ => unreachable!(),
},
- ScalarKind::Bit => match width {
+ ast::ScalarKind::Bit => match width {
1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32,
8 => ast::ScalarType::B64,
_ => unreachable!(),
},
- ScalarKind::Signed => match width {
+ ast::ScalarKind::Signed => match width {
1 => ast::ScalarType::S8,
2 => ast::ScalarType::S16,
4 => ast::ScalarType::S32,
8 => ast::ScalarType::S64,
_ => unreachable!(),
},
- ScalarKind::Unsigned => match width {
+ ast::ScalarKind::Unsigned => match width {
1 => ast::ScalarType::U8,
2 => ast::ScalarType::U16,
4 => ast::ScalarType::U32,
8 => ast::ScalarType::U64,
_ => unreachable!(),
},
- ScalarKind::Float2 => match width {
+ ast::ScalarKind::Float2 => match width {
4 => ast::ScalarType::F16x2,
_ => unreachable!(),
},
- ScalarKind::Pred => ast::ScalarType::Pred,
+ ast::ScalarKind::Pred => ast::ScalarType::Pred,
}
}
}
-impl ast::BooleanType {
- fn to_type(self) -> ast::Type {
- match self {
- ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
- ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
- ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
- ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
- }
- }
-}
-
-impl ast::ShlType {
- fn to_type(self) -> ast::Type {
- match self {
- ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
- ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
- ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
- }
- }
-}
-
-impl ast::ShrType {
- fn signed(&self) -> bool {
+impl ast::ArithDetails {
+ pub(crate) fn get_type(&self) -> ast::ScalarType {
match self {
- ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true,
- _ => false,
+ ast::ArithDetails::Unsigned(t) => *t,
+ ast::ArithDetails::Signed(d) => d.typ,
+ ast::ArithDetails::Float(d) => d.typ,
}
}
}
-impl ast::ArithDetails {
- fn get_type(&self) -> ast::Type {
- ast::Type::Scalar(match self {
- ast::ArithDetails::Unsigned(t) => (*t).into(),
- ast::ArithDetails::Signed(d) => d.typ.into(),
- ast::ArithDetails::Float(d) => d.typ.into(),
- })
- }
-}
-
impl ast::MulDetails {
fn get_type(&self) -> ast::Type {
ast::Type::Scalar(match self {
@@ -7025,12 +8806,12 @@ impl ast::MulDetails {
}
impl ast::MinMaxDetails {
- fn get_type(&self) -> ast::Type {
- ast::Type::Scalar(match self {
- ast::MinMaxDetails::Signed(t) => (*t).into(),
- ast::MinMaxDetails::Unsigned(t) => (*t).into(),
- ast::MinMaxDetails::Float(d) => d.typ.into(),
- })
+ pub(crate) fn get_type(&self) -> ast::ScalarType {
+ match self {
+ ast::MinMaxDetails::Signed(t) => *t,
+ ast::MinMaxDetails::Unsigned(t) => *t,
+ ast::MinMaxDetails::Float(d) => d.typ,
+ }
}
}
@@ -7055,60 +8836,35 @@ impl ast::AtomInnerDetails {
}
}
-impl ast::SIntType {
- fn from_size(width: u8) -> Self {
- match width {
- 1 => ast::SIntType::S8,
- 2 => ast::SIntType::S16,
- 4 => ast::SIntType::S32,
- 8 => ast::SIntType::S64,
- _ => unreachable!(),
- }
+impl ast::StateSpace {
+ fn is_compatible(self, other: ast::StateSpace) -> bool {
+ self == other
+ || self == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
+ || self == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
}
-}
-impl ast::UIntType {
- fn from_size(width: u8) -> Self {
- match width {
- 1 => ast::UIntType::U8,
- 2 => ast::UIntType::U16,
- 4 => ast::UIntType::U32,
- 8 => ast::UIntType::U64,
- _ => unreachable!(),
- }
- }
-}
-
-impl ast::LdStateSpace {
- fn to_spirv(self) -> spirv::StorageClass {
+ fn coerces_to_generic(self) -> bool {
match self {
- ast::LdStateSpace::Const => spirv::StorageClass::UniformConstant,
- ast::LdStateSpace::Generic => spirv::StorageClass::Generic,
- ast::LdStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
- ast::LdStateSpace::Local => spirv::StorageClass::Function,
- ast::LdStateSpace::Shared => spirv::StorageClass::Workgroup,
- ast::LdStateSpace::Param => spirv::StorageClass::Function,
- }
- }
-}
-
-impl From<ast::FnArgumentType> for ast::VariableType {
- fn from(t: ast::FnArgumentType) -> Self {
- match t {
- ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
- ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
- ast::FnArgumentType::Shared => todo!(),
+ ast::StateSpace::Global
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Reg
+ | ast::StateSpace::Param
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Sreg => false,
}
}
-}
-impl<T> ast::Operand<T> {
- fn underlying(&self) -> Option<&T> {
+ fn is_addressable(self) -> bool {
match self {
- ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r),
- ast::Operand::Imm(_) => None,
- ast::Operand::VecMember(reg, _) => Some(reg),
- ast::Operand::VecPack(..) => None,
+ ast::StateSpace::Const
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Global
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared
+ | ast::StateSpace::Param => true,
+ ast::StateSpace::Reg | ast::StateSpace::Sreg => false,
}
}
}
@@ -7123,140 +8879,115 @@ impl ast::MulDetails {
}
}
-impl ast::AtomSpace {
- fn to_ld_ss(self) -> ast::LdStateSpace {
- match self {
- ast::AtomSpace::Generic => ast::LdStateSpace::Generic,
- ast::AtomSpace::Global => ast::LdStateSpace::Global,
- ast::AtomSpace::Shared => ast::LdStateSpace::Shared,
- }
- }
-}
-
-impl ast::MemScope {
- fn to_spirv(self) -> spirv::Scope {
- match self {
- ast::MemScope::Cta => spirv::Scope::Workgroup,
- ast::MemScope::Gpu => spirv::Scope::Device,
- ast::MemScope::Sys => spirv::Scope::CrossDevice,
+impl ast::SurfaceDetails {
+ fn suffix(&self) -> &'static str {
+ match self.direct {
+ true => "",
+ false => "indirect_",
}
}
}
-impl ast::AtomSemantics {
- fn to_spirv(self) -> spirv::MemorySemantics {
- match self {
- ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED,
- ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE,
- ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE,
- ast::AtomSemantics::AcquireRelease => spirv::MemorySemantics::ACQUIRE_RELEASE,
+impl ast::TexDetails {
+ fn suffix(&self) -> &'static str {
+ match self.direct {
+ true => "",
+ false => "_indirect",
}
}
}
-impl ast::FnArgumentType {
- fn semantics(&self) -> ArgumentSemantics {
- match self {
- ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default,
- ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer,
- ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer,
- }
- }
-}
-
-fn bitcast_register_pointer(
- operand_type: &ast::Type,
- instr_type: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+fn default_implicit_conversion(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- bitcast_physical_pointer(operand_type, instr_type, ss)
+ if !instruction_space.is_compatible(operand_space) {
+ default_implicit_conversion_space(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
+ } else if instruction_type != operand_type {
+ default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
+ } else {
+ Ok(None)
+ }
}
-fn bitcast_physical_pointer(
- operand_type: &ast::Type,
- instr_type: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+// Space is different
+fn default_implicit_conversion_space(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- match operand_type {
- // array decays to a pointer
- ast::Type::Array(op_scalar_t, _) => {
- if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
- if ss == Some(*instr_space) {
- if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) {
- Ok(None)
- } else {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- }
+ if (instruction_space == ast::StateSpace::Generic && operand_space.coerces_to_generic())
+ || (operand_space == ast::StateSpace::Generic && instruction_space.coerces_to_generic())
+ {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else if operand_space.is_compatible(ast::StateSpace::Reg) {
+ match operand_type {
+ ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
+ if *operand_ptr_space == instruction_space =>
+ {
+ if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
} else {
- if ss == Some(ast::LdStateSpace::Generic)
- || *instr_space == ast::LdStateSpace::Generic
- {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- } else {
- Err(TranslateError::MismatchedType)
- }
+ Ok(None)
}
- } else {
- Err(TranslateError::MismatchedType)
}
- }
- ast::Type::Scalar(ast::ScalarType::B64)
- | ast::Type::Scalar(ast::ScalarType::U64)
- | ast::Type::Scalar(ast::ScalarType::S64) => {
- if let Some(space) = ss {
- Ok(Some(ConversionKind::BitToPtr(space)))
- } else {
- Err(error_unreachable())
- }
- }
- ast::Type::Scalar(ast::ScalarType::B32)
- | ast::Type::Scalar(ast::ScalarType::U32)
- | ast::Type::Scalar(ast::ScalarType::S32) => match ss {
- Some(ast::LdStateSpace::Shared)
- | Some(ast::LdStateSpace::Generic)
- | Some(ast::LdStateSpace::Param)
- | Some(ast::LdStateSpace::Local) => {
- Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared)))
- }
- _ => Err(TranslateError::MismatchedType),
- },
- ast::Type::Pointer(op_scalar_t, op_space) => {
- if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
- if op_space == instr_space {
- if op_scalar_t == instr_scalar_t {
- Ok(None)
- } else {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- }
+ // TODO: 32 bit
+ ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
+ ast::StateSpace::Global
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Param
+ | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
+ _ => Err(TranslateError::mismatched_type()),
+ },
+ ast::Type::Scalar(ast::ScalarType::B32)
+ | ast::Type::Scalar(ast::ScalarType::U32)
+ | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
+ ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
+ Ok(Some(ConversionKind::BitToPtr))
+ }
+ _ => Err(TranslateError::mismatched_type()),
+ },
+ _ => Err(TranslateError::mismatched_type()),
+ }
+ } else if instruction_space.is_compatible(ast::StateSpace::Reg) {
+ match instruction_type {
+ ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
+ if operand_space == *instruction_ptr_space =>
+ {
+ if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
} else {
- if *op_space == ast::LdStateSpace::Generic
- || *instr_space == ast::LdStateSpace::Generic
- {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- } else {
- Err(TranslateError::MismatchedType)
- }
+ Ok(None)
}
- } else {
- Err(TranslateError::MismatchedType)
}
+ _ => Err(TranslateError::mismatched_type()),
}
- _ => Err(TranslateError::MismatchedType),
+ } else {
+ Err(TranslateError::mismatched_type())
}
}
-fn force_bitcast_ptr_to_bit(
- _: &ast::Type,
- instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+// Space is same, but type is different
+fn default_implicit_conversion_type(
+ space: ast::StateSpace,
+ operand_type: &ast::Type,
+ instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> {
- // TODO: verify this on f32, u16 and the like
- if let ast::Type::Scalar(scalar_t) = instr_type {
- if let Ok(int_type) = (*scalar_t).try_into() {
- return Ok(Some(ConversionKind::PtrToBit(int_type)));
+ if space.is_compatible(ast::StateSpace::Reg) {
+ if should_bitcast(instruction_type, operand_type) {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Err(TranslateError::mismatched_type())
}
+ } else {
+ Ok(Some(ConversionKind::PtrToPtr))
}
- Err(TranslateError::MismatchedType)
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
@@ -7266,18 +8997,28 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
return false;
}
match inst.kind() {
- ScalarKind::Bit => operand.kind() != ScalarKind::Bit,
- ScalarKind::Float => operand.kind() == ScalarKind::Bit,
- ScalarKind::Signed => {
- operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned
+ ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
+ ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
+ ast::ScalarKind::Signed => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Unsigned
}
- ScalarKind::Unsigned => {
- operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed
+ ast::ScalarKind::Unsigned => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Signed
+ || operand.kind() == ast::ScalarKind::Float2
}
- ScalarKind::Float2 => false,
- ScalarKind::Pred => false,
+ ast::ScalarKind::Float2 => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Unsigned
+ }
+ ast::ScalarKind::Pred => false,
}
}
+ (ast::Type::Scalar(scalar), ast::Type::Vector(vector, width))
+ | (ast::Type::Vector(vector, width), ast::Type::Scalar(scalar)) => {
+ scalar.kind() == ast::ScalarKind::Bit && *width * vector.size_of() == scalar.size_of()
+ }
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
| (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => {
should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
@@ -7286,49 +9027,61 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
}
}
-fn should_bitcast_packed(
- operand: &ast::Type,
- instr: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+fn implicit_conversion_mov(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
- (operand, instr)
- {
- if scalar.kind() == ScalarKind::Bit
- && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ // instruction_space is always reg
+ if operand_space.is_compatible(ast::StateSpace::Reg) {
+ if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
+ (operand_type, instruction_type)
{
- return Ok(Some(ConversionKind::Default));
+ if scalar.kind() == ast::ScalarKind::Bit
+ && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ {
+ return Ok(Some(ConversionKind::Default));
+ }
}
+ // TODO: verify .params addressability:
+ // * kernel arg
+ // * func arg
+ // * variable
}
- should_bitcast_wrapper(operand, instr, ss)
+ if is_addressable(operand_type, operand_space) {
+ return Ok(Some(ConversionKind::AddressOf));
+ }
+ default_implicit_conversion(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
}
-fn should_bitcast_wrapper(
- operand: &ast::Type,
- instr: &ast::Type,
- _: Option<ast::LdStateSpace>,
-) -> Result<Option<ConversionKind>, TranslateError> {
- if instr == operand {
- return Ok(None);
+fn is_addressable(type_: &ast::Type, state_space: ast::StateSpace) -> bool {
+ if state_space.is_addressable() {
+ return true;
}
- if should_bitcast(instr, operand) {
- Ok(Some(ConversionKind::Default))
- } else {
- Err(TranslateError::MismatchedType)
+ if !state_space.is_compatible(ast::StateSpace::Reg) {
+ return false;
+ }
+ match type_ {
+ ast::Type::Pointer(_, space) => space.is_addressable(),
+ _ => false,
}
}
fn should_convert_relaxed_src_wrapper(
- src_type: &ast::Type,
- instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if src_type == instr_type {
+ if !operand_space.is_compatible(instruction_space) {
+ return Err(TranslateError::mismatched_type());
+ }
+ if operand_type == instruction_type {
return Ok(None);
}
- match should_convert_relaxed_src(src_type, instr_type) {
+ match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
- None => Err(TranslateError::MismatchedType),
+ None => Err(TranslateError::mismatched_type()),
}
}
@@ -7342,32 +9095,33 @@ fn should_convert_relaxed_src(
}
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
- ScalarKind::Bit => {
+ ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Signed | ScalarKind::Unsigned => {
+ ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of()
- && src_type.kind() != ScalarKind::Float
+ && src_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float => {
- if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float2 => todo!(),
- ScalarKind::Pred => None,
+ ast::ScalarKind::Float2 => todo!(),
+ ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
@@ -7381,16 +9135,18 @@ fn should_convert_relaxed_src(
}
fn should_convert_relaxed_dst_wrapper(
- dst_type: &ast::Type,
- instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if dst_type == instr_type {
+ if !operand_space.is_compatible(instruction_space) {
+ return Err(TranslateError::mismatched_type());
+ }
+ if operand_type == instruction_type {
return Ok(None);
}
- match should_convert_relaxed_dst(dst_type, instr_type) {
+ match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
- None => Err(TranslateError::MismatchedType),
+ None => Err(TranslateError::mismatched_type()),
}
}
@@ -7404,15 +9160,15 @@ fn should_convert_relaxed_dst(
}
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
- ScalarKind::Bit => {
+ ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Signed => {
- if dst_type.kind() != ScalarKind::Float {
+ ast::ScalarKind::Signed => {
+ if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() {
@@ -7424,25 +9180,26 @@ fn should_convert_relaxed_dst(
None
}
}
- ScalarKind::Unsigned => {
+ ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of()
- && dst_type.kind() != ScalarKind::Float
+ && dst_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float => {
- if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float2 => todo!(),
- ScalarKind::Pred => None,
+ ast::ScalarKind::Float2 => todo!(),
+ ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
@@ -7455,77 +9212,35 @@ fn should_convert_relaxed_dst(
}
}
-impl<'a> ast::MethodDecl<'a, &'a str> {
+impl<'a> ast::MethodDeclaration<'a, &'a str> {
fn name(&self) -> &'a str {
- match self {
- ast::MethodDecl::Kernel { name, .. } => name,
- ast::MethodDecl::Func(_, name, _) => name,
+ match self.name {
+ ast::MethodName::Kernel(name) => name,
+ ast::MethodName::Func(name) => name,
}
}
}
-struct SpirvMethodDecl<'input> {
- input: Vec<ast::Variable<ast::Type, spirv::Word>>,
- output: Vec<ast::Variable<ast::Type, spirv::Word>>,
- name: MethodName<'input>,
- uses_shared_mem: bool,
+#[derive(Copy, Clone)]
+pub(crate) enum ConstType<'a> {
+ Type(&'a ast::Type),
+ ArraySubtype(ast::ScalarType, &'a [u32]),
}
-impl<'input> SpirvMethodDecl<'input> {
- fn new(ast_decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
- let (input, output) = match ast_decl {
- ast::MethodDecl::Kernel { in_args, .. } => {
- let spirv_input = in_args
- .iter()
- .map(|var| {
- let v_type = match &var.v_type {
- ast::KernelArgumentType::Normal(t) => {
- ast::FnArgumentType::Param(t.clone())
- }
- ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared,
- };
- ast::Variable {
- name: var.name,
- align: var.align,
- v_type: v_type.to_kernel_type(),
- array_init: var.array_init.clone(),
- }
- })
- .collect();
- (spirv_input, Vec::new())
- }
- ast::MethodDecl::Func(out_args, _, in_args) => {
- let (param_output, non_param_output): (Vec<_>, Vec<_>) =
- out_args.iter().partition(|var| var.v_type.is_param());
- let spirv_output = non_param_output
- .into_iter()
- .cloned()
- .map(|var| ast::Variable {
- name: var.name,
- align: var.align,
- v_type: var.v_type.to_func_type(),
- array_init: var.array_init.clone(),
- })
- .collect();
- let spirv_input = param_output
- .into_iter()
- .cloned()
- .chain(in_args.iter().cloned())
- .map(|var| ast::Variable {
- name: var.name,
- align: var.align,
- v_type: var.v_type.to_func_type(),
- array_init: var.array_init.clone(),
- })
- .collect();
- (spirv_input, spirv_output)
- }
- };
- SpirvMethodDecl {
- input,
- output,
- name: MethodName::new(ast_decl),
- uses_shared_mem: false,
+impl ast::ScalarType {
+ pub fn is_integer(self) -> bool {
+ match self.kind() {
+ ast::ScalarKind::Unsigned | ast::ScalarKind::Signed | ast::ScalarKind::Bit => true,
+ ast::ScalarKind::Float | ast::ScalarKind::Float2 | ast::ScalarKind::Pred => false,
+ }
+ }
+}
+
+impl ast::ReductionOp {
+ fn dst_type(self) -> ast::ScalarType {
+ match self {
+ ast::ReductionOp::And | ast::ReductionOp::Or => ast::ScalarType::Pred,
+ ast::ReductionOp::Popc => ast::ScalarType::U32,
}
}
}
@@ -7639,4 +9354,39 @@ mod tests {
fn should_convert_relaxed_dst_all_combinations() {
assert_conversion_table(RELAXED_DST_CONVERSION_TABLE, should_convert_relaxed_dst);
}
+
+ #[test]
+ fn returns_correct_layout_for_array_params() {
+ use crate::{ModuleParser, ModuleParserExt};
+
+ let ptx = r#"
+ .version 6.5
+ .target sm_30
+ .address_size 64
+
+ .visible .entry kernel(.param .align 8 .b8 kernel_param[72])
+ {
+ ret;
+ }"#;
+
+ let ast = ModuleParser::parse_checked(ptx).unwrap();
+ if let ast::Directive::Method(
+ _,
+ ast::Function {
+ func_directive:
+ ast::MethodDeclaration {
+ input_arguments, ..
+ },
+ ..
+ },
+ ) = &ast.directives[0]
+ {
+ assert_eq!(input_arguments.len(), 1);
+ assert_eq!(input_arguments[0].layout(), unsafe {
+ Layout::from_size_align_unchecked(72, 8)
+ });
+ } else {
+ panic!()
+ }
+ }
}