diff options
author | Andrzej Janik <[email protected]> | 2020-10-02 20:34:45 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-10-04 19:53:07 +0200 |
commit | 27d25865af2bf51ca55b223e634208234d1a141a (patch) | |
tree | 695f081f09cd22ffbc04effa0e947abba82bc50d | |
parent | 9a65dd32f5898eb9dd3edf7cdddb1513a7a754ed (diff) | |
download | ZLUDA-27d25865af2bf51ca55b223e634208234d1a141a.tar.gz ZLUDA-27d25865af2bf51ca55b223e634208234d1a141a.zip |
Add support for top-level global variables, improve array support
-rw-r--r-- | level_zero/src/ze.rs | 57 | ||||
-rw-r--r-- | notcuda/src/impl/function.rs | 2 | ||||
-rw-r--r-- | notcuda/src/impl/memory.rs | 2 | ||||
-rw-r--r-- | notcuda/src/impl/mod.rs | 2 | ||||
-rw-r--r-- | notcuda/src/impl/module.rs | 11 | ||||
-rw-r--r-- | notcuda/src/impl/stream.rs | 4 | ||||
-rw-r--r-- | notcuda/src/impl/test.rs | 2 | ||||
-rw-r--r-- | ptx/src/ast.rs | 265 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 153 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/global_array.ptx | 22 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/global_array.spvtxt | 54 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 18 | ||||
-rw-r--r-- | ptx/src/translate.rs | 871 |
13 files changed, 1085 insertions, 378 deletions
diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs index 559805e..5ced5d0 100644 --- a/level_zero/src/ze.rs +++ b/level_zero/src/ze.rs @@ -1,6 +1,6 @@ use crate::sys;
use std::{
- ffi::{c_void, CStr},
+ ffi::{c_void, CStr, CString},
fmt::Debug,
marker::PhantomData,
mem, ptr,
@@ -238,23 +238,16 @@ impl Drop for CommandQueue { pub struct Module(sys::ze_module_handle_t);
impl Module {
- pub unsafe fn as_ffi(&self) -> sys::ze_module_handle_t {
- self.0
- }
- pub unsafe fn from_ffi(x: sys::ze_module_handle_t) -> Self {
- Self(x)
- }
-
pub fn new_spirv(
ctx: &mut Context,
d: &Device,
bin: &[u8],
opts: Option<&CStr>,
- ) -> Result<Self> {
+ ) -> (Result<Self>, BuildLog) {
Module::new(ctx, true, d, bin, opts)
}
- pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> Result<Self> {
+ pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> (Result<Self>, BuildLog) {
Module::new(ctx, false, d, bin, None)
}
@@ -264,7 +257,7 @@ impl Module { d: &Device,
bin: &[u8],
opts: Option<&CStr>,
- ) -> Result<Self> {
+ ) -> (Result<Self>, BuildLog) {
let desc = sys::ze_module_desc_t {
stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_MODULE_DESC,
pNext: ptr::null(),
@@ -279,14 +272,14 @@ impl Module { pConstants: ptr::null(),
};
let mut result: sys::ze_module_handle_t = ptr::null_mut();
- check!(sys::zeModuleCreate(
- ctx.0,
- d.0,
- &desc,
- &mut result,
- ptr::null_mut()
- ));
- Ok(Module(result))
+ let mut log_handle = ptr::null_mut();
+ let err = unsafe { sys::zeModuleCreate(ctx.0, d.0, &desc, &mut result, &mut log_handle) };
+ let log = BuildLog(log_handle);
+ if err != crate::sys::ze_result_t::ZE_RESULT_SUCCESS {
+ (Result::Err(err), log)
+ } else {
+ (Ok(Module(result)), log)
+ }
}
}
@@ -297,6 +290,32 @@ impl Drop for Module { }
}
+pub struct BuildLog(sys::ze_module_build_log_handle_t);
+
+impl BuildLog {
+ pub unsafe fn as_ffi(&self) -> sys::ze_module_build_log_handle_t {
+ self.0
+ }
+ pub unsafe fn from_ffi(x: sys::ze_module_build_log_handle_t) -> Self {
+ Self(x)
+ }
+
+ pub fn get_cstring(&self) -> Result<CString> {
+ let mut size = 0;
+ check! { sys::zeModuleBuildLogGetString(self.0, &mut size, ptr::null_mut()) };
+ let mut str_vec = vec![0u8; size];
+ check! { sys::zeModuleBuildLogGetString(self.0, &mut size, str_vec.as_mut_ptr() as *mut i8) };
+ str_vec.pop();
+ Ok(CString::new(str_vec).map_err(|_| sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?)
+ }
+}
+
+impl Drop for BuildLog {
+ fn drop(&mut self) {
+ check_panic!(sys::zeModuleBuildLogDestroy(self.0));
+ }
+}
+
pub trait SafeRepr {}
impl SafeRepr for u8 {}
impl SafeRepr for i8 {}
diff --git a/notcuda/src/impl/function.rs b/notcuda/src/impl/function.rs index 6f8773e..0ab3bea 100644 --- a/notcuda/src/impl/function.rs +++ b/notcuda/src/impl/function.rs @@ -1,7 +1,7 @@ use ::std::os::raw::{c_uint, c_void}; use std::ptr; -use super::{context, device, stream::Stream, CUresult}; +use super::{device, stream::Stream, CUresult}; pub struct Function { pub base: l0::Kernel<'static>, diff --git a/notcuda/src/impl/memory.rs b/notcuda/src/impl/memory.rs index 3f92b5e..439b26f 100644 --- a/notcuda/src/impl/memory.rs +++ b/notcuda/src/impl/memory.rs @@ -46,7 +46,7 @@ unsafe fn memcpy_impl( Ok(())
}
-pub(crate) fn free_v2(mem: *mut c_void)-> l0::Result<()> {
+pub(crate) fn free_v2(_: *mut c_void)-> l0::Result<()> {
Ok(())
}
diff --git a/notcuda/src/impl/mod.rs b/notcuda/src/impl/mod.rs index 3d31da2..5a72ce4 100644 --- a/notcuda/src/impl/mod.rs +++ b/notcuda/src/impl/mod.rs @@ -1,4 +1,4 @@ -use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUfunction, CUmod_st, CUmodule, CUresult, CUstream, CUstream_st}; +use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st}; use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex}; #[cfg(test)] diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs index fc55f33..eea862b 100644 --- a/notcuda/src/impl/module.rs +++ b/notcuda/src/impl/module.rs @@ -1,13 +1,10 @@ use std::{ - collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, - sync::Mutex, + collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex, }; use super::{function::Function, transmute_lifetime, CUresult}; use ptx; -use super::context; - pub type Module = Mutex<ModuleData>; pub struct ModuleData { @@ -67,14 +64,14 @@ impl ModuleData { l0::Module::new_spirv(&mut dev.l0_context, &dev.base, byte_il, None) }); match module { - Ok(Ok(module)) => Ok(Mutex::new(Self { + Ok((Ok(module), _)) => Ok(Mutex::new(Self { base: module, arg_lens: all_arg_lens .into_iter() .map(|(k, v)| (CString::new(k).unwrap(), v)) .collect(), })), - Ok(Err(err)) => Err(ModuleCompileError::from(err)), + Ok((Err(err), _)) => Err(ModuleCompileError::from(err)), Err(err) => Err(ModuleCompileError::from(err)), } } @@ -116,6 +113,6 @@ pub fn get_function( Ok(()) } -pub(crate) fn unload(decuda: *mut Module) -> Result<(), CUresult> { +pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> { Ok(()) } diff --git a/notcuda/src/impl/stream.rs b/notcuda/src/impl/stream.rs index 7410100..1844677 100644 --- a/notcuda/src/impl/stream.rs +++ b/notcuda/src/impl/stream.rs @@ -30,7 +30,7 @@ mod tests { use super::super::test::CudaDriverFns; use super::super::CUresult; - use std::{ffi::c_void, ptr}; + use std::ptr; const CU_STREAM_LEGACY: CUstream = 1 as *mut _; const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _; @@ -41,7 +41,7 @@ mod tests { fn default_stream_uses_current_ctx_legacy<T: CudaDriverFns>() { default_stream_uses_current_ctx_impl::<T>(CU_STREAM_LEGACY); } - + fn default_stream_uses_current_ctx_ptsd<T: CudaDriverFns>() { default_stream_uses_current_ctx_impl::<T>(CU_STREAM_PER_THREAD); } diff --git a/notcuda/src/impl/test.rs b/notcuda/src/impl/test.rs index d4366b7..dbd2eff 100644 --- a/notcuda/src/impl/test.rs +++ b/notcuda/src/impl/test.rs @@ -1,6 +1,6 @@ #![allow(non_snake_case)] -use crate::{cuda::CUcontext, cuda::CUstream, r#impl as notcuda}; +use crate::{cuda::CUstream, r#impl as notcuda}; use crate::r#impl::CUresult; use crate::{cuda::CUuuid, r#impl::Encuda}; use ::std::{ diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 048d43a..c6510da 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,6 +1,8 @@ -use std::convert::From; +use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; +use half::f16; + quick_error! { #[derive(Debug)] pub enum PtxError { @@ -9,11 +11,17 @@ quick_error! { display("{}", err) cause(err) } + ParseFloat (err: ParseFloatError) { + from() + display("{}", err) + cause(err) + } SyntaxError {} NonF32Ftz {} WrongArrayType {} WrongVectorElement {} MultiArrayVariable {} + ZeroDimensionArray {} } } @@ -53,7 +61,7 @@ macro_rules! sub_scalar_type { macro_rules! sub_type { ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - #[derive(PartialEq, Eq, Clone, Copy)] + #[derive(PartialEq, Eq, Clone)] pub enum $type_name { $( $variant ($($field_type),+), @@ -80,11 +88,13 @@ sub_type! { } } +type VecU32 = Vec<u32>; + sub_type! { VariableLocalType { Scalar(SizedScalarType), Vector(SizedScalarType, u8), - Array(SizedScalarType, u32), + Array(SizedScalarType, VecU32), } } @@ -95,7 +105,7 @@ sub_type! { sub_type! { VariableParamType { Scalar(ParamScalarType), - Array(SizedScalarType, u32), + Array(SizedScalarType, VecU32), } } @@ -169,7 +179,12 @@ impl< pub struct Module<'a> { pub version: (u8, u8), - pub functions: Vec<ParsedFunction<'a>>, + pub directives: Vec<Directive<'a, ParsedArgParams<'a>>>, +} + +pub enum Directive<'a, P: ArgParams> { + Variable(Variable<VariableType, P::Id>), + Method(Function<'a, &'a str, Statement<P>>), } pub enum MethodDecl<'a, ID> { @@ -187,7 +202,7 @@ pub struct Function<'a, ID, S> { pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>; -#[derive(PartialEq, Eq, Clone, Copy)] +#[derive(PartialEq, Eq, Clone)] pub enum FnArgumentType { Reg(VariableRegType), Param(VariableParamType), @@ -202,11 +217,11 @@ impl From<FnArgumentType> for Type { } } -#[derive(PartialEq, Eq, Hash, Clone, Copy)] +#[derive(PartialEq, Eq, Hash, Clone)] pub enum Type { Scalar(ScalarType), Vector(ScalarType, u8), - Array(ScalarType, u32), + Array(ScalarType, Vec<u32>), } #[derive(PartialEq, Eq, Hash, Clone, Copy)] @@ -274,6 +289,30 @@ sub_scalar_type!(FloatType { F64 }); +impl ScalarType { + pub fn size_of(self) -> u8 { + match self { + ScalarType::U8 => 1, + ScalarType::S8 => 1, + ScalarType::B8 => 1, + ScalarType::U16 => 2, + ScalarType::S16 => 2, + ScalarType::B16 => 2, + ScalarType::F16 => 2, + ScalarType::U32 => 4, + ScalarType::S32 => 4, + ScalarType::B32 => 4, + ScalarType::F32 => 4, + ScalarType::U64 => 8, + ScalarType::S64 => 8, + ScalarType::B64 => 8, + ScalarType::F64 => 8, + ScalarType::F16x2 => 4, + ScalarType::Pred => 1, + } + } +} + impl Default for ScalarType { fn default() -> Self { ScalarType::B8 @@ -296,13 +335,26 @@ pub struct Variable<T, ID> { pub align: Option<u32>, pub v_type: T, pub name: ID, + pub array_init: Vec<u8>, } -#[derive(Eq, PartialEq, Copy, Clone)] +#[derive(Eq, PartialEq, Clone)] pub enum VariableType { Reg(VariableRegType), Local(VariableLocalType), Param(VariableParamType), + Global(VariableLocalType), +} + +impl VariableType { + pub fn to_type(&self) -> (StateSpace, Type) { + match self { + VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()), + VariableType::Local(t) => (StateSpace::Local, t.clone().into()), + VariableType::Param(t) => (StateSpace::Param, t.clone().into()), + VariableType::Global(t) => (StateSpace::Global, t.clone().into()), + } + } } impl From<VariableType> for Type { @@ -311,6 +363,7 @@ impl From<VariableType> for Type { VariableType::Reg(t) => t.into(), VariableType::Local(t) => t.into(), VariableType::Param(t) => t.into(), + VariableType::Global(t) => t.into(), } } } @@ -318,7 +371,6 @@ impl From<VariableType> for Type { #[derive(Copy, Clone, PartialEq, Eq)] pub enum StateSpace { Reg, - Sreg, Const, Global, Local, @@ -538,7 +590,7 @@ pub enum LdCacheOperator { Uncached, } -#[derive(Copy, Clone)] +#[derive(Clone)] pub struct MovDetails { pub typ: Type, pub src_is_address: bool, @@ -846,3 +898,194 @@ pub struct MinMaxFloat { pub nan: bool, pub typ: FloatType, } + +pub enum NumsOrArrays<'a> { + Nums(Vec<&'a str>), + Arrays(Vec<NumsOrArrays<'a>>), +} + +impl<'a> NumsOrArrays<'a> { + pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> { + self.normalize_dimensions(dimensions)?; + let sizeof_t = ScalarType::from(typ).size_of() as usize; + let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize)); + let mut result = vec![0; result_size]; + self.parse_and_copy(typ, sizeof_t, dimensions, &mut result)?; + Ok(result) + } + + fn normalize_dimensions(&self, dimensions: &mut [u32]) -> Result<(), PtxError> { + match dimensions.first_mut() { + Some(first) => { + if *first == 0 { + *first = match self { + NumsOrArrays::Nums(v) => v.len() as u32, + NumsOrArrays::Arrays(v) => v.len() as u32, + }; + } + } + None => return Err(PtxError::ZeroDimensionArray), + } + for dim in dimensions { + if *dim == 0 { + return Err(PtxError::ZeroDimensionArray); + } + } + Ok(()) + } + + fn parse_and_copy( + &self, + t: SizedScalarType, + size_of_t: usize, + dimensions: &[u32], + result: &mut [u8], + ) -> Result<(), PtxError> { + match dimensions { + [] => unreachable!(), + [dim] => match self { + NumsOrArrays::Nums(vec) => { + if vec.len() > *dim as usize { + return Err(PtxError::ZeroDimensionArray); + } + for (idx, val) in vec.iter().enumerate() { + Self::parse_and_copy_single(t, idx, val, result)?; + } + } + NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray), + }, + [first_dim, rest @ ..] => match self { + NumsOrArrays::Arrays(vec) => { + if vec.len() > *first_dim as usize { + return Err(PtxError::ZeroDimensionArray); + } + let size_of_element = rest.iter().fold(size_of_t, |x, y| x * (*y as usize)); + for (idx, this) in vec.iter().enumerate() { + this.parse_and_copy( + t, + size_of_t, + rest, + &mut result[(size_of_element * idx)..], + )?; + } + } + NumsOrArrays::Nums(_) => return Err(PtxError::ZeroDimensionArray), + }, + } + Ok(()) + } + + fn parse_and_copy_single( + t: SizedScalarType, + idx: usize, + str_val: &str, + output: &mut [u8], + ) -> Result<(), PtxError> { + match t { + SizedScalarType::B8 | SizedScalarType::U8 => { + Self::parse_and_copy_single_t::<u8>(idx, str_val, output)?; + } + SizedScalarType::B16 | SizedScalarType::U16 => { + Self::parse_and_copy_single_t::<u16>(idx, str_val, output)?; + } + SizedScalarType::B32 | SizedScalarType::U32 => { + Self::parse_and_copy_single_t::<u32>(idx, str_val, output)?; + } + SizedScalarType::B64 | SizedScalarType::U64 => { + Self::parse_and_copy_single_t::<u64>(idx, str_val, output)?; + } + SizedScalarType::S8 => { + Self::parse_and_copy_single_t::<i8>(idx, str_val, output)?; + } + SizedScalarType::S16 => { + Self::parse_and_copy_single_t::<i16>(idx, str_val, output)?; + } + SizedScalarType::S32 => { + Self::parse_and_copy_single_t::<i32>(idx, str_val, output)?; + } + SizedScalarType::S64 => { + Self::parse_and_copy_single_t::<i64>(idx, str_val, output)?; + } + SizedScalarType::F16 => { + Self::parse_and_copy_single_t::<f16>(idx, str_val, output)?; + } + SizedScalarType::F16x2 => todo!(), + SizedScalarType::F32 => { + Self::parse_and_copy_single_t::<f32>(idx, str_val, output)?; + } + SizedScalarType::F64 => { + Self::parse_and_copy_single_t::<f64>(idx, str_val, output)?; + } + } + Ok(()) + } + + fn parse_and_copy_single_t<T: Copy + FromStr>( + idx: usize, + str_val: &str, + output: &mut [u8], + ) -> Result<(), PtxError> + where + T::Err: Into<PtxError>, + { + let typed_output = unsafe { + std::slice::from_raw_parts_mut::<T>( + output.as_mut_ptr() as *mut _, + output.len() / mem::size_of::<T>(), + ) + }; + typed_output[idx] = str_val.parse::<T>().map_err(|e| e.into())?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn array_fails_multiple_0_dmiensions() { + let inp = NumsOrArrays::Nums(Vec::new()); + assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err()); + } + + #[test] + fn array_fails_on_empty() { + let inp = NumsOrArrays::Nums(Vec::new()); + assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err()); + } + + #[test] + fn array_auto_sizes_0_dimension() { + let inp = NumsOrArrays::Arrays(vec![ + NumsOrArrays::Nums(vec!["1", "2"]), + NumsOrArrays::Nums(vec!["3", "4"]), + ]); + let mut dimensions = vec![0u32, 2]; + assert_eq!( + vec![1u8, 2, 3, 4], + inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap() + ); + assert_eq!(dimensions, vec![2u32, 2]); + } + + #[test] + fn array_fails_wrong_structure() { + let inp = NumsOrArrays::Arrays(vec![ + NumsOrArrays::Nums(vec!["1", "2"]), + NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec!["1"])]), + ]); + let mut dimensions = vec![0u32, 2]; + assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); + } + + #[test] + fn array_fails_too_long_component() { + let inp = NumsOrArrays::Arrays(vec![ + NumsOrArrays::Nums(vec!["1", "2", "3"]), + NumsOrArrays::Nums(vec!["4", "5"]), + ]); + let mut dimensions = vec![0u32, 2]; + assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); + } +} diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 2c0e365..0b6fa0f 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -2,6 +2,8 @@ use crate::ast; use crate::ast::UnwrapWithVec; use crate::{without_none, vector_index}; +use lalrpop_util::ParseError; + grammar<'a>(errors: &mut Vec<ast::PtxError>); extern { @@ -27,6 +29,7 @@ match { "{", "}", "<", ">", "|", + "=", ".acquire", ".address_size", ".align", @@ -94,7 +97,6 @@ match { ".sat", ".section", ".shared", - ".sreg", ".sys", ".target", ".to", @@ -176,8 +178,8 @@ ExtendedID : &'input str = { } pub Module: ast::Module<'input> = { - <v:Version> Target <f:Directive*> => { - ast::Module { version: v, functions: without_none(f) } + <v:Version> Target <d:Directive*> => { + ast::Module { version: v, directives: without_none(d) } } }; @@ -203,11 +205,12 @@ TargetSpecifier = { "map_f64_to_f32" }; -Directive: Option<ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>>> = { +Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = { AddressSize => None, - <f:Function> => Some(f), + <f:Function> => Some(ast::Directive::Method(f)), File => None, - Section => None + Section => None, + <v:GlobalVariable> ";" => Some(ast::Directive::Variable(v)), }; AddressSize = { @@ -242,9 +245,9 @@ FnArguments: Vec<ast::FnArgument<&'input str>> = { }; KernelInput: ast::Variable<ast::VariableParamType, &'input str> = { - <v:ParamVariable> => { + <v:ParamDeclaration> => { let (align, v_type, name) = v; - ast::Variable{ align, v_type, name } + ast::Variable{ align, v_type, name, array_init: Vec::new() } } } @@ -252,12 +255,12 @@ FnInput: ast::Variable<ast::FnArgumentType, &'input str> = { <v:RegVariable> => { let (align, v_type, name) = v; let v_type = ast::FnArgumentType::Reg(v_type); - ast::Variable{ align, v_type, name } + ast::Variable{ align, v_type, name, array_init: Vec::new() } }, - <v:ParamVariable> => { + <v:ParamDeclaration> => { let (align, v_type, name) = v; let v_type = ast::FnArgumentType::Param(v_type); - ast::Variable{ align, v_type, name } + ast::Variable{ align, v_type, name, array_init: Vec::new() } } } @@ -268,7 +271,6 @@ pub(crate) FunctionBody: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>> StateSpaceSpecifier: ast::StateSpace = { ".reg" => ast::StateSpace::Reg, - ".sreg" => ast::StateSpace::Sreg, ".const" => ast::StateSpace::Const, ".global" => ast::StateSpace::Global, ".local" => ast::StateSpace::Local, @@ -344,13 +346,13 @@ Variable: ast::Variable<ast::VariableType, &'input str> = { <v:RegVariable> => { let (align, v_type, name) = v; let v_type = ast::VariableType::Reg(v_type); - ast::Variable {align, v_type, name} + ast::Variable {align, v_type, name, array_init: Vec::new()} }, LocalVariable, <v:ParamVariable> => { - let (align, v_type, name) = v; + let (align, array_init, v_type, name) = v; let v_type = ast::VariableType::Param(v_type); - ast::Variable {align, v_type, name} + ast::Variable {align, v_type, name, array_init} }, }; @@ -366,32 +368,60 @@ RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = { } LocalVariable: ast::Variable<ast::VariableType, &'input str> = { - ".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> => { - let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t)); - ast::Variable {align, v_type, name} - }, - ".local" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => { - let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len)); - ast::Variable {align, v_type, name} - }, - ".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => { - let v_type = ast::VariableType::Local(ast::VariableLocalType::Array(t, arr)); - ast::Variable {align, v_type, name} + ".local" <def:LocalVariableDefinition> => { + let (align, array_init, v_type, name) = def; + ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init } + } +} + +GlobalVariable: ast::Variable<ast::VariableType, &'input str> = { + ".global" <def:LocalVariableDefinition> => { + let (align, array_init, v_type, name) = def; + ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init } } } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space -ParamVariable: (Option<u32>, ast::VariableParamType, &'input str) = { +ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = { + ".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => { + let v_type = ast::VariableParamType::Scalar(t); + (align, Vec::new(), v_type, name) + }, + ".param" <align:Align?> <arr:ArrayDefinition> => { + let (array_init, name, (t, dimensions)) = arr; + let v_type = ast::VariableParamType::Array(t, dimensions); + (align, array_init, v_type, name) + } +} + +ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = { ".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => { let v_type = ast::VariableParamType::Scalar(t); (align, v_type, name) }, - ".param" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => { - let v_type = ast::VariableParamType::Array(t, arr); + ".param" <align:Align?> <arr:ArrayDeclaration> => { + let (name, (t, dimensions)) = arr; + let v_type = ast::VariableParamType::Array(t, dimensions); (align, v_type, name) } } +LocalVariableDefinition: (Option<u32>, Vec<u8>, ast::VariableLocalType, &'input str) = { + <align:Align?> <t:SizedScalarType> <name:ExtendedID> => { + let v_type = ast::VariableLocalType::Scalar(t); + (align, Vec::new(), v_type, name) + }, + <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => { + let v_type = ast::VariableLocalType::Vector(t, v_len); + (align, Vec::new(), v_type, name) + }, + <align:Align?> <arr:ArrayDefinition> => { + let (array_init, name, (t, dimensions)) = arr; + let v_type = ast::VariableLocalType::Array(t, dimensions); + (align, array_init, v_type, name) + } +} + #[inline] SizedScalarType: ast::SizedScalarType = { ".b8" => ast::SizedScalarType::B8, @@ -431,12 +461,59 @@ ParamScalarType: ast::ParamScalarType = { ".f64" => ast::ParamScalarType::F64, } -ArraySpecifier: u32 = { - "[" <n:Num> "]" => { - let size = n.parse::<u32>(); - size.unwrap_with(errors) +ArrayDefinition: (Vec<u8>, &'input str, (ast::SizedScalarType, Vec<u32>)) = { + <typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimensions> <init:ArrayInitializer?> =>? { + let mut dims = dims; + let array_init = init.unwrap_or(ast::NumsOrArrays::Nums(Vec::new())).to_vec(typ, &mut dims)?; + Ok(( + array_init, + name, + (typ, dims) + )) } -}; +} + +ArrayDeclaration: (&'input str, (ast::SizedScalarType, Vec<u32>)) = { + <typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimension+> =>? { + let dims = dims.into_iter().map(|x| if x > 0 { Ok(x) } else { Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }) }).collect::<Result<_,_>>()?; + Ok((name, (typ, dims))) + } +} + +// [0] and [] are treated the same +ArrayDimensions: Vec<u32> = { + ArrayEmptyDimension => vec![0u32], + ArrayEmptyDimension <dims:ArrayDimension+> => { + let mut dims = dims; + let mut result = vec![0u32]; + result.append(&mut dims); + result + }, + <dims:ArrayDimension+> => dims +} + +ArrayEmptyDimension = { + "[" "]" +} + +ArrayDimension: u32 = { + "[" <n:Num> "]" =>? { + str::parse::<u32>(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) }) + } +} + +ArrayInitializer: ast::NumsOrArrays<'input> = { + "=" <nums:NumsOrArraysBracket> => nums +} + +NumsOrArraysBracket: ast::NumsOrArrays<'input> = { + "{" <nums:NumsOrArrays> "}" => nums +} + +NumsOrArrays: ast::NumsOrArrays<'input> = { + <n:Comma<NumsOrArraysBracket>> => ast::NumsOrArrays::Arrays(n), + <n:CommaNonEmpty<Num>> => ast::NumsOrArrays::Nums(n), +} Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { InstLd, @@ -1244,3 +1321,11 @@ Comma<T>: Vec<T> = { } } }; + +CommaNonEmpty<T>: Vec<T> = { + <v:(<T> ",")*> <e:T> => { + let mut v = v; + v.push(e); + v + } +}; diff --git a/ptx/src/test/spirv_run/global_array.ptx b/ptx/src/test/spirv_run/global_array.ptx new file mode 100644 index 0000000..7ac8bce --- /dev/null +++ b/ptx/src/test/spirv_run/global_array.ptx @@ -0,0 +1,22 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.global .s32 foobar[4] = {1};
+
+.visible .entry global_array(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 temp;
+
+ mov.u64 in_addr, foobar;
+ ld.param.u64 out_addr, [output];
+
+ ld.global.u32 temp, [in_addr];
+ st.global.u32 [out_addr], temp;
+ ret;
+}
\ No newline at end of file diff --git a/ptx/src/test/spirv_run/global_array.spvtxt b/ptx/src/test/spirv_run/global_array.spvtxt new file mode 100644 index 0000000..a4ed91d --- /dev/null +++ b/ptx/src/test/spirv_run/global_array.spvtxt @@ -0,0 +1,54 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %22 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %2 "global_array" %1 + %void = OpTypeVoid + %uint = OpTypeInt 32 0 + %uint_4 = OpConstant %uint 4 +%_arr_uint_uint_4 = OpTypeArray %uint %uint_4 +%_ptr_CrossWorkgroup__arr_uint_uint_4 = OpTypePointer CrossWorkgroup %_arr_uint_uint_4 + %uint_4_0 = OpConstant %uint 4 + %uint_1 = OpConstant %uint 1 + %uint_0 = OpConstant %uint 0 + %31 = OpConstantComposite %_arr_uint_uint_4 %uint_1 %uint_0 %uint_0 %uint_0 + %1 = OpVariable %_ptr_CrossWorkgroup__arr_uint_uint_4 CrossWorkgroup %31 + %ulong = OpTypeInt 64 0 + %33 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Function_uint = OpTypePointer Function %uint +%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint + %2 = OpFunction %void None %33 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %20 = OpLabel + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_uint Function + OpStore %3 %8 + OpStore %4 %9 + %17 = OpConvertPtrToU %ulong %1 + %10 = OpCopyObject %ulong %17 + OpStore %5 %10 + %12 = OpLoad %ulong %4 + %11 = OpCopyObject %ulong %12 + OpStore %6 %11 + %14 = OpLoad %ulong %5 + %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %14 + %13 = OpLoad %uint %18 + OpStore %7 %13 + %15 = OpLoad %ulong %6 + %16 = OpLoad %uint %7 + %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %15 + OpStore %19 %16 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 8caf540..0c881d9 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -66,14 +66,18 @@ test_ptx!(b64tof64, [111u64], [111u64]); test_ptx!(implicit_param, [34u32], [34u32]);
test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]);
test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]);
-test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]);
+test_ptx!(
+ mul_wide,
+ [0x01_00_00_00__01_00_00_00i64],
+ [0x1_00_00_00_00_00_00i64]
+);
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
test_ptx!(shr, [-2i32], [-1i32]);
test_ptx!(or, [1u64, 2u64], [3u64]);
test_ptx!(sub, [2u64], [1u64]);
test_ptx!(min, [555i32, 444i32], [444i32]);
test_ptx!(max, [555i32, 444i32], [555i32]);
-
+test_ptx!(global_array, [0xDEADu32], [1u32]);
struct DisplayError<T: Debug> {
err: T,
@@ -131,7 +135,15 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>( let mut devices = drv.devices()?;
let dev = devices.drain(0..1).next().unwrap();
let queue = ze::CommandQueue::new(&mut ctx, &dev)?;
- let module = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None)?;
+ let (module, log) = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None);
+ let module = match module {
+ Ok(m) => m,
+ Err(err) => {
+ let raw_err_string = log.get_cstring()?;
+ let err_string = raw_err_string.to_string_lossy();
+ panic!("{:?}\n{}", err, err_string);
+ }
+ };
let mut kernel = ze::Kernel::new_resident(&module, name)?;
kernel.set_indirect_access(
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7c15744..a86ab3c 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,7 +1,7 @@ use crate::ast;
+use half::f16;
use rspirv::{binary::Disassemble, dr};
use std::collections::{hash_map, HashMap, HashSet};
-use std::convert::TryInto;
use std::{borrow::Cow, iter, mem};
use rspirv::binary::Assemble;
@@ -26,7 +26,7 @@ quick_error! { enum SpirvType {
Base(SpirvScalarKey),
Vector(SpirvScalarKey, u8),
- Array(SpirvScalarKey, u32),
+ Array(SpirvScalarKey, Vec<u32>),
Pointer(Box<SpirvType>, spirv::StorageClass),
Func(Option<Box<SpirvType>>, Vec<SpirvType>),
Struct(Vec<SpirvScalarKey>),
@@ -62,6 +62,7 @@ impl From<ast::ScalarType> for SpirvType { struct TypeWordMap {
void: spirv::Word,
complex: HashMap<SpirvType, spirv::Word>,
+ constants: HashMap<(SpirvType, u64), spirv::Word>,
}
// SPIR-V integer type definitions are signless, more below:
@@ -108,6 +109,7 @@ impl TypeWordMap { TypeWordMap {
void: void,
complex: HashMap::<SpirvType, spirv::Word>::new(),
+ constants: HashMap::new(),
}
}
@@ -154,13 +156,25 @@ impl TypeWordMap { .entry(t)
.or_insert_with(|| b.type_vector(base, len as u32))
}
- SpirvType::Array(typ, len) => {
- let base = self.get_or_add_spirv_scalar(b, typ);
+ SpirvType::Array(typ, array_dimensions) => {
let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
- *self.complex.entry(t).or_insert_with(|| {
- let len_word = b.constant_u32(u32_type, None, len);
- b.type_array(base, len_word)
- })
+ let (base_type, length) = match &*array_dimensions {
+ &[len] => {
+ let base = self.get_or_add_spirv_scalar(b, typ);
+ let len_const = b.constant_u32(u32_type, None, len);
+ (base, len_const)
+ }
+ array_dimensions => {
+ let base = self
+ .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
+ let len_const = b.constant_u32(u32_type, None, array_dimensions[0]);
+ (base, len_const)
+ }
+ };
+ *self
+ .complex
+ .entry(SpirvType::Array(typ, array_dimensions))
+ .or_insert_with(|| b.type_array(base_type, length))
}
SpirvType::Func(ref out_params, ref in_params) => {
let out_t = match out_params {
@@ -211,16 +225,173 @@ impl TypeWordMap { self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::<Vec<_>>())),
)
}
+
+ fn get_or_add_constant(
+ &mut self,
+ b: &mut dr::Builder,
+ typ: &ast::Type,
+ init: &[u8],
+ ) -> Result<spirv::Word, TranslateError> {
+ Ok(match typ {
+ ast::Type::Scalar(t) => match t {
+ ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self
+ .get_or_add_constant_single::<u8, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| b.constant_u32(result_type, None, v as u32),
+ ),
+ ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self
+ .get_or_add_constant_single::<u16, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| b.constant_u32(result_type, None, v as u32),
+ ),
+ ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self
+ .get_or_add_constant_single::<u32, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| b.constant_u32(result_type, None, v),
+ ),
+ ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self
+ .get_or_add_constant_single::<u64, _, _>(
+ b,
+ *t,
+ init,
+ |v| v,
+ |b, result_type, v| b.constant_u64(result_type, None, v),
+ ),
+ ast::ScalarType::F16 => self.get_or_add_constant_single::<f16, _, _>(
+ b,
+ *t,
+ init,
+ |v| unsafe { mem::transmute::<_, u16>(v) } as u64,
+ |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()),
+ ),
+ ast::ScalarType::F32 => self.get_or_add_constant_single::<f32, _, _>(
+ b,
+ *t,
+ init,
+ |v| unsafe { mem::transmute::<_, u32>(v) } as u64,
+ |b, result_type, v| b.constant_f32(result_type, None, v),
+ ),
+ ast::ScalarType::F64 => self.get_or_add_constant_single::<f64, _, _>(
+ b,
+ *t,
+ init,
+ |v| unsafe { mem::transmute::<_, u64>(v) },
+ |b, result_type, v| b.constant_f64(result_type, None, v),
+ ),
+ ast::ScalarType::F16x2 => return Err(TranslateError::Todo),
+ ast::ScalarType::Pred => self.get_or_add_constant_single::<u8, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| {
+ if v == 0 {
+ b.constant_false(result_type, None)
+ } else {
+ b.constant_true(result_type, None)
+ }
+ },
+ ),
+ },
+ ast::Type::Vector(typ, len) => {
+ let result_type =
+ self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len));
+ let size_of_t = typ.size_of();
+ let components = (0..*len)
+ .map(|x| {
+ self.get_or_add_constant(
+ b,
+ &ast::Type::Scalar(*typ),
+ &init[((size_of_t as usize) * (x as usize))..],
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ b.constant_composite(result_type, None, &components)
+ }
+ ast::Type::Array(typ, dims) => match dims.as_slice() {
+ [] => return Err(TranslateError::Unreachable),
+ [dim] => {
+ let result_type = self
+ .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim]));
+ let size_of_t = typ.size_of();
+ let components = (0..*dim)
+ .map(|x| {
+ self.get_or_add_constant(
+ b,
+ &ast::Type::Scalar(*typ),
+ &init[((size_of_t as usize) * (x as usize))..],
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ b.constant_composite(result_type, None, &components)
+ }
+ [first_dim, rest @ ..] => {
+ let result_type = self.get_or_add(
+ b,
+ SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()),
+ );
+ let size_of_t = rest
+ .iter()
+ .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y));
+ let components = (0..*first_dim)
+ .map(|x| {
+ self.get_or_add_constant(
+ b,
+ &ast::Type::Array(*typ, rest.to_vec()),
+ &init[((size_of_t as usize) * (x as usize))..],
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ b.constant_composite(result_type, None, &components)
+ }
+ },
+ })
+ }
+
+ fn get_or_add_constant_single<
+ T: Copy,
+ CastAsU64: FnOnce(T) -> u64,
+ InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word,
+ >(
+ &mut self,
+ b: &mut dr::Builder,
+ key: ast::ScalarType,
+ init: &[u8],
+ cast: CastAsU64,
+ f: InsertConstant,
+ ) -> spirv::Word {
+ let value = unsafe { *(init.as_ptr() as *const T) };
+ let value_64 = cast(value);
+ let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64);
+ match self.constants.get(&ht_key) {
+ Some(value) => *value,
+ None => {
+ let spirv_type = self.get_or_add_scalar(b, key);
+ let result = f(b, spirv_type, value);
+ self.constants.insert(ht_key, result);
+ result
+ }
+ }
+ }
}
pub fn to_spirv_module<'a>(
ast: ast::Module<'a>,
) -> Result<(dr::Module, HashMap<String, Vec<usize>>), TranslateError> {
let mut id_defs = GlobalStringIdResolver::new(1);
- let ssa_functions = ast
- .functions
+ let directives = ast
+ .directives
.into_iter()
- .map(|f| to_ssa_function(&mut id_defs, f))
+ .map(|f| translate_directive(&mut id_defs, f))
.collect::<Result<Vec<_>, _>>()?;
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
@@ -233,21 +404,28 @@ pub fn to_spirv_module<'a>( let mut map = TypeWordMap::new(&mut builder);
emit_builtins(&mut builder, &mut map, &id_defs);
let mut args_len = HashMap::new();
- for f in ssa_functions {
- let f_body = match f.body {
- Some(f) => f,
- None => continue,
- };
- emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
- emit_function_header(
- &mut builder,
- &mut map,
- &id_defs,
- f.func_directive,
- &mut args_len,
- )?;
- emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
- builder.end_function()?;
+ for d in directives {
+ match d {
+ Directive::Variable(var) => {
+ emit_variable(&mut builder, &mut map, &var)?;
+ }
+ Directive::Method(f) => {
+ let f_body = match f.body {
+ Some(f) => f,
+ None => continue,
+ };
+ emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
+ emit_function_header(
+ &mut builder,
+ &mut map,
+ &id_defs,
+ f.func_directive,
+ &mut args_len,
+ )?;
+ emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
+ builder.end_function()?;
+ }
+ }
}
Ok((builder.module(), args_len))
}
@@ -294,12 +472,18 @@ fn emit_function_header<'a>( let fn_id = match func_directive {
ast::MethodDecl::Kernel(name, _) => {
let fn_id = global.get_id(name)?;
- let interface = global
+ let mut global_variables = global
+ .variables_type_check
+ .iter()
+ .filter_map(|(k, t)| t.as_ref().map(|_| *k))
+ .collect::<Vec<_>>();
+ let mut interface = global
.special_registers
.iter()
.map(|(_, id)| *id)
.collect::<Vec<_>>();
- builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, interface);
+ global_variables.append(&mut interface);
+ builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
fn_id
}
ast::MethodDecl::Func(_, name, _) => name,
@@ -311,7 +495,7 @@ fn emit_function_header<'a>( func_type,
)?;
func_directive.visit_args(&mut |arg| {
- let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type).into());
+ let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type.clone()).into());
let inst = dr::Instruction::new(
spirv::Op::FunctionParameter,
Some(result_type),
@@ -355,7 +539,30 @@ fn emit_memory_model(builder: &mut dr::Builder) { );
}
-fn to_ssa_function<'a>(
+fn translate_directive<'input>(
+ id_defs: &mut GlobalStringIdResolver<'input>,
+ d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
+) -> Result<Directive<'input>, TranslateError> {
+ Ok(match d {
+ ast::Directive::Variable(v) => Directive::Variable(translate_variable(id_defs, v)?),
+ ast::Directive::Method(f) => Directive::Method(translate_function(id_defs, f)?),
+ })
+}
+
+fn translate_variable<'a>(
+ id_defs: &mut GlobalStringIdResolver<'a>,
+ var: ast::Variable<ast::VariableType, &'a str>,
+) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
+ let (state_space, typ) = var.v_type.to_type();
+ Ok(ast::Variable {
+ align: var.align,
+ v_type: var.v_type,
+ name: id_defs.get_or_add_def_typed(var.name, (state_space.into(), typ)),
+ array_init: var.array_init,
+ })
+}
+
+fn translate_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
f: ast::ParsedFunction<'a>,
) -> Result<Function<'a>, TranslateError> {
@@ -368,9 +575,13 @@ fn expand_kernel_params<'a, 'b>( args: impl Iterator<Item = &'b ast::KernelArgument<&'a str>>,
) -> Vec<ast::KernelArgument<spirv::Word>> {
args.map(|a| ast::KernelArgument {
- name: fn_resolver.add_def(a.name, Some((StateSpace::Param, ast::Type::from(a.v_type)))),
- v_type: a.v_type,
+ name: fn_resolver.add_def(
+ a.name,
+ Some((StateSpace::Param, ast::Type::from(a.v_type.clone()))),
+ ),
+ v_type: a.v_type.clone(),
align: a.align,
+ array_init: Vec::new(),
})
.collect()
}
@@ -385,9 +596,10 @@ fn expand_fn_params<'a, 'b>( ast::FnArgumentType::Param(_) => StateSpace::Param,
};
ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type)))),
- v_type: a.v_type,
+ name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type.clone())))),
+ v_type: a.v_type.clone(),
align: a.align,
+ array_init: Vec::new(),
}
})
.collect()
@@ -628,7 +840,7 @@ fn to_resolved_fn_args<T>( params
.into_iter()
.zip(params_decl.iter())
- .map(|(id, typ)| (id, *typ))
+ .map(|(id, typ)| (id, typ.clone()))
.collect::<Vec<_>>()
}
@@ -719,12 +931,13 @@ fn insert_mem_ssa_statements<'a, 'b>( let out_param = match &mut f_args {
ast::MethodDecl::Kernel(_, in_params) => {
for p in in_params.iter_mut() {
- let typ = ast::Type::from(p.v_type);
- let new_id = id_def.new_id(typ);
+ let typ = ast::Type::from(p.v_type.clone());
+ let new_id = id_def.new_id(typ.clone());
result.push(Statement::Variable(ast::Variable {
align: p.align,
- v_type: ast::VariableType::Param(p.v_type),
+ v_type: ast::VariableType::Param(p.v_type.clone()),
name: p.name,
+ array_init: p.array_init.clone(),
}));
result.push(Statement::StoreVar(
ast::Arg2St {
@@ -739,20 +952,21 @@ fn insert_mem_ssa_statements<'a, 'b>( }
ast::MethodDecl::Func(out_params, _, in_params) => {
for p in in_params.iter_mut() {
- let typ = ast::Type::from(p.v_type);
- let new_id = id_def.new_id(typ);
- let var_typ = ast::VariableType::from(p.v_type);
+ let typ = ast::Type::from(p.v_type.clone());
+ let new_id = id_def.new_id(typ.clone());
+ let var_typ = ast::VariableType::from(p.v_type.clone());
result.push(Statement::Variable(ast::Variable {
align: p.align,
v_type: var_typ,
name: p.name,
+ array_init: p.array_init.clone(),
}));
result.push(Statement::StoreVar(
ast::Arg2St {
src1: p.name,
src2: new_id,
},
- typ,
+ typ.clone(),
));
p.name = new_id;
}
@@ -760,8 +974,9 @@ fn insert_mem_ssa_statements<'a, 'b>( [p] => {
result.push(Statement::Variable(ast::Variable {
align: p.align,
- v_type: ast::VariableType::from(p.v_type),
+ v_type: ast::VariableType::from(p.v_type.clone()),
name: p.name,
+ array_init: p.array_init.clone(),
}));
Some(p.name)
}
@@ -779,13 +994,13 @@ fn insert_mem_ssa_statements<'a, 'b>( ast::Instruction::Ret(d) => {
if let Some(out_param) = out_param {
let typ = id_def.get_typed(out_param)?;
- let new_id = id_def.new_id(typ);
+ let new_id = id_def.new_id(typ.clone());
result.push(Statement::LoadVar(
ast::Arg2 {
dst: new_id,
src: out_param,
},
- typ,
+ typ.clone(),
));
result.push(Statement::RetValue(d, new_id));
} else {
@@ -824,7 +1039,7 @@ trait VisitVariable: Sized { 'a,
F: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
@@ -835,7 +1050,7 @@ trait VisitVariableExpanded { fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
@@ -861,7 +1076,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>( | (t, ArgumentSemantics::DefaultRelaxed)
| (t, ArgumentSemantics::PhysicalPointer) => t,
};
- let generated_id = id_def.new_id(id_type);
+ let generated_id = id_def.new_id(id_type.clone());
if !desc.is_dst {
result.push(Statement::LoadVar(
Arg2 {
@@ -909,10 +1124,12 @@ fn expand_arguments<'a, 'b>( align,
v_type,
name,
+ array_init,
}) => result.push(Statement::Variable(ast::Variable {
align,
v_type,
name,
+ array_init,
})),
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
@@ -969,7 +1186,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn reg(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- _: Option<ast::Type>,
+ _: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
Ok(desc.op)
}
@@ -977,8 +1194,9 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn reg_offset(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, i32)>,
- mut typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
+ let mut typ = typ.clone();
let (reg, offset) = desc.op;
match desc.sema {
ArgumentSemantics::Default
@@ -997,7 +1215,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
ScalarKind::Pred => return Err(TranslateError::MismatchedType),
};
- (scalar_t.width(), kind)
+ (scalar_t.size_of(), kind)
}
_ => return Err(TranslateError::MismatchedType),
};
@@ -1009,7 +1227,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } else {
ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
};
- let id_constant_stmt = self.id_def.new_id(typ);
+ let id_constant_stmt = self.id_def.new_id(typ.clone());
let result_id = self.id_def.new_id(typ);
// TODO: check for edge cases around min value/max value/wrapping
if offset < 0 && kind != ScalarKind::Signed {
@@ -1060,10 +1278,10 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn immediate(
&mut self,
desc: ArgumentDescriptor<u32>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
+ *scalar
} else {
todo!()
};
@@ -1098,14 +1316,14 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn vector(
&mut self,
desc: ArgumentDescriptor<&Vec<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
let (scalar_type, vec_len) = typ.get_vector()?;
if !desc.is_dst {
- let mut new_id = self.id_def.new_id(typ);
- self.func.push(Statement::Undef(typ, new_id));
+ let mut new_id = self.id_def.new_id(typ.clone());
+ self.func.push(Statement::Undef(typ.clone(), new_id));
for (idx, id) in desc.op.iter().enumerate() {
- let newer_id = self.id_def.new_id(typ);
+ let newer_id = self.id_def.new_id(typ.clone());
self.func.push(Statement::Instruction(ast::Instruction::Mov(
ast::MovDetails {
typ: ast::Type::Scalar(scalar_type),
@@ -1124,7 +1342,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { }
Ok(new_id)
} else {
- let new_id = self.id_def.new_id(typ);
+ let new_id = self.id_def.new_id(typ.clone());
for (idx, id) in desc.op.iter().enumerate() {
Self::insert_composite_read(
&mut self.post_stmts,
@@ -1144,7 +1362,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<ast::Type>,
+ t: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self.reg(desc, t)
}
@@ -1152,7 +1370,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::Operand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
@@ -1166,7 +1384,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::CallOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)),
@@ -1185,7 +1403,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::IdOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
@@ -1196,7 +1414,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::OperandOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
@@ -1236,8 +1454,8 @@ fn insert_implicit_conversions( None,
)?,
Statement::Instruction(inst) => {
- let mut default_conversion_fn = should_bitcast_wrapper
- as fn(_, _, _) -> Result<Option<ConversionKind>, TranslateError>;
+ let mut default_conversion_fn =
+ should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _;
let mut state_space = None;
if let ast::Instruction::Ld(d, _) = &inst {
state_space = Some(d.state_space);
@@ -1281,9 +1499,9 @@ fn insert_implicit_conversions_impl( func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
stmt: impl VisitVariableExpanded,
- default_conversion_fn: fn(
- ast::Type,
- ast::Type,
+ default_conversion_fn: for<'a> fn(
+ &'a ast::Type,
+ &'a ast::Type,
Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError>,
state_space: Option<ast::LdStateSpace>,
@@ -1315,16 +1533,16 @@ fn insert_implicit_conversions_impl( conversion_fn = force_bitcast_ptr_to_bit;
}
};
- match conversion_fn(operand_type, instr_type, state_space)? {
+ match conversion_fn(&operand_type, instr_type, state_space)? {
Some(conv_kind) => {
let conv_output = if desc.is_dst {
&mut post_conv
} else {
&mut *func
};
- let mut from = instr_type;
+ let mut from = instr_type.clone();
let mut to = operand_type;
- let mut src = id_def.new_id(instr_type);
+ let mut src = id_def.new_id(instr_type.clone());
let mut dst = desc.op;
let result = Ok(src);
if !desc.is_dst {
@@ -1358,17 +1576,17 @@ fn get_function_type( builder,
out_params
.iter()
- .map(|p| SpirvType::from(ast::Type::from(p.v_type))),
+ .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))),
in_params
.iter()
- .map(|p| SpirvType::from(ast::Type::from(p.v_type))),
+ .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))),
),
ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn(
builder,
iter::empty(),
params
.iter()
- .map(|p| SpirvType::from(ast::Type::from(p.v_type))),
+ .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))),
),
}
}
@@ -1398,7 +1616,7 @@ fn emit_function_body_ops( Statement::Call(call) => {
let (result_type, result_id) = match &*call.ret_params {
[(id, typ)] => (
- map.get_or_add(builder, SpirvType::from(ast::Type::from(*typ))),
+ map.get_or_add(builder, SpirvType::from(ast::Type::from(typ.clone()))),
Some(*id),
),
[] => (map.void(), None),
@@ -1411,28 +1629,8 @@ fn emit_function_body_ops( .collect::<Vec<_>>();
builder.function_call(result_type, result_id, call.func, arg_list)?;
}
- Statement::Variable(ast::Variable {
- align,
- v_type,
- name,
- }) => {
- let st_class = match v_type {
- ast::VariableType::Reg(_)
- | ast::VariableType::Param(_)
- | ast::VariableType::Local(_) => spirv::StorageClass::Function,
- };
- let type_id = map.get_or_add(
- builder,
- SpirvType::new_pointer(ast::Type::from(*v_type), st_class),
- );
- builder.variable(type_id, Some(*name), st_class, None);
- if let Some(align) = align {
- builder.decorate(
- *name,
- spirv::Decoration::Alignment,
- &[dr::Operand::LiteralInt32(*align)],
- );
- }
+ Statement::Variable(var) => {
+ emit_variable(builder, map, var)?;
}
Statement::Constant(cnst) => {
let typ_id = map.get_or_add_scalar(builder, cnst.typ);
@@ -1479,13 +1677,14 @@ fn emit_function_body_ops( if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
- let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
+ let result_type = map.get_or_add(builder, SpirvType::from(data.typ.clone()));
match data.state_space {
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
}
ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
- let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
+ let result_type =
+ map.get_or_add(builder, SpirvType::from(data.typ.clone()));
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
_ => todo!(),
@@ -1498,7 +1697,8 @@ fn emit_function_body_ops( if data.state_space == ast::StStateSpace::Param
|| data.state_space == ast::StStateSpace::Local
{
- let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
+ let result_type =
+ map.get_or_add(builder, SpirvType::from(data.typ.clone()));
builder.copy_object(result_type, Some(arg.src1), arg.src2)?;
} else if data.state_space == ast::StStateSpace::Generic
|| data.state_space == ast::StStateSpace::Global
@@ -1513,8 +1713,8 @@ fn emit_function_body_ops( ast::Instruction::Mov(d, arg) => match arg {
ast::Arg2Mov::Normal(ast::Arg2MovNormal { dst, src })
| ast::Arg2Mov::Member(ast::Arg2MovMember::Src(dst, src)) => {
- let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ)));
+ let result_type = map
+ .get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone())));
builder.copy_object(result_type, Some(*dst), *src)?;
}
ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
@@ -1645,7 +1845,7 @@ fn emit_function_body_ops( }
},
Statement::LoadVar(arg, typ) => {
- let type_id = map.get_or_add(builder, SpirvType::from(*typ));
+ let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
builder.load(type_id, Some(arg.dst), arg.src, None, [])?;
}
Statement::StoreVar(arg, _) => {
@@ -1665,7 +1865,7 @@ fn emit_function_body_ops( )?;
}
Statement::Undef(t, id) => {
- let result_type = map.get_or_add(builder, SpirvType::from(*t));
+ let result_type = map.get_or_add(builder, SpirvType::from(t.clone()));
builder.undef(result_type, Some(*id));
}
}
@@ -1673,6 +1873,41 @@ fn emit_function_body_ops( Ok(())
}
+fn emit_variable(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ var: &ast::Variable<ast::VariableType, spirv::Word>,
+) -> Result<(), TranslateError> {
+ let (should_init, st_class) = match var.v_type {
+ ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
+ (false, spirv::StorageClass::Function)
+ }
+ ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
+ };
+ let type_id = map.get_or_add(
+ builder,
+ SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class),
+ );
+ let initalizer = if should_init {
+ Some(map.get_or_add_constant(
+ builder,
+ &ast::Type::from(var.v_type.clone()),
+ &*var.array_init,
+ )?)
+ } else {
+ None
+ };
+ builder.variable(type_id, Some(var.name), st_class, initalizer);
+ if let Some(align) = var.align {
+ builder.decorate(
+ var.name,
+ spirv::Decoration::Alignment,
+ &[dr::Operand::LiteralInt32(align)],
+ );
+ }
+ Ok(())
+}
+
fn emit_mad_uint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1876,7 +2111,7 @@ fn emit_cvt( dst: new_dst,
from: ast::Type::Scalar(src_t),
to: ast::Type::Scalar(ast::ScalarType::from_parts(
- dest_t.width(),
+ dest_t.size_of(),
src_t.kind(),
)),
kind: ConversionKind::Default,
@@ -2041,7 +2276,7 @@ fn emit_mul_sint( ]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
- let instr_width = instruction_type.width();
+ let instr_width = instruction_type.size_of();
let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
let dst_type_id = map.get_or_add_scalar(builder, dst_type);
@@ -2088,7 +2323,7 @@ fn emit_mul_uint( ]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
- let instr_width = instruction_type.width();
+ let instr_width = instruction_type.size_of();
let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
let dst_type_id = map.get_or_add_scalar(builder, dst_type);
@@ -2193,13 +2428,13 @@ fn emit_implicit_conversion( (_, _, ConversionKind::BitToPtr(space)) => {
let dst_type = map.get_or_add(
builder,
- SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()),
+ SpirvType::Pointer(Box::new(SpirvType::from(cv.to.clone())), space.to_spirv()),
);
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
if from_parts.width == to_parts.width {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.to));
+ let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
if from_parts.scalar_kind != ScalarKind::Float
&& to_parts.scalar_kind != ScalarKind::Float
{
@@ -2222,7 +2457,8 @@ fn emit_implicit_conversion( scalar_kind: ScalarKind::Bit,
..to_parts
});
- let wide_bit_type_spirv = map.get_or_add(builder, SpirvType::from(wide_bit_type));
+ let wide_bit_type_spirv =
+ map.get_or_add(builder, SpirvType::from(wide_bit_type.clone()));
if to_parts.scalar_kind == ScalarKind::Unsigned
|| to_parts.scalar_kind == ScalarKind::Bit
{
@@ -2237,7 +2473,7 @@ fn emit_implicit_conversion( src: wide_bit_value,
dst: cv.dst,
from: wide_bit_type,
- to: cv.to,
+ to: cv.to.clone(),
kind: ConversionKind::Default,
},
)?;
@@ -2248,7 +2484,7 @@ fn emit_implicit_conversion( (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
| (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default)
| (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
- let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
+ let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
}
_ => unreachable!(),
@@ -2301,23 +2537,29 @@ fn expand_map_variables<'a, 'b>( ast::VariableType::Reg(_) => StateSpace::Reg,
ast::VariableType::Local(_) => StateSpace::Local,
ast::VariableType::Param(_) => StateSpace::ParamReg,
+ ast::VariableType::Global(_) => todo!(),
};
match var.count {
Some(count) => {
- for new_id in id_defs.add_defs(var.var.name, count, ss, var.var.v_type.into()) {
+ for new_id in
+ id_defs.add_defs(var.var.name, count, ss, var.var.v_type.clone().into())
+ {
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
- v_type: var.var.v_type,
+ v_type: var.var.v_type.clone(),
name: new_id,
+ array_init: var.var.array_init.clone(),
}))
}
}
None => {
- let new_id = id_defs.add_def(var.var.name, Some((ss, var.var.v_type.into())));
+ let new_id =
+ id_defs.add_def(var.var.name, Some((ss, var.var.v_type.clone().into())));
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
- v_type: var.var.v_type,
+ v_type: var.var.v_type.clone(),
name: new_id,
+ array_init: var.var.array_init,
}));
}
}
@@ -2367,6 +2609,7 @@ impl PtxSpecialRegister { struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
+ variables_type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
special_registers: HashMap<PtxSpecialRegister, spirv::Word>,
fns: HashMap<spirv::Word, FnDecl>,
}
@@ -2381,13 +2624,26 @@ impl<'a> GlobalStringIdResolver<'a> { Self {
current_id: start_id,
variables: HashMap::new(),
+ variables_type_check: HashMap::new(),
special_registers: HashMap::new(),
fns: HashMap::new(),
}
}
fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word {
- match self.variables.entry(Cow::Borrowed(id)) {
+ self.get_or_add_impl(id, None)
+ }
+
+ fn get_or_add_def_typed(&mut self, id: &'a str, typ: (StateSpace, ast::Type)) -> spirv::Word {
+ self.get_or_add_impl(id, Some(typ))
+ }
+
+ fn get_or_add_impl(
+ &mut self,
+ id: &'a str,
+ typ: Option<(StateSpace, ast::Type)>,
+ ) -> spirv::Word {
+ let id = match self.variables.entry(Cow::Borrowed(id)) {
hash_map::Entry::Occupied(e) => *(e.get()),
hash_map::Entry::Vacant(e) => {
let numeric_id = self.current_id;
@@ -2395,7 +2651,9 @@ impl<'a> GlobalStringIdResolver<'a> { self.current_id += 1;
numeric_id
}
- }
+ };
+ self.variables_type_check.insert(id, typ);
+ id
}
fn get_id(&self, id: &str) -> Result<spirv::Word, TranslateError> {
@@ -2422,6 +2680,7 @@ impl<'a> GlobalStringIdResolver<'a> { let mut fn_resolver = FnStringIdResolver {
current_id: &mut self.current_id,
global_variables: &self.variables,
+ global_type_check: &self.variables_type_check,
special_registers: &mut self.special_registers,
variables: vec![HashMap::new(); 1],
type_check: HashMap::new(),
@@ -2436,8 +2695,8 @@ impl<'a> GlobalStringIdResolver<'a> { self.fns.insert(
name_id,
FnDecl {
- ret_vals: ret_params_ids.iter().map(|p| p.v_type).collect(),
- params: params_ids.iter().map(|p| p.v_type).collect(),
+ ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(),
+ params: params_ids.iter().map(|p| p.v_type.clone()).collect(),
},
);
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
@@ -2475,6 +2734,7 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut spirv::Word,
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
+ global_type_check: &'b HashMap<u32, Option<(StateSpace, ast::Type)>>,
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
@@ -2484,6 +2744,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { fn finish(self) -> NumericIdResolver<'b> {
NumericIdResolver {
current_id: self.current_id,
+ global_type_check: self.global_type_check,
type_check: self.type_check,
special_registers: self
.special_registers
@@ -2551,7 +2812,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .last_mut()
.unwrap()
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
- self.type_check.insert(numeric_id + i, Some((ss, typ)));
+ self.type_check
+ .insert(numeric_id + i, Some((ss, typ.clone())));
}
*self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
@@ -2560,6 +2822,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
+ global_type_check: &'b HashMap<u32, Option<(StateSpace, ast::Type)>>,
type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
special_registers: HashMap<spirv::Word, PtxSpecialRegister>,
}
@@ -2571,11 +2834,14 @@ impl<'b> NumericIdResolver<'b> { fn get_typed(&self, id: spirv::Word) -> Result<(StateSpace, ast::Type), TranslateError> {
match self.type_check.get(&id) {
- Some(Some(x)) => Ok(*x),
+ Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(&id) {
Some(x) => Ok((StateSpace::Reg, x.get_type())),
- None => Err(TranslateError::UntypedSymbol),
+ None => match self.global_type_check.get(&id) {
+ Some(Some(x)) => Ok(x.clone()),
+ Some(None) | None => Err(TranslateError::UntypedSymbol),
+ },
},
}
}
@@ -2655,7 +2921,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> { is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(typ.into()),
+ Some(&typ.clone().into()),
)?;
Ok((new_id, typ))
})
@@ -2678,7 +2944,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> { is_dst: false,
sema: ArgumentSemantics::Default,
},
- typ.into(),
+ &typ.clone().into(),
)?;
Ok((new_id, typ))
})
@@ -2697,7 +2963,7 @@ impl VisitVariable for ResolvedCall<TypedArgParams> { 'a,
F: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
@@ -2711,7 +2977,7 @@ impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> { fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
@@ -2821,6 +3087,24 @@ pub enum StateSpace { ParamReg,
}
+impl From<ast::StateSpace> for StateSpace {
+ fn from(ss: ast::StateSpace) -> Self {
+ match ss {
+ ast::StateSpace::Reg => StateSpace::Reg,
+ ast::StateSpace::Const => StateSpace::Const,
+ ast::StateSpace::Global => StateSpace::Global,
+ ast::StateSpace::Local => StateSpace::Local,
+ ast::StateSpace::Shared => StateSpace::Shared,
+ ast::StateSpace::Param => StateSpace::Param,
+ }
+ }
+}
+
+enum Directive<'input> {
+ Variable(ast::Variable<ast::VariableType, spirv::Word>),
+ Method(Function<'input>),
+}
+
struct Function<'input> {
pub func_directive: ast::MethodDecl<'input, spirv::Word>,
pub globals: Vec<ExpandedStatement>,
@@ -2831,27 +3115,27 @@ pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> { fn id(
&mut self,
desc: ArgumentDescriptor<T::Id>,
- typ: Option<ast::Type>,
+ typ: Option<&ast::Type>,
) -> Result<U::Id, TranslateError>;
fn operand(
&mut self,
desc: ArgumentDescriptor<T::Operand>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<U::Operand, TranslateError>;
fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<T::IdOrVector>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<U::IdOrVector, TranslateError>;
fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<T::OperandOrVector>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<U::OperandOrVector, TranslateError>;
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<T::CallOperand>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<U::CallOperand, TranslateError>;
fn src_member_operand(
&mut self,
@@ -2864,13 +3148,13 @@ impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<ast::Type>,
+ t: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -2878,7 +3162,7 @@ where fn operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(t))
}
@@ -2886,7 +3170,7 @@ where fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(typ))
}
@@ -2894,7 +3178,7 @@ where fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(typ))
}
@@ -2902,7 +3186,7 @@ where fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(t))
}
@@ -2912,7 +3196,7 @@ where desc: ArgumentDescriptor<spirv::Word>,
(scalar_type, _): (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
- self(desc.new_op(desc.op), Some(ast::Type::Scalar(scalar_type)))
+ self(desc.new_op(desc.op), Some(&ast::Type::Scalar(scalar_type)))
}
}
@@ -2923,7 +3207,7 @@ where fn id(
&mut self,
desc: ArgumentDescriptor<&str>,
- _: Option<ast::Type>,
+ _: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self(desc.op)
}
@@ -2931,7 +3215,7 @@ where fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
- _: ast::Type,
+ _: &ast::Type,
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
match desc.op {
ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)),
@@ -2943,7 +3227,7 @@ where fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::IdOrVector<&'a str>>,
- _: ast::Type,
+ _: &ast::Type,
) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
match desc.op {
ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(id)?)),
@@ -2956,7 +3240,7 @@ where fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::OperandOrVector<&'a str>>,
- _: ast::Type,
+ _: &ast::Type,
) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
match desc.op {
ast::OperandOrVector::Reg(id) => Ok(ast::OperandOrVector::Reg(self(id)?)),
@@ -2973,7 +3257,7 @@ where fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<&str>>,
- _: ast::Type,
+ _: &ast::Type,
) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
match desc.op {
ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(id)?)),
@@ -3027,39 +3311,39 @@ impl<T: ArgParamsEx> ast::Instruction<T> { ) -> Result<ast::Instruction<U>, TranslateError> {
Ok(match self {
ast::Instruction::Abs(d, arg) => {
- ast::Instruction::Abs(d, arg.map(visitor, false, ast::Type::Scalar(d.typ))?)
+ ast::Instruction::Abs(d, arg.map(visitor, false, &ast::Type::Scalar(d.typ))?)
}
// Call instruction is converted to a call statement early on
ast::Instruction::Call(_) => return Err(TranslateError::Unreachable),
ast::Instruction::Ld(d, a) => {
- let inst_type = d.typ;
let is_param = d.state_space == ast::LdStateSpace::Param
|| d.state_space == ast::LdStateSpace::Local;
- ast::Instruction::Ld(d, a.map(visitor, inst_type, is_param)?)
+ let new_args = a.map(visitor, &d.typ, is_param)?;
+ ast::Instruction::Ld(d, new_args)
}
ast::Instruction::Mov(d, a) => {
- let mapped = a.map(visitor, d)?;
+ let mapped = a.map(visitor, &d)?;
ast::Instruction::Mov(d, mapped)
}
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
let is_wide = d.is_wide();
- ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type, is_wide)?)
+ ast::Instruction::Mul(d, a.map_non_shift(visitor, &inst_type, is_wide)?)
}
ast::Instruction::Add(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type, false)?)
+ ast::Instruction::Add(d, a.map_non_shift(visitor, &inst_type, false)?)
}
ast::Instruction::Setp(d, a) => {
let inst_type = d.typ;
- ast::Instruction::Setp(d, a.map(visitor, ast::Type::Scalar(inst_type))?)
+ ast::Instruction::Setp(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
ast::Instruction::SetpBool(d, a) => {
let inst_type = d.typ;
- ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type))?)
+ ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
ast::Instruction::Not(t, a) => {
- ast::Instruction::Not(t, a.map(visitor, false, t.to_type())?)
+ ast::Instruction::Not(t, a.map(visitor, false, &t.to_type())?)
}
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
@@ -3080,46 +3364,46 @@ impl<T: ArgParamsEx> ast::Instruction<T> { ast::Type::Scalar(desc.src.into()),
),
};
- ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t)?)
+ ast::Instruction::Cvt(d, a.map_cvt(visitor, &dst_t, &src_t)?)
}
ast::Instruction::Shl(t, a) => {
- ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type())?)
+ ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?)
}
ast::Instruction::Shr(t, a) => {
- ast::Instruction::Shr(t, a.map_shift(visitor, ast::Type::Scalar(t.into()))?)
+ ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?)
}
ast::Instruction::St(d, a) => {
- let inst_type = d.typ;
let is_param = d.state_space == ast::StStateSpace::Param
|| d.state_space == ast::StStateSpace::Local;
- ast::Instruction::St(d, a.map(visitor, inst_type, is_param)?)
+ let new_args = a.map(visitor, &d.typ, is_param)?;
+ ast::Instruction::St(d, new_args)
}
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
ast::Instruction::Cvta(d, a) => {
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
- ast::Instruction::Cvta(d, a.map(visitor, false, inst_type)?)
+ ast::Instruction::Cvta(d, a.map(visitor, false, &inst_type)?)
}
ast::Instruction::Mad(d, a) => {
let inst_type = d.get_type();
let is_wide = d.is_wide();
- ast::Instruction::Mad(d, a.map(visitor, inst_type, is_wide)?)
+ ast::Instruction::Mad(d, a.map(visitor, &inst_type, is_wide)?)
}
ast::Instruction::Or(t, a) => ast::Instruction::Or(
t,
- a.map_non_shift(visitor, ast::Type::Scalar(t.into()), false)?,
+ a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?,
),
ast::Instruction::Sub(d, a) => {
let typ = d.get_type();
- ast::Instruction::Sub(d, a.map_non_shift(visitor, typ, false)?)
+ ast::Instruction::Sub(d, a.map_non_shift(visitor, &typ, false)?)
}
ast::Instruction::Min(d, a) => {
let typ = d.get_type();
- ast::Instruction::Min(d, a.map_non_shift(visitor, typ, false)?)
+ ast::Instruction::Min(d, a.map_non_shift(visitor, &typ, false)?)
}
ast::Instruction::Max(d, a) => {
let typ = d.get_type();
- ast::Instruction::Max(d, a.map_non_shift(visitor, typ, false)?)
+ ast::Instruction::Max(d, a.map_non_shift(visitor, &typ, false)?)
}
})
}
@@ -3130,7 +3414,7 @@ impl VisitVariable for ast::Instruction<TypedArgParams> { 'a,
F: FnMut(
ArgumentDescriptor<spirv_headers::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv_headers::Word, TranslateError>,
>(
self,
@@ -3144,13 +3428,13 @@ impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<ast::Type>,
+ t: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -3158,7 +3442,7 @@ where fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
match desc.op {
ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(desc.new_op(id), Some(t))?)),
@@ -3173,7 +3457,7 @@ where fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
match desc.op {
ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(desc.new_op(id), Some(t))?)),
@@ -3184,7 +3468,7 @@ where fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
match desc.op {
ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(desc.new_op(id), Some(typ))?)),
@@ -3199,7 +3483,7 @@ where fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
match desc.op {
ast::OperandOrVector::Reg(id) => {
@@ -3226,7 +3510,7 @@ where Ok((
self(
desc.new_op(desc.op.0),
- Some(ast::Type::Vector(scalar_type.into(), vector_len)),
+ Some(&ast::Type::Vector(scalar_type.into(), vector_len)),
)?,
desc.op.1,
))
@@ -3238,7 +3522,7 @@ impl ast::Type { match self {
ast::Type::Scalar(scalar) => {
let kind = scalar.kind();
- let width = scalar.width();
+ let width = scalar.size_of();
if (kind != ScalarKind::Signed
&& kind != ScalarKind::Unsigned
&& kind != ScalarKind::Bit)
@@ -3255,25 +3539,25 @@ impl ast::Type { }
}
- fn to_parts(self) -> TypeParts {
+ fn to_parts(&self) -> TypeParts {
match self {
ast::Type::Scalar(scalar) => TypeParts {
kind: TypeKind::Scalar,
scalar_kind: scalar.kind(),
- width: scalar.width(),
- components: 0,
+ width: scalar.size_of(),
+ components: Vec::new(),
},
ast::Type::Vector(scalar, components) => TypeParts {
kind: TypeKind::Vector,
scalar_kind: scalar.kind(),
- width: scalar.width(),
- components: components as u32,
+ width: scalar.size_of(),
+ components: vec![*components as u32],
},
ast::Type::Array(scalar, components) => TypeParts {
kind: TypeKind::Array,
scalar_kind: scalar.kind(),
- width: scalar.width(),
- components: components,
+ width: scalar.size_of(),
+ components: components.clone(),
},
}
}
@@ -3285,7 +3569,7 @@ impl ast::Type { }
TypeKind::Vector => ast::Type::Vector(
ast::ScalarType::from_parts(t.width, t.scalar_kind),
- t.components as u8,
+ t.components[0] as u8,
),
TypeKind::Array => ast::Type::Array(
ast::ScalarType::from_parts(t.width, t.scalar_kind),
@@ -3295,12 +3579,12 @@ impl ast::Type { }
}
-#[derive(Eq, PartialEq, Copy, Clone)]
+#[derive(Eq, PartialEq, Clone)]
struct TypeParts {
kind: TypeKind,
scalar_kind: ScalarKind,
width: u8,
- components: u32,
+ components: Vec<u32>,
}
#[derive(Eq, PartialEq, Copy, Clone)]
@@ -3342,7 +3626,7 @@ impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> { fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv_headers::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv_headers::Word, TranslateError>,
>(
self,
@@ -3368,7 +3652,7 @@ impl VisitVariableExpanded for CompositeRead { fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv_headers::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv_headers::Word, TranslateError>,
>(
self,
@@ -3384,7 +3668,7 @@ impl VisitVariableExpanded for CompositeRead { is_dst: true,
sema: dst_sema,
},
- Some(ast::Type::Scalar(self.typ)),
+ Some(&ast::Type::Scalar(self.typ)),
)?,
src_composite: f(
ArgumentDescriptor {
@@ -3392,7 +3676,7 @@ impl VisitVariableExpanded for CompositeRead { is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Vector(self.typ, self.src_len as u8)),
+ Some(&ast::Type::Vector(self.typ, self.src_len as u8)),
)?,
..self
}))
@@ -3411,7 +3695,7 @@ struct BrachCondition { if_false: spirv::Word,
}
-#[derive(Copy, Clone)]
+#[derive(Clone)]
struct ImplicitConversion {
src: spirv::Word,
dst: spirv::Word,
@@ -3471,11 +3755,12 @@ impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> { }
impl ast::VariableParamType {
- fn width(self) -> usize {
+ fn width(&self) -> usize {
match self {
- ast::VariableParamType::Scalar(t) => ast::ScalarType::from(t).width() as usize,
+ ast::VariableParamType::Scalar(t) => ast::ScalarType::from(*t).size_of() as usize,
ast::VariableParamType::Array(t, len) => {
- (ast::ScalarType::from(t).width() as usize) * (len as usize)
+ (ast::ScalarType::from(*t).size_of() as usize)
+ * (len.iter().fold(1, |x, y| x * (*y)) as usize)
}
}
}
@@ -3489,7 +3774,7 @@ impl<T: ArgParamsEx> ast::Arg1<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: Option<&ast::Type>,
) -> Result<ast::Arg1<U>, TranslateError> {
let new_src = visitor.id(
ArgumentDescriptor {
@@ -3515,7 +3800,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> { self,
visitor: &mut V,
src_is_addr: bool,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
let new_dst = visitor.id(
ArgumentDescriptor {
@@ -3546,8 +3831,8 @@ impl<T: ArgParamsEx> ast::Arg2<T> { fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- dst_t: ast::Type,
- src_t: ast::Type,
+ dst_t: &ast::Type,
+ src_t: &ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
@@ -3582,7 +3867,7 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: &ast::Type,
is_param: bool,
) -> Result<ast::Arg2Ld<U>, TranslateError> {
let dst = visitor.id_or_vector(
@@ -3591,7 +3876,7 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> { is_dst: true,
sema: ArgumentSemantics::DefaultRelaxed,
},
- t.into(),
+ &ast::Type::from(t.clone()),
)?;
let src = visitor.operand(
ArgumentDescriptor {
@@ -3622,7 +3907,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: &ast::Type,
is_param: bool,
) -> Result<ast::Arg2St<U>, TranslateError> {
let src1 = visitor.operand(
@@ -3653,7 +3938,7 @@ impl<T: ArgParamsEx> ast::Arg2Mov<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- details: ast::MovDetails,
+ details: &ast::MovDetails,
) -> Result<ast::Arg2Mov<U>, TranslateError> {
Ok(match self {
ast::Arg2Mov::Normal(arg) => ast::Arg2Mov::Normal(arg.map(visitor, details)?),
@@ -3675,7 +3960,7 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<P, U>>(
self,
visitor: &mut V,
- details: ast::MovDetails,
+ details: &ast::MovDetails,
) -> Result<ast::Arg2MovNormal<U>, TranslateError> {
let dst = visitor.id_or_vector(
ArgumentDescriptor {
@@ -3683,7 +3968,7 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> { is_dst: true,
sema: ArgumentSemantics::Default,
},
- details.typ.into(),
+ &details.typ.clone().into(),
)?;
let src = visitor.operand_or_vector(
ArgumentDescriptor {
@@ -3695,7 +3980,7 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> { ArgumentSemantics::Default
},
},
- details.typ.into(),
+ &details.typ.clone().into(),
)?;
Ok(ast::Arg2MovNormal { dst, src })
}
@@ -3733,7 +4018,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- details: ast::MovDetails,
+ details: &ast::MovDetails,
) -> Result<ast::Arg2MovMember<U>, TranslateError> {
match self {
ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => {
@@ -3744,7 +4029,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> { is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Vector(scalar_type, details.dst_width)),
+ Some(&ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src1 = visitor.id(
ArgumentDescriptor {
@@ -3752,7 +4037,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> { is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Vector(scalar_type, details.dst_width)),
+ Some(&ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src2 = visitor.id(
ArgumentDescriptor {
@@ -3766,7 +4051,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> { ArgumentSemantics::Default
},
},
- Some(details.typ.into()),
+ Some(&details.typ.clone().into()),
)?;
Ok(ast::Arg2MovMember::Dst((dst, len), src1, src2))
}
@@ -3777,7 +4062,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> { is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(details.typ.into()),
+ Some(&details.typ.clone().into()),
)?;
let scalar_typ = details.typ.get_scalar()?;
let src = visitor.src_member_operand(
@@ -3798,7 +4083,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> { is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Vector(scalar_type, details.dst_width)),
+ Some(&ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let composite_src = visitor.id(
ArgumentDescriptor {
@@ -3806,7 +4091,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> { is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Vector(scalar_type, details.dst_width)),
+ Some(&ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src = visitor.src_member_operand(
ArgumentDescriptor {
@@ -3838,16 +4123,21 @@ impl<T: ArgParamsEx> ast::Arg3<T> { fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- typ: ast::Type,
+ typ: &ast::Type,
is_wide: bool,
) -> Result<ast::Arg3<U>, TranslateError> {
+ let wide_type = if is_wide {
+ Some(typ.clone().widen()?)
+ } else {
+ None
+ };
let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(if is_wide { typ.widen()? } else { typ }),
+ Some(wide_type.as_ref().unwrap_or(typ)),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -3871,7 +4161,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> { fn map_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<ast::Arg3<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
@@ -3895,7 +4185,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> { is_dst: false,
sema: ArgumentSemantics::Default,
},
- ast::Type::Scalar(ast::ScalarType::U32),
+ &ast::Type::Scalar(ast::ScalarType::U32),
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -3914,16 +4204,21 @@ impl<T: ArgParamsEx> ast::Arg4<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: &ast::Type,
is_wide: bool,
) -> Result<ast::Arg4<U>, TranslateError> {
+ let wide_type = if is_wide {
+ Some(t.clone().widen()?)
+ } else {
+ None
+ };
let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(if is_wide { t.widen()? } else { t }),
+ Some(wide_type.as_ref().unwrap_or(t)),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -3971,7 +4266,7 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<ast::Arg4Setp<U>, TranslateError> {
let dst1 = visitor.id(
ArgumentDescriptor {
@@ -3979,7 +4274,7 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> { is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
)?;
let dst2 = self
.dst2
@@ -3990,7 +4285,7 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> { is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
)
})
.transpose()?;
@@ -4033,7 +4328,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<ast::Arg5<U>, TranslateError> {
let dst1 = visitor.id(
ArgumentDescriptor {
@@ -4041,7 +4336,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> { is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
)?;
let dst2 = self
.dst2
@@ -4052,7 +4347,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> { is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
)
})
.transpose()?;
@@ -4078,7 +4373,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> { is_dst: false,
sema: ArgumentSemantics::Default,
},
- ast::Type::Scalar(ast::ScalarType::Pred),
+ &ast::Type::Scalar(ast::ScalarType::Pred),
)?;
Ok(ast::Arg5 {
dst1,
@@ -4091,16 +4386,16 @@ impl<T: ArgParamsEx> ast::Arg5<T> { }
impl ast::Type {
- fn get_vector(self) -> Result<(ast::ScalarType, u8), TranslateError> {
+ fn get_vector(&self) -> Result<(ast::ScalarType, u8), TranslateError> {
match self {
- ast::Type::Vector(t, len) => Ok((t, len)),
+ ast::Type::Vector(t, len) => Ok((*t, *len)),
_ => Err(TranslateError::MismatchedType),
}
}
- fn get_scalar(self) -> Result<ast::ScalarType, TranslateError> {
+ fn get_scalar(&self) -> Result<ast::ScalarType, TranslateError> {
match self {
- ast::Type::Scalar(t) => Ok(t),
+ ast::Type::Scalar(t) => Ok(*t),
_ => Err(TranslateError::MismatchedType),
}
}
@@ -4141,28 +4436,6 @@ enum ScalarKind { }
impl ast::ScalarType {
- fn width(self) -> u8 {
- match self {
- ast::ScalarType::U8 => 1,
- ast::ScalarType::S8 => 1,
- ast::ScalarType::B8 => 1,
- ast::ScalarType::U16 => 2,
- ast::ScalarType::S16 => 2,
- ast::ScalarType::B16 => 2,
- ast::ScalarType::F16 => 2,
- ast::ScalarType::U32 => 4,
- ast::ScalarType::S32 => 4,
- ast::ScalarType::B32 => 4,
- ast::ScalarType::F32 => 4,
- ast::ScalarType::U64 => 8,
- ast::ScalarType::S64 => 8,
- ast::ScalarType::B64 => 8,
- ast::ScalarType::F64 => 8,
- ast::ScalarType::F16x2 => 4,
- ast::ScalarType::Pred => 1,
- }
- }
-
fn kind(self) -> ScalarKind {
match self {
ast::ScalarType::U8 => ScalarKind::Unsigned,
@@ -4283,20 +4556,6 @@ impl ast::MinMaxDetails { }
}
-impl ast::IntType {
- fn try_new(t: ast::ScalarType) -> Option<Self> {
- match t {
- ast::ScalarType::U16 => Some(ast::IntType::U16),
- ast::ScalarType::U32 => Some(ast::IntType::U32),
- ast::ScalarType::U64 => Some(ast::IntType::U64),
- ast::ScalarType::S16 => Some(ast::IntType::S16),
- ast::ScalarType::S32 => Some(ast::IntType::S32),
- ast::ScalarType::S64 => Some(ast::IntType::S64),
- _ => None,
- }
- }
-}
-
impl ast::SIntType {
fn from_size(width: u8) -> Self {
match width {
@@ -4372,8 +4631,8 @@ impl ast::MulDetails { }
fn force_bitcast(
- operand: ast::Type,
- instr: ast::Type,
+ operand: &ast::Type,
+ instr: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if instr != operand {
@@ -4384,8 +4643,8 @@ fn force_bitcast( }
fn bitcast_physical_pointer(
- operand_type: ast::Type,
- _: ast::Type,
+ operand_type: &ast::Type,
+ _: &ast::Type,
ss: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
match operand_type {
@@ -4403,17 +4662,17 @@ fn bitcast_physical_pointer( }
fn force_bitcast_ptr_to_bit(
- _: ast::Type,
- _: ast::Type,
+ _: &ast::Type,
+ _: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
Ok(Some(ConversionKind::PtrToBit))
}
-fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
+fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
- if inst.width() != operand.width() {
+ if inst.size_of() != operand.size_of() {
return false;
}
match inst.kind() {
@@ -4431,22 +4690,22 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool { }
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
| (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => {
- should_bitcast(ast::Type::Scalar(inst), ast::Type::Scalar(operand))
+ should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
}
_ => false,
}
}
fn should_bitcast_packed(
- operand: ast::Type,
- instr: ast::Type,
+ operand: &ast::Type,
+ instr: &ast::Type,
ss: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
(operand, instr)
{
if scalar.kind() == ScalarKind::Bit
- && scalar.width() == (vec_underlying_type.width() * vec_len)
+ && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
{
return Ok(Some(ConversionKind::Default));
}
@@ -4455,8 +4714,8 @@ fn should_bitcast_packed( }
fn should_bitcast_wrapper(
- operand: ast::Type,
- instr: ast::Type,
+ operand: &ast::Type,
+ instr: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if instr == operand {
@@ -4470,8 +4729,8 @@ fn should_bitcast_wrapper( }
fn should_convert_relaxed_src_wrapper(
- src_type: ast::Type,
- instr_type: ast::Type,
+ src_type: &ast::Type,
+ instr_type: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if src_type == instr_type {
@@ -4485,8 +4744,8 @@ fn should_convert_relaxed_src_wrapper( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
- src_type: ast::Type,
- instr_type: ast::Type,
+ src_type: &ast::Type,
+ instr_type: &ast::Type,
) -> Option<ConversionKind> {
if src_type == instr_type {
return None;
@@ -4494,21 +4753,24 @@ fn should_convert_relaxed_src( match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ScalarKind::Bit => {
- if instr_type.width() <= src_type.width() {
+ if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Signed | ScalarKind::Unsigned => {
- if instr_type.width() <= src_type.width() && src_type.kind() != ScalarKind::Float {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() != ScalarKind::Float
+ {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float => {
- if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Bit {
+ if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit
+ {
Some(ConversionKind::Default)
} else {
None
@@ -4519,15 +4781,18 @@ fn should_convert_relaxed_src( },
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
- should_convert_relaxed_src(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type))
+ should_convert_relaxed_src(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
}
_ => None,
}
}
fn should_convert_relaxed_dst_wrapper(
- dst_type: ast::Type,
- instr_type: ast::Type,
+ dst_type: &ast::Type,
+ instr_type: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if dst_type == instr_type {
@@ -4541,8 +4806,8 @@ fn should_convert_relaxed_dst_wrapper( // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
- dst_type: ast::Type,
- instr_type: ast::Type,
+ dst_type: &ast::Type,
+ instr_type: &ast::Type,
) -> Option<ConversionKind> {
if dst_type == instr_type {
return None;
@@ -4550,7 +4815,7 @@ fn should_convert_relaxed_dst( match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ScalarKind::Bit => {
- if instr_type.width() <= dst_type.width() {
+ if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
@@ -4558,9 +4823,9 @@ fn should_convert_relaxed_dst( }
ScalarKind::Signed => {
if dst_type.kind() != ScalarKind::Float {
- if instr_type.width() == dst_type.width() {
+ if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
- } else if instr_type.width() < dst_type.width() {
+ } else if instr_type.size_of() < dst_type.size_of() {
Some(ConversionKind::SignExtend)
} else {
None
@@ -4570,14 +4835,17 @@ fn should_convert_relaxed_dst( }
}
ScalarKind::Unsigned => {
- if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() != ScalarKind::Float
+ {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float => {
- if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Bit {
+ if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit
+ {
Some(ConversionKind::Default)
} else {
None
@@ -4588,7 +4856,10 @@ fn should_convert_relaxed_dst( },
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
- should_convert_relaxed_dst(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type))
+ should_convert_relaxed_dst(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
}
_ => None,
}
@@ -4611,7 +4882,8 @@ impl<'a> ast::MethodDecl<'a, spirv::Word> { f(&ast::FnArgument {
align: arg.align,
name: arg.name,
- v_type: ast::FnArgumentType::Param(arg.v_type),
+ v_type: ast::FnArgumentType::Param(arg.v_type.clone()),
+ array_init: arg.array_init.clone(),
})
}),
}
@@ -4698,14 +4970,17 @@ mod tests { .collect::<Vec<_>>()
}
- fn assert_conversion_table<F: Fn(ast::Type, ast::Type) -> Option<ConversionKind>>(
+ fn assert_conversion_table<F: Fn(&ast::Type, &ast::Type) -> Option<ConversionKind>>(
table: &'static str,
f: F,
) {
let conv_table = parse_conversion_table(table);
for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() {
for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() {
- let conversion = f(ast::Type::Scalar(*op_type), ast::Type::Scalar(*instr_type));
+ let conversion = f(
+ &ast::Type::Scalar(*op_type),
+ &ast::Type::Scalar(*instr_type),
+ );
if instr_idx == op_idx {
assert_eq!(conversion, None);
} else {
|