aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-10-02 20:34:45 +0200
committerAndrzej Janik <[email protected]>2020-10-04 19:53:07 +0200
commit27d25865af2bf51ca55b223e634208234d1a141a (patch)
tree695f081f09cd22ffbc04effa0e947abba82bc50d
parent9a65dd32f5898eb9dd3edf7cdddb1513a7a754ed (diff)
downloadZLUDA-27d25865af2bf51ca55b223e634208234d1a141a.tar.gz
ZLUDA-27d25865af2bf51ca55b223e634208234d1a141a.zip
Add support for top-level global variables, improve array support
-rw-r--r--level_zero/src/ze.rs57
-rw-r--r--notcuda/src/impl/function.rs2
-rw-r--r--notcuda/src/impl/memory.rs2
-rw-r--r--notcuda/src/impl/mod.rs2
-rw-r--r--notcuda/src/impl/module.rs11
-rw-r--r--notcuda/src/impl/stream.rs4
-rw-r--r--notcuda/src/impl/test.rs2
-rw-r--r--ptx/src/ast.rs265
-rw-r--r--ptx/src/ptx.lalrpop153
-rw-r--r--ptx/src/test/spirv_run/global_array.ptx22
-rw-r--r--ptx/src/test/spirv_run/global_array.spvtxt54
-rw-r--r--ptx/src/test/spirv_run/mod.rs18
-rw-r--r--ptx/src/translate.rs871
13 files changed, 1085 insertions, 378 deletions
diff --git a/level_zero/src/ze.rs b/level_zero/src/ze.rs
index 559805e..5ced5d0 100644
--- a/level_zero/src/ze.rs
+++ b/level_zero/src/ze.rs
@@ -1,6 +1,6 @@
use crate::sys;
use std::{
- ffi::{c_void, CStr},
+ ffi::{c_void, CStr, CString},
fmt::Debug,
marker::PhantomData,
mem, ptr,
@@ -238,23 +238,16 @@ impl Drop for CommandQueue {
pub struct Module(sys::ze_module_handle_t);
impl Module {
- pub unsafe fn as_ffi(&self) -> sys::ze_module_handle_t {
- self.0
- }
- pub unsafe fn from_ffi(x: sys::ze_module_handle_t) -> Self {
- Self(x)
- }
-
pub fn new_spirv(
ctx: &mut Context,
d: &Device,
bin: &[u8],
opts: Option<&CStr>,
- ) -> Result<Self> {
+ ) -> (Result<Self>, BuildLog) {
Module::new(ctx, true, d, bin, opts)
}
- pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> Result<Self> {
+ pub fn new_native(ctx: &mut Context, d: &Device, bin: &[u8]) -> (Result<Self>, BuildLog) {
Module::new(ctx, false, d, bin, None)
}
@@ -264,7 +257,7 @@ impl Module {
d: &Device,
bin: &[u8],
opts: Option<&CStr>,
- ) -> Result<Self> {
+ ) -> (Result<Self>, BuildLog) {
let desc = sys::ze_module_desc_t {
stype: sys::ze_structure_type_t::ZE_STRUCTURE_TYPE_MODULE_DESC,
pNext: ptr::null(),
@@ -279,14 +272,14 @@ impl Module {
pConstants: ptr::null(),
};
let mut result: sys::ze_module_handle_t = ptr::null_mut();
- check!(sys::zeModuleCreate(
- ctx.0,
- d.0,
- &desc,
- &mut result,
- ptr::null_mut()
- ));
- Ok(Module(result))
+ let mut log_handle = ptr::null_mut();
+ let err = unsafe { sys::zeModuleCreate(ctx.0, d.0, &desc, &mut result, &mut log_handle) };
+ let log = BuildLog(log_handle);
+ if err != crate::sys::ze_result_t::ZE_RESULT_SUCCESS {
+ (Result::Err(err), log)
+ } else {
+ (Ok(Module(result)), log)
+ }
}
}
@@ -297,6 +290,32 @@ impl Drop for Module {
}
}
+pub struct BuildLog(sys::ze_module_build_log_handle_t);
+
+impl BuildLog {
+ pub unsafe fn as_ffi(&self) -> sys::ze_module_build_log_handle_t {
+ self.0
+ }
+ pub unsafe fn from_ffi(x: sys::ze_module_build_log_handle_t) -> Self {
+ Self(x)
+ }
+
+ pub fn get_cstring(&self) -> Result<CString> {
+ let mut size = 0;
+ check! { sys::zeModuleBuildLogGetString(self.0, &mut size, ptr::null_mut()) };
+ let mut str_vec = vec![0u8; size];
+ check! { sys::zeModuleBuildLogGetString(self.0, &mut size, str_vec.as_mut_ptr() as *mut i8) };
+ str_vec.pop();
+ Ok(CString::new(str_vec).map_err(|_| sys::ze_result_t::ZE_RESULT_ERROR_UNKNOWN)?)
+ }
+}
+
+impl Drop for BuildLog {
+ fn drop(&mut self) {
+ check_panic!(sys::zeModuleBuildLogDestroy(self.0));
+ }
+}
+
pub trait SafeRepr {}
impl SafeRepr for u8 {}
impl SafeRepr for i8 {}
diff --git a/notcuda/src/impl/function.rs b/notcuda/src/impl/function.rs
index 6f8773e..0ab3bea 100644
--- a/notcuda/src/impl/function.rs
+++ b/notcuda/src/impl/function.rs
@@ -1,7 +1,7 @@
use ::std::os::raw::{c_uint, c_void};
use std::ptr;
-use super::{context, device, stream::Stream, CUresult};
+use super::{device, stream::Stream, CUresult};
pub struct Function {
pub base: l0::Kernel<'static>,
diff --git a/notcuda/src/impl/memory.rs b/notcuda/src/impl/memory.rs
index 3f92b5e..439b26f 100644
--- a/notcuda/src/impl/memory.rs
+++ b/notcuda/src/impl/memory.rs
@@ -46,7 +46,7 @@ unsafe fn memcpy_impl(
Ok(())
}
-pub(crate) fn free_v2(mem: *mut c_void)-> l0::Result<()> {
+pub(crate) fn free_v2(_: *mut c_void)-> l0::Result<()> {
Ok(())
}
diff --git a/notcuda/src/impl/mod.rs b/notcuda/src/impl/mod.rs
index 3d31da2..5a72ce4 100644
--- a/notcuda/src/impl/mod.rs
+++ b/notcuda/src/impl/mod.rs
@@ -1,4 +1,4 @@
-use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUfunction, CUmod_st, CUmodule, CUresult, CUstream, CUstream_st};
+use crate::cuda::{CUctx_st, CUdevice, CUdeviceptr, CUfunc_st, CUmod_st, CUresult, CUstream_st};
use std::{ffi::c_void, mem::{self, ManuallyDrop}, os::raw::c_int, sync::Mutex};
#[cfg(test)]
diff --git a/notcuda/src/impl/module.rs b/notcuda/src/impl/module.rs
index fc55f33..eea862b 100644
--- a/notcuda/src/impl/module.rs
+++ b/notcuda/src/impl/module.rs
@@ -1,13 +1,10 @@
use std::{
- collections::HashMap, ffi::c_void, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice,
- sync::Mutex,
+ collections::HashMap, ffi::CStr, ffi::CString, mem, os::raw::c_char, ptr, slice, sync::Mutex,
};
use super::{function::Function, transmute_lifetime, CUresult};
use ptx;
-use super::context;
-
pub type Module = Mutex<ModuleData>;
pub struct ModuleData {
@@ -67,14 +64,14 @@ impl ModuleData {
l0::Module::new_spirv(&mut dev.l0_context, &dev.base, byte_il, None)
});
match module {
- Ok(Ok(module)) => Ok(Mutex::new(Self {
+ Ok((Ok(module), _)) => Ok(Mutex::new(Self {
base: module,
arg_lens: all_arg_lens
.into_iter()
.map(|(k, v)| (CString::new(k).unwrap(), v))
.collect(),
})),
- Ok(Err(err)) => Err(ModuleCompileError::from(err)),
+ Ok((Err(err), _)) => Err(ModuleCompileError::from(err)),
Err(err) => Err(ModuleCompileError::from(err)),
}
}
@@ -116,6 +113,6 @@ pub fn get_function(
Ok(())
}
-pub(crate) fn unload(decuda: *mut Module) -> Result<(), CUresult> {
+pub(crate) fn unload(_: *mut Module) -> Result<(), CUresult> {
Ok(())
}
diff --git a/notcuda/src/impl/stream.rs b/notcuda/src/impl/stream.rs
index 7410100..1844677 100644
--- a/notcuda/src/impl/stream.rs
+++ b/notcuda/src/impl/stream.rs
@@ -30,7 +30,7 @@ mod tests {
use super::super::test::CudaDriverFns;
use super::super::CUresult;
- use std::{ffi::c_void, ptr};
+ use std::ptr;
const CU_STREAM_LEGACY: CUstream = 1 as *mut _;
const CU_STREAM_PER_THREAD: CUstream = 2 as *mut _;
@@ -41,7 +41,7 @@ mod tests {
fn default_stream_uses_current_ctx_legacy<T: CudaDriverFns>() {
default_stream_uses_current_ctx_impl::<T>(CU_STREAM_LEGACY);
}
-
+
fn default_stream_uses_current_ctx_ptsd<T: CudaDriverFns>() {
default_stream_uses_current_ctx_impl::<T>(CU_STREAM_PER_THREAD);
}
diff --git a/notcuda/src/impl/test.rs b/notcuda/src/impl/test.rs
index d4366b7..dbd2eff 100644
--- a/notcuda/src/impl/test.rs
+++ b/notcuda/src/impl/test.rs
@@ -1,6 +1,6 @@
#![allow(non_snake_case)]
-use crate::{cuda::CUcontext, cuda::CUstream, r#impl as notcuda};
+use crate::{cuda::CUstream, r#impl as notcuda};
use crate::r#impl::CUresult;
use crate::{cuda::CUuuid, r#impl::Encuda};
use ::std::{
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 048d43a..c6510da 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -1,6 +1,8 @@
-use std::convert::From;
+use std::{convert::From, mem, num::ParseFloatError, str::FromStr};
use std::{marker::PhantomData, num::ParseIntError};
+use half::f16;
+
quick_error! {
#[derive(Debug)]
pub enum PtxError {
@@ -9,11 +11,17 @@ quick_error! {
display("{}", err)
cause(err)
}
+ ParseFloat (err: ParseFloatError) {
+ from()
+ display("{}", err)
+ cause(err)
+ }
SyntaxError {}
NonF32Ftz {}
WrongArrayType {}
WrongVectorElement {}
MultiArrayVariable {}
+ ZeroDimensionArray {}
}
}
@@ -53,7 +61,7 @@ macro_rules! sub_scalar_type {
macro_rules! sub_type {
($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
- #[derive(PartialEq, Eq, Clone, Copy)]
+ #[derive(PartialEq, Eq, Clone)]
pub enum $type_name {
$(
$variant ($($field_type),+),
@@ -80,11 +88,13 @@ sub_type! {
}
}
+type VecU32 = Vec<u32>;
+
sub_type! {
VariableLocalType {
Scalar(SizedScalarType),
Vector(SizedScalarType, u8),
- Array(SizedScalarType, u32),
+ Array(SizedScalarType, VecU32),
}
}
@@ -95,7 +105,7 @@ sub_type! {
sub_type! {
VariableParamType {
Scalar(ParamScalarType),
- Array(SizedScalarType, u32),
+ Array(SizedScalarType, VecU32),
}
}
@@ -169,7 +179,12 @@ impl<
pub struct Module<'a> {
pub version: (u8, u8),
- pub functions: Vec<ParsedFunction<'a>>,
+ pub directives: Vec<Directive<'a, ParsedArgParams<'a>>>,
+}
+
+pub enum Directive<'a, P: ArgParams> {
+ Variable(Variable<VariableType, P::Id>),
+ Method(Function<'a, &'a str, Statement<P>>),
}
pub enum MethodDecl<'a, ID> {
@@ -187,7 +202,7 @@ pub struct Function<'a, ID, S> {
pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>;
-#[derive(PartialEq, Eq, Clone, Copy)]
+#[derive(PartialEq, Eq, Clone)]
pub enum FnArgumentType {
Reg(VariableRegType),
Param(VariableParamType),
@@ -202,11 +217,11 @@ impl From<FnArgumentType> for Type {
}
}
-#[derive(PartialEq, Eq, Hash, Clone, Copy)]
+#[derive(PartialEq, Eq, Hash, Clone)]
pub enum Type {
Scalar(ScalarType),
Vector(ScalarType, u8),
- Array(ScalarType, u32),
+ Array(ScalarType, Vec<u32>),
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
@@ -274,6 +289,30 @@ sub_scalar_type!(FloatType {
F64
});
+impl ScalarType {
+ pub fn size_of(self) -> u8 {
+ match self {
+ ScalarType::U8 => 1,
+ ScalarType::S8 => 1,
+ ScalarType::B8 => 1,
+ ScalarType::U16 => 2,
+ ScalarType::S16 => 2,
+ ScalarType::B16 => 2,
+ ScalarType::F16 => 2,
+ ScalarType::U32 => 4,
+ ScalarType::S32 => 4,
+ ScalarType::B32 => 4,
+ ScalarType::F32 => 4,
+ ScalarType::U64 => 8,
+ ScalarType::S64 => 8,
+ ScalarType::B64 => 8,
+ ScalarType::F64 => 8,
+ ScalarType::F16x2 => 4,
+ ScalarType::Pred => 1,
+ }
+ }
+}
+
impl Default for ScalarType {
fn default() -> Self {
ScalarType::B8
@@ -296,13 +335,26 @@ pub struct Variable<T, ID> {
pub align: Option<u32>,
pub v_type: T,
pub name: ID,
+ pub array_init: Vec<u8>,
}
-#[derive(Eq, PartialEq, Copy, Clone)]
+#[derive(Eq, PartialEq, Clone)]
pub enum VariableType {
Reg(VariableRegType),
Local(VariableLocalType),
Param(VariableParamType),
+ Global(VariableLocalType),
+}
+
+impl VariableType {
+ pub fn to_type(&self) -> (StateSpace, Type) {
+ match self {
+ VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()),
+ VariableType::Local(t) => (StateSpace::Local, t.clone().into()),
+ VariableType::Param(t) => (StateSpace::Param, t.clone().into()),
+ VariableType::Global(t) => (StateSpace::Global, t.clone().into()),
+ }
+ }
}
impl From<VariableType> for Type {
@@ -311,6 +363,7 @@ impl From<VariableType> for Type {
VariableType::Reg(t) => t.into(),
VariableType::Local(t) => t.into(),
VariableType::Param(t) => t.into(),
+ VariableType::Global(t) => t.into(),
}
}
}
@@ -318,7 +371,6 @@ impl From<VariableType> for Type {
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum StateSpace {
Reg,
- Sreg,
Const,
Global,
Local,
@@ -538,7 +590,7 @@ pub enum LdCacheOperator {
Uncached,
}
-#[derive(Copy, Clone)]
+#[derive(Clone)]
pub struct MovDetails {
pub typ: Type,
pub src_is_address: bool,
@@ -846,3 +898,194 @@ pub struct MinMaxFloat {
pub nan: bool,
pub typ: FloatType,
}
+
+pub enum NumsOrArrays<'a> {
+ Nums(Vec<&'a str>),
+ Arrays(Vec<NumsOrArrays<'a>>),
+}
+
+impl<'a> NumsOrArrays<'a> {
+ pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
+ self.normalize_dimensions(dimensions)?;
+ let sizeof_t = ScalarType::from(typ).size_of() as usize;
+ let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize));
+ let mut result = vec![0; result_size];
+ self.parse_and_copy(typ, sizeof_t, dimensions, &mut result)?;
+ Ok(result)
+ }
+
+ fn normalize_dimensions(&self, dimensions: &mut [u32]) -> Result<(), PtxError> {
+ match dimensions.first_mut() {
+ Some(first) => {
+ if *first == 0 {
+ *first = match self {
+ NumsOrArrays::Nums(v) => v.len() as u32,
+ NumsOrArrays::Arrays(v) => v.len() as u32,
+ };
+ }
+ }
+ None => return Err(PtxError::ZeroDimensionArray),
+ }
+ for dim in dimensions {
+ if *dim == 0 {
+ return Err(PtxError::ZeroDimensionArray);
+ }
+ }
+ Ok(())
+ }
+
+ fn parse_and_copy(
+ &self,
+ t: SizedScalarType,
+ size_of_t: usize,
+ dimensions: &[u32],
+ result: &mut [u8],
+ ) -> Result<(), PtxError> {
+ match dimensions {
+ [] => unreachable!(),
+ [dim] => match self {
+ NumsOrArrays::Nums(vec) => {
+ if vec.len() > *dim as usize {
+ return Err(PtxError::ZeroDimensionArray);
+ }
+ for (idx, val) in vec.iter().enumerate() {
+ Self::parse_and_copy_single(t, idx, val, result)?;
+ }
+ }
+ NumsOrArrays::Arrays(_) => return Err(PtxError::ZeroDimensionArray),
+ },
+ [first_dim, rest @ ..] => match self {
+ NumsOrArrays::Arrays(vec) => {
+ if vec.len() > *first_dim as usize {
+ return Err(PtxError::ZeroDimensionArray);
+ }
+ let size_of_element = rest.iter().fold(size_of_t, |x, y| x * (*y as usize));
+ for (idx, this) in vec.iter().enumerate() {
+ this.parse_and_copy(
+ t,
+ size_of_t,
+ rest,
+ &mut result[(size_of_element * idx)..],
+ )?;
+ }
+ }
+ NumsOrArrays::Nums(_) => return Err(PtxError::ZeroDimensionArray),
+ },
+ }
+ Ok(())
+ }
+
+ fn parse_and_copy_single(
+ t: SizedScalarType,
+ idx: usize,
+ str_val: &str,
+ output: &mut [u8],
+ ) -> Result<(), PtxError> {
+ match t {
+ SizedScalarType::B8 | SizedScalarType::U8 => {
+ Self::parse_and_copy_single_t::<u8>(idx, str_val, output)?;
+ }
+ SizedScalarType::B16 | SizedScalarType::U16 => {
+ Self::parse_and_copy_single_t::<u16>(idx, str_val, output)?;
+ }
+ SizedScalarType::B32 | SizedScalarType::U32 => {
+ Self::parse_and_copy_single_t::<u32>(idx, str_val, output)?;
+ }
+ SizedScalarType::B64 | SizedScalarType::U64 => {
+ Self::parse_and_copy_single_t::<u64>(idx, str_val, output)?;
+ }
+ SizedScalarType::S8 => {
+ Self::parse_and_copy_single_t::<i8>(idx, str_val, output)?;
+ }
+ SizedScalarType::S16 => {
+ Self::parse_and_copy_single_t::<i16>(idx, str_val, output)?;
+ }
+ SizedScalarType::S32 => {
+ Self::parse_and_copy_single_t::<i32>(idx, str_val, output)?;
+ }
+ SizedScalarType::S64 => {
+ Self::parse_and_copy_single_t::<i64>(idx, str_val, output)?;
+ }
+ SizedScalarType::F16 => {
+ Self::parse_and_copy_single_t::<f16>(idx, str_val, output)?;
+ }
+ SizedScalarType::F16x2 => todo!(),
+ SizedScalarType::F32 => {
+ Self::parse_and_copy_single_t::<f32>(idx, str_val, output)?;
+ }
+ SizedScalarType::F64 => {
+ Self::parse_and_copy_single_t::<f64>(idx, str_val, output)?;
+ }
+ }
+ Ok(())
+ }
+
+ fn parse_and_copy_single_t<T: Copy + FromStr>(
+ idx: usize,
+ str_val: &str,
+ output: &mut [u8],
+ ) -> Result<(), PtxError>
+ where
+ T::Err: Into<PtxError>,
+ {
+ let typed_output = unsafe {
+ std::slice::from_raw_parts_mut::<T>(
+ output.as_mut_ptr() as *mut _,
+ output.len() / mem::size_of::<T>(),
+ )
+ };
+ typed_output[idx] = str_val.parse::<T>().map_err(|e| e.into())?;
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn array_fails_multiple_0_dmiensions() {
+ let inp = NumsOrArrays::Nums(Vec::new());
+ assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err());
+ }
+
+ #[test]
+ fn array_fails_on_empty() {
+ let inp = NumsOrArrays::Nums(Vec::new());
+ assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err());
+ }
+
+ #[test]
+ fn array_auto_sizes_0_dimension() {
+ let inp = NumsOrArrays::Arrays(vec![
+ NumsOrArrays::Nums(vec!["1", "2"]),
+ NumsOrArrays::Nums(vec!["3", "4"]),
+ ]);
+ let mut dimensions = vec![0u32, 2];
+ assert_eq!(
+ vec![1u8, 2, 3, 4],
+ inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap()
+ );
+ assert_eq!(dimensions, vec![2u32, 2]);
+ }
+
+ #[test]
+ fn array_fails_wrong_structure() {
+ let inp = NumsOrArrays::Arrays(vec![
+ NumsOrArrays::Nums(vec!["1", "2"]),
+ NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec!["1"])]),
+ ]);
+ let mut dimensions = vec![0u32, 2];
+ assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
+ }
+
+ #[test]
+ fn array_fails_too_long_component() {
+ let inp = NumsOrArrays::Arrays(vec![
+ NumsOrArrays::Nums(vec!["1", "2", "3"]),
+ NumsOrArrays::Nums(vec!["4", "5"]),
+ ]);
+ let mut dimensions = vec![0u32, 2];
+ assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
+ }
+}
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 2c0e365..0b6fa0f 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -2,6 +2,8 @@ use crate::ast;
use crate::ast::UnwrapWithVec;
use crate::{without_none, vector_index};
+use lalrpop_util::ParseError;
+
grammar<'a>(errors: &mut Vec<ast::PtxError>);
extern {
@@ -27,6 +29,7 @@ match {
"{", "}",
"<", ">",
"|",
+ "=",
".acquire",
".address_size",
".align",
@@ -94,7 +97,6 @@ match {
".sat",
".section",
".shared",
- ".sreg",
".sys",
".target",
".to",
@@ -176,8 +178,8 @@ ExtendedID : &'input str = {
}
pub Module: ast::Module<'input> = {
- <v:Version> Target <f:Directive*> => {
- ast::Module { version: v, functions: without_none(f) }
+ <v:Version> Target <d:Directive*> => {
+ ast::Module { version: v, directives: without_none(d) }
}
};
@@ -203,11 +205,12 @@ TargetSpecifier = {
"map_f64_to_f32"
};
-Directive: Option<ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>>> = {
+Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = {
AddressSize => None,
- <f:Function> => Some(f),
+ <f:Function> => Some(ast::Directive::Method(f)),
File => None,
- Section => None
+ Section => None,
+ <v:GlobalVariable> ";" => Some(ast::Directive::Variable(v)),
};
AddressSize = {
@@ -242,9 +245,9 @@ FnArguments: Vec<ast::FnArgument<&'input str>> = {
};
KernelInput: ast::Variable<ast::VariableParamType, &'input str> = {
- <v:ParamVariable> => {
+ <v:ParamDeclaration> => {
let (align, v_type, name) = v;
- ast::Variable{ align, v_type, name }
+ ast::Variable{ align, v_type, name, array_init: Vec::new() }
}
}
@@ -252,12 +255,12 @@ FnInput: ast::Variable<ast::FnArgumentType, &'input str> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
let v_type = ast::FnArgumentType::Reg(v_type);
- ast::Variable{ align, v_type, name }
+ ast::Variable{ align, v_type, name, array_init: Vec::new() }
},
- <v:ParamVariable> => {
+ <v:ParamDeclaration> => {
let (align, v_type, name) = v;
let v_type = ast::FnArgumentType::Param(v_type);
- ast::Variable{ align, v_type, name }
+ ast::Variable{ align, v_type, name, array_init: Vec::new() }
}
}
@@ -268,7 +271,6 @@ pub(crate) FunctionBody: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>
StateSpaceSpecifier: ast::StateSpace = {
".reg" => ast::StateSpace::Reg,
- ".sreg" => ast::StateSpace::Sreg,
".const" => ast::StateSpace::Const,
".global" => ast::StateSpace::Global,
".local" => ast::StateSpace::Local,
@@ -344,13 +346,13 @@ Variable: ast::Variable<ast::VariableType, &'input str> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
let v_type = ast::VariableType::Reg(v_type);
- ast::Variable {align, v_type, name}
+ ast::Variable {align, v_type, name, array_init: Vec::new()}
},
LocalVariable,
<v:ParamVariable> => {
- let (align, v_type, name) = v;
+ let (align, array_init, v_type, name) = v;
let v_type = ast::VariableType::Param(v_type);
- ast::Variable {align, v_type, name}
+ ast::Variable {align, v_type, name, array_init}
},
};
@@ -366,32 +368,60 @@ RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
}
LocalVariable: ast::Variable<ast::VariableType, &'input str> = {
- ".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
- let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
- ast::Variable {align, v_type, name}
- },
- ".local" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
- let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
- ast::Variable {align, v_type, name}
- },
- ".local" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
- let v_type = ast::VariableType::Local(ast::VariableLocalType::Array(t, arr));
- ast::Variable {align, v_type, name}
+ ".local" <def:LocalVariableDefinition> => {
+ let (align, array_init, v_type, name) = def;
+ ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init }
+ }
+}
+
+GlobalVariable: ast::Variable<ast::VariableType, &'input str> = {
+ ".global" <def:LocalVariableDefinition> => {
+ let (align, array_init, v_type, name) = def;
+ ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init }
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
-ParamVariable: (Option<u32>, ast::VariableParamType, &'input str) = {
+ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
+ ".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
+ let v_type = ast::VariableParamType::Scalar(t);
+ (align, Vec::new(), v_type, name)
+ },
+ ".param" <align:Align?> <arr:ArrayDefinition> => {
+ let (array_init, name, (t, dimensions)) = arr;
+ let v_type = ast::VariableParamType::Array(t, dimensions);
+ (align, array_init, v_type, name)
+ }
+}
+
+ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
let v_type = ast::VariableParamType::Scalar(t);
(align, v_type, name)
},
- ".param" <align:Align?> <t:SizedScalarType> <name:ExtendedID> <arr:ArraySpecifier> => {
- let v_type = ast::VariableParamType::Array(t, arr);
+ ".param" <align:Align?> <arr:ArrayDeclaration> => {
+ let (name, (t, dimensions)) = arr;
+ let v_type = ast::VariableParamType::Array(t, dimensions);
(align, v_type, name)
}
}
+LocalVariableDefinition: (Option<u32>, Vec<u8>, ast::VariableLocalType, &'input str) = {
+ <align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
+ let v_type = ast::VariableLocalType::Scalar(t);
+ (align, Vec::new(), v_type, name)
+ },
+ <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
+ let v_type = ast::VariableLocalType::Vector(t, v_len);
+ (align, Vec::new(), v_type, name)
+ },
+ <align:Align?> <arr:ArrayDefinition> => {
+ let (array_init, name, (t, dimensions)) = arr;
+ let v_type = ast::VariableLocalType::Array(t, dimensions);
+ (align, array_init, v_type, name)
+ }
+}
+
#[inline]
SizedScalarType: ast::SizedScalarType = {
".b8" => ast::SizedScalarType::B8,
@@ -431,12 +461,59 @@ ParamScalarType: ast::ParamScalarType = {
".f64" => ast::ParamScalarType::F64,
}
-ArraySpecifier: u32 = {
- "[" <n:Num> "]" => {
- let size = n.parse::<u32>();
- size.unwrap_with(errors)
+ArrayDefinition: (Vec<u8>, &'input str, (ast::SizedScalarType, Vec<u32>)) = {
+ <typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimensions> <init:ArrayInitializer?> =>? {
+ let mut dims = dims;
+ let array_init = init.unwrap_or(ast::NumsOrArrays::Nums(Vec::new())).to_vec(typ, &mut dims)?;
+ Ok((
+ array_init,
+ name,
+ (typ, dims)
+ ))
}
-};
+}
+
+ArrayDeclaration: (&'input str, (ast::SizedScalarType, Vec<u32>)) = {
+ <typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimension+> =>? {
+ let dims = dims.into_iter().map(|x| if x > 0 { Ok(x) } else { Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }) }).collect::<Result<_,_>>()?;
+ Ok((name, (typ, dims)))
+ }
+}
+
+// [0] and [] are treated the same
+ArrayDimensions: Vec<u32> = {
+ ArrayEmptyDimension => vec![0u32],
+ ArrayEmptyDimension <dims:ArrayDimension+> => {
+ let mut dims = dims;
+ let mut result = vec![0u32];
+ result.append(&mut dims);
+ result
+ },
+ <dims:ArrayDimension+> => dims
+}
+
+ArrayEmptyDimension = {
+ "[" "]"
+}
+
+ArrayDimension: u32 = {
+ "[" <n:Num> "]" =>? {
+ str::parse::<u32>(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) })
+ }
+}
+
+ArrayInitializer: ast::NumsOrArrays<'input> = {
+ "=" <nums:NumsOrArraysBracket> => nums
+}
+
+NumsOrArraysBracket: ast::NumsOrArrays<'input> = {
+ "{" <nums:NumsOrArrays> "}" => nums
+}
+
+NumsOrArrays: ast::NumsOrArrays<'input> = {
+ <n:Comma<NumsOrArraysBracket>> => ast::NumsOrArrays::Arrays(n),
+ <n:CommaNonEmpty<Num>> => ast::NumsOrArrays::Nums(n),
+}
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstLd,
@@ -1244,3 +1321,11 @@ Comma<T>: Vec<T> = {
}
}
};
+
+CommaNonEmpty<T>: Vec<T> = {
+ <v:(<T> ",")*> <e:T> => {
+ let mut v = v;
+ v.push(e);
+ v
+ }
+};
diff --git a/ptx/src/test/spirv_run/global_array.ptx b/ptx/src/test/spirv_run/global_array.ptx
new file mode 100644
index 0000000..7ac8bce
--- /dev/null
+++ b/ptx/src/test/spirv_run/global_array.ptx
@@ -0,0 +1,22 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.global .s32 foobar[4] = {1};
+
+.visible .entry global_array(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 temp;
+
+ mov.u64 in_addr, foobar;
+ ld.param.u64 out_addr, [output];
+
+ ld.global.u32 temp, [in_addr];
+ st.global.u32 [out_addr], temp;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/global_array.spvtxt b/ptx/src/test/spirv_run/global_array.spvtxt
new file mode 100644
index 0000000..a4ed91d
--- /dev/null
+++ b/ptx/src/test/spirv_run/global_array.spvtxt
@@ -0,0 +1,54 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int8
+ OpCapability Int16
+ OpCapability Int64
+ OpCapability Float16
+ OpCapability Float64
+ %22 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %2 "global_array" %1
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %uint_4 = OpConstant %uint 4
+%_arr_uint_uint_4 = OpTypeArray %uint %uint_4
+%_ptr_CrossWorkgroup__arr_uint_uint_4 = OpTypePointer CrossWorkgroup %_arr_uint_uint_4
+ %uint_4_0 = OpConstant %uint 4
+ %uint_1 = OpConstant %uint 1
+ %uint_0 = OpConstant %uint 0
+ %31 = OpConstantComposite %_arr_uint_uint_4 %uint_1 %uint_0 %uint_0 %uint_0
+ %1 = OpVariable %_ptr_CrossWorkgroup__arr_uint_uint_4 CrossWorkgroup %31
+ %ulong = OpTypeInt 64 0
+ %33 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
+ %2 = OpFunction %void None %33
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %20 = OpLabel
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ OpStore %3 %8
+ OpStore %4 %9
+ %17 = OpConvertPtrToU %ulong %1
+ %10 = OpCopyObject %ulong %17
+ OpStore %5 %10
+ %12 = OpLoad %ulong %4
+ %11 = OpCopyObject %ulong %12
+ OpStore %6 %11
+ %14 = OpLoad %ulong %5
+ %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %14
+ %13 = OpLoad %uint %18
+ OpStore %7 %13
+ %15 = OpLoad %ulong %6
+ %16 = OpLoad %uint %7
+ %19 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %15
+ OpStore %19 %16
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 8caf540..0c881d9 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -66,14 +66,18 @@ test_ptx!(b64tof64, [111u64], [111u64]);
test_ptx!(implicit_param, [34u32], [34u32]);
test_ptx!(pred_not, [10u64, 11u64], [2u64, 0u64]);
test_ptx!(mad_s32, [2i32, 3i32, 4i32], [10i32, 10i32, 10i32]);
-test_ptx!(mul_wide, [0x01_00_00_00__01_00_00_00i64], [0x1_00_00_00_00_00_00i64]);
+test_ptx!(
+ mul_wide,
+ [0x01_00_00_00__01_00_00_00i64],
+ [0x1_00_00_00_00_00_00i64]
+);
test_ptx!(vector_extract, [1u8, 2u8, 3u8, 4u8], [3u8, 4u8, 1u8, 2u8]);
test_ptx!(shr, [-2i32], [-1i32]);
test_ptx!(or, [1u64, 2u64], [3u64]);
test_ptx!(sub, [2u64], [1u64]);
test_ptx!(min, [555i32, 444i32], [444i32]);
test_ptx!(max, [555i32, 444i32], [555i32]);
-
+test_ptx!(global_array, [0xDEADu32], [1u32]);
struct DisplayError<T: Debug> {
err: T,
@@ -131,7 +135,15 @@ fn run_spirv<T: From<u8> + ze::SafeRepr + Copy + Debug>(
let mut devices = drv.devices()?;
let dev = devices.drain(0..1).next().unwrap();
let queue = ze::CommandQueue::new(&mut ctx, &dev)?;
- let module = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None)?;
+ let (module, log) = ze::Module::new_spirv(&mut ctx, &dev, byte_il, None);
+ let module = match module {
+ Ok(m) => m,
+ Err(err) => {
+ let raw_err_string = log.get_cstring()?;
+ let err_string = raw_err_string.to_string_lossy();
+ panic!("{:?}\n{}", err, err_string);
+ }
+ };
let mut kernel = ze::Kernel::new_resident(&module, name)?;
kernel.set_indirect_access(
ze::sys::ze_kernel_indirect_access_flags_t::ZE_KERNEL_INDIRECT_ACCESS_FLAG_DEVICE,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 7c15744..a86ab3c 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,7 +1,7 @@
use crate::ast;
+use half::f16;
use rspirv::{binary::Disassemble, dr};
use std::collections::{hash_map, HashMap, HashSet};
-use std::convert::TryInto;
use std::{borrow::Cow, iter, mem};
use rspirv::binary::Assemble;
@@ -26,7 +26,7 @@ quick_error! {
enum SpirvType {
Base(SpirvScalarKey),
Vector(SpirvScalarKey, u8),
- Array(SpirvScalarKey, u32),
+ Array(SpirvScalarKey, Vec<u32>),
Pointer(Box<SpirvType>, spirv::StorageClass),
Func(Option<Box<SpirvType>>, Vec<SpirvType>),
Struct(Vec<SpirvScalarKey>),
@@ -62,6 +62,7 @@ impl From<ast::ScalarType> for SpirvType {
struct TypeWordMap {
void: spirv::Word,
complex: HashMap<SpirvType, spirv::Word>,
+ constants: HashMap<(SpirvType, u64), spirv::Word>,
}
// SPIR-V integer type definitions are signless, more below:
@@ -108,6 +109,7 @@ impl TypeWordMap {
TypeWordMap {
void: void,
complex: HashMap::<SpirvType, spirv::Word>::new(),
+ constants: HashMap::new(),
}
}
@@ -154,13 +156,25 @@ impl TypeWordMap {
.entry(t)
.or_insert_with(|| b.type_vector(base, len as u32))
}
- SpirvType::Array(typ, len) => {
- let base = self.get_or_add_spirv_scalar(b, typ);
+ SpirvType::Array(typ, array_dimensions) => {
let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
- *self.complex.entry(t).or_insert_with(|| {
- let len_word = b.constant_u32(u32_type, None, len);
- b.type_array(base, len_word)
- })
+ let (base_type, length) = match &*array_dimensions {
+ &[len] => {
+ let base = self.get_or_add_spirv_scalar(b, typ);
+ let len_const = b.constant_u32(u32_type, None, len);
+ (base, len_const)
+ }
+ array_dimensions => {
+ let base = self
+ .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
+ let len_const = b.constant_u32(u32_type, None, array_dimensions[0]);
+ (base, len_const)
+ }
+ };
+ *self
+ .complex
+ .entry(SpirvType::Array(typ, array_dimensions))
+ .or_insert_with(|| b.type_array(base_type, length))
}
SpirvType::Func(ref out_params, ref in_params) => {
let out_t = match out_params {
@@ -211,16 +225,173 @@ impl TypeWordMap {
self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::<Vec<_>>())),
)
}
+
+ fn get_or_add_constant(
+ &mut self,
+ b: &mut dr::Builder,
+ typ: &ast::Type,
+ init: &[u8],
+ ) -> Result<spirv::Word, TranslateError> {
+ Ok(match typ {
+ ast::Type::Scalar(t) => match t {
+ ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self
+ .get_or_add_constant_single::<u8, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| b.constant_u32(result_type, None, v as u32),
+ ),
+ ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self
+ .get_or_add_constant_single::<u16, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| b.constant_u32(result_type, None, v as u32),
+ ),
+ ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self
+ .get_or_add_constant_single::<u32, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| b.constant_u32(result_type, None, v),
+ ),
+ ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self
+ .get_or_add_constant_single::<u64, _, _>(
+ b,
+ *t,
+ init,
+ |v| v,
+ |b, result_type, v| b.constant_u64(result_type, None, v),
+ ),
+ ast::ScalarType::F16 => self.get_or_add_constant_single::<f16, _, _>(
+ b,
+ *t,
+ init,
+ |v| unsafe { mem::transmute::<_, u16>(v) } as u64,
+ |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()),
+ ),
+ ast::ScalarType::F32 => self.get_or_add_constant_single::<f32, _, _>(
+ b,
+ *t,
+ init,
+ |v| unsafe { mem::transmute::<_, u32>(v) } as u64,
+ |b, result_type, v| b.constant_f32(result_type, None, v),
+ ),
+ ast::ScalarType::F64 => self.get_or_add_constant_single::<f64, _, _>(
+ b,
+ *t,
+ init,
+ |v| unsafe { mem::transmute::<_, u64>(v) },
+ |b, result_type, v| b.constant_f64(result_type, None, v),
+ ),
+ ast::ScalarType::F16x2 => return Err(TranslateError::Todo),
+ ast::ScalarType::Pred => self.get_or_add_constant_single::<u8, _, _>(
+ b,
+ *t,
+ init,
+ |v| v as u64,
+ |b, result_type, v| {
+ if v == 0 {
+ b.constant_false(result_type, None)
+ } else {
+ b.constant_true(result_type, None)
+ }
+ },
+ ),
+ },
+ ast::Type::Vector(typ, len) => {
+ let result_type =
+ self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len));
+ let size_of_t = typ.size_of();
+ let components = (0..*len)
+ .map(|x| {
+ self.get_or_add_constant(
+ b,
+ &ast::Type::Scalar(*typ),
+ &init[((size_of_t as usize) * (x as usize))..],
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ b.constant_composite(result_type, None, &components)
+ }
+ ast::Type::Array(typ, dims) => match dims.as_slice() {
+ [] => return Err(TranslateError::Unreachable),
+ [dim] => {
+ let result_type = self
+ .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim]));
+ let size_of_t = typ.size_of();
+ let components = (0..*dim)
+ .map(|x| {
+ self.get_or_add_constant(
+ b,
+ &ast::Type::Scalar(*typ),
+ &init[((size_of_t as usize) * (x as usize))..],
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ b.constant_composite(result_type, None, &components)
+ }
+ [first_dim, rest @ ..] => {
+ let result_type = self.get_or_add(
+ b,
+ SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()),
+ );
+ let size_of_t = rest
+ .iter()
+ .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y));
+ let components = (0..*first_dim)
+ .map(|x| {
+ self.get_or_add_constant(
+ b,
+ &ast::Type::Array(*typ, rest.to_vec()),
+ &init[((size_of_t as usize) * (x as usize))..],
+ )
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+ b.constant_composite(result_type, None, &components)
+ }
+ },
+ })
+ }
+
+ fn get_or_add_constant_single<
+ T: Copy,
+ CastAsU64: FnOnce(T) -> u64,
+ InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word,
+ >(
+ &mut self,
+ b: &mut dr::Builder,
+ key: ast::ScalarType,
+ init: &[u8],
+ cast: CastAsU64,
+ f: InsertConstant,
+ ) -> spirv::Word {
+ let value = unsafe { *(init.as_ptr() as *const T) };
+ let value_64 = cast(value);
+ let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64);
+ match self.constants.get(&ht_key) {
+ Some(value) => *value,
+ None => {
+ let spirv_type = self.get_or_add_scalar(b, key);
+ let result = f(b, spirv_type, value);
+ self.constants.insert(ht_key, result);
+ result
+ }
+ }
+ }
}
pub fn to_spirv_module<'a>(
ast: ast::Module<'a>,
) -> Result<(dr::Module, HashMap<String, Vec<usize>>), TranslateError> {
let mut id_defs = GlobalStringIdResolver::new(1);
- let ssa_functions = ast
- .functions
+ let directives = ast
+ .directives
.into_iter()
- .map(|f| to_ssa_function(&mut id_defs, f))
+ .map(|f| translate_directive(&mut id_defs, f))
.collect::<Result<Vec<_>, _>>()?;
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
@@ -233,21 +404,28 @@ pub fn to_spirv_module<'a>(
let mut map = TypeWordMap::new(&mut builder);
emit_builtins(&mut builder, &mut map, &id_defs);
let mut args_len = HashMap::new();
- for f in ssa_functions {
- let f_body = match f.body {
- Some(f) => f,
- None => continue,
- };
- emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
- emit_function_header(
- &mut builder,
- &mut map,
- &id_defs,
- f.func_directive,
- &mut args_len,
- )?;
- emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
- builder.end_function()?;
+ for d in directives {
+ match d {
+ Directive::Variable(var) => {
+ emit_variable(&mut builder, &mut map, &var)?;
+ }
+ Directive::Method(f) => {
+ let f_body = match f.body {
+ Some(f) => f,
+ None => continue,
+ };
+ emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
+ emit_function_header(
+ &mut builder,
+ &mut map,
+ &id_defs,
+ f.func_directive,
+ &mut args_len,
+ )?;
+ emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
+ builder.end_function()?;
+ }
+ }
}
Ok((builder.module(), args_len))
}
@@ -294,12 +472,18 @@ fn emit_function_header<'a>(
let fn_id = match func_directive {
ast::MethodDecl::Kernel(name, _) => {
let fn_id = global.get_id(name)?;
- let interface = global
+ let mut global_variables = global
+ .variables_type_check
+ .iter()
+ .filter_map(|(k, t)| t.as_ref().map(|_| *k))
+ .collect::<Vec<_>>();
+ let mut interface = global
.special_registers
.iter()
.map(|(_, id)| *id)
.collect::<Vec<_>>();
- builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, interface);
+ global_variables.append(&mut interface);
+ builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
fn_id
}
ast::MethodDecl::Func(_, name, _) => name,
@@ -311,7 +495,7 @@ fn emit_function_header<'a>(
func_type,
)?;
func_directive.visit_args(&mut |arg| {
- let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type).into());
+ let result_type = map.get_or_add(builder, ast::Type::from(arg.v_type.clone()).into());
let inst = dr::Instruction::new(
spirv::Op::FunctionParameter,
Some(result_type),
@@ -355,7 +539,30 @@ fn emit_memory_model(builder: &mut dr::Builder) {
);
}
-fn to_ssa_function<'a>(
+fn translate_directive<'input>(
+ id_defs: &mut GlobalStringIdResolver<'input>,
+ d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
+) -> Result<Directive<'input>, TranslateError> {
+ Ok(match d {
+ ast::Directive::Variable(v) => Directive::Variable(translate_variable(id_defs, v)?),
+ ast::Directive::Method(f) => Directive::Method(translate_function(id_defs, f)?),
+ })
+}
+
+fn translate_variable<'a>(
+ id_defs: &mut GlobalStringIdResolver<'a>,
+ var: ast::Variable<ast::VariableType, &'a str>,
+) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
+ let (state_space, typ) = var.v_type.to_type();
+ Ok(ast::Variable {
+ align: var.align,
+ v_type: var.v_type,
+ name: id_defs.get_or_add_def_typed(var.name, (state_space.into(), typ)),
+ array_init: var.array_init,
+ })
+}
+
+fn translate_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
f: ast::ParsedFunction<'a>,
) -> Result<Function<'a>, TranslateError> {
@@ -368,9 +575,13 @@ fn expand_kernel_params<'a, 'b>(
args: impl Iterator<Item = &'b ast::KernelArgument<&'a str>>,
) -> Vec<ast::KernelArgument<spirv::Word>> {
args.map(|a| ast::KernelArgument {
- name: fn_resolver.add_def(a.name, Some((StateSpace::Param, ast::Type::from(a.v_type)))),
- v_type: a.v_type,
+ name: fn_resolver.add_def(
+ a.name,
+ Some((StateSpace::Param, ast::Type::from(a.v_type.clone()))),
+ ),
+ v_type: a.v_type.clone(),
align: a.align,
+ array_init: Vec::new(),
})
.collect()
}
@@ -385,9 +596,10 @@ fn expand_fn_params<'a, 'b>(
ast::FnArgumentType::Param(_) => StateSpace::Param,
};
ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type)))),
- v_type: a.v_type,
+ name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type.clone())))),
+ v_type: a.v_type.clone(),
align: a.align,
+ array_init: Vec::new(),
}
})
.collect()
@@ -628,7 +840,7 @@ fn to_resolved_fn_args<T>(
params
.into_iter()
.zip(params_decl.iter())
- .map(|(id, typ)| (id, *typ))
+ .map(|(id, typ)| (id, typ.clone()))
.collect::<Vec<_>>()
}
@@ -719,12 +931,13 @@ fn insert_mem_ssa_statements<'a, 'b>(
let out_param = match &mut f_args {
ast::MethodDecl::Kernel(_, in_params) => {
for p in in_params.iter_mut() {
- let typ = ast::Type::from(p.v_type);
- let new_id = id_def.new_id(typ);
+ let typ = ast::Type::from(p.v_type.clone());
+ let new_id = id_def.new_id(typ.clone());
result.push(Statement::Variable(ast::Variable {
align: p.align,
- v_type: ast::VariableType::Param(p.v_type),
+ v_type: ast::VariableType::Param(p.v_type.clone()),
name: p.name,
+ array_init: p.array_init.clone(),
}));
result.push(Statement::StoreVar(
ast::Arg2St {
@@ -739,20 +952,21 @@ fn insert_mem_ssa_statements<'a, 'b>(
}
ast::MethodDecl::Func(out_params, _, in_params) => {
for p in in_params.iter_mut() {
- let typ = ast::Type::from(p.v_type);
- let new_id = id_def.new_id(typ);
- let var_typ = ast::VariableType::from(p.v_type);
+ let typ = ast::Type::from(p.v_type.clone());
+ let new_id = id_def.new_id(typ.clone());
+ let var_typ = ast::VariableType::from(p.v_type.clone());
result.push(Statement::Variable(ast::Variable {
align: p.align,
v_type: var_typ,
name: p.name,
+ array_init: p.array_init.clone(),
}));
result.push(Statement::StoreVar(
ast::Arg2St {
src1: p.name,
src2: new_id,
},
- typ,
+ typ.clone(),
));
p.name = new_id;
}
@@ -760,8 +974,9 @@ fn insert_mem_ssa_statements<'a, 'b>(
[p] => {
result.push(Statement::Variable(ast::Variable {
align: p.align,
- v_type: ast::VariableType::from(p.v_type),
+ v_type: ast::VariableType::from(p.v_type.clone()),
name: p.name,
+ array_init: p.array_init.clone(),
}));
Some(p.name)
}
@@ -779,13 +994,13 @@ fn insert_mem_ssa_statements<'a, 'b>(
ast::Instruction::Ret(d) => {
if let Some(out_param) = out_param {
let typ = id_def.get_typed(out_param)?;
- let new_id = id_def.new_id(typ);
+ let new_id = id_def.new_id(typ.clone());
result.push(Statement::LoadVar(
ast::Arg2 {
dst: new_id,
src: out_param,
},
- typ,
+ typ.clone(),
));
result.push(Statement::RetValue(d, new_id));
} else {
@@ -824,7 +1039,7 @@ trait VisitVariable: Sized {
'a,
F: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
@@ -835,7 +1050,7 @@ trait VisitVariableExpanded {
fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
@@ -861,7 +1076,7 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
| (t, ArgumentSemantics::DefaultRelaxed)
| (t, ArgumentSemantics::PhysicalPointer) => t,
};
- let generated_id = id_def.new_id(id_type);
+ let generated_id = id_def.new_id(id_type.clone());
if !desc.is_dst {
result.push(Statement::LoadVar(
Arg2 {
@@ -909,10 +1124,12 @@ fn expand_arguments<'a, 'b>(
align,
v_type,
name,
+ array_init,
}) => result.push(Statement::Variable(ast::Variable {
align,
v_type,
name,
+ array_init,
})),
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Conditional(bra) => result.push(Statement::Conditional(bra)),
@@ -969,7 +1186,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn reg(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- _: Option<ast::Type>,
+ _: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
Ok(desc.op)
}
@@ -977,8 +1194,9 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn reg_offset(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, i32)>,
- mut typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
+ let mut typ = typ.clone();
let (reg, offset) = desc.op;
match desc.sema {
ArgumentSemantics::Default
@@ -997,7 +1215,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
ScalarKind::Pred => return Err(TranslateError::MismatchedType),
};
- (scalar_t.width(), kind)
+ (scalar_t.size_of(), kind)
}
_ => return Err(TranslateError::MismatchedType),
};
@@ -1009,7 +1227,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
} else {
ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
};
- let id_constant_stmt = self.id_def.new_id(typ);
+ let id_constant_stmt = self.id_def.new_id(typ.clone());
let result_id = self.id_def.new_id(typ);
// TODO: check for edge cases around min value/max value/wrapping
if offset < 0 && kind != ScalarKind::Signed {
@@ -1060,10 +1278,10 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn immediate(
&mut self,
desc: ArgumentDescriptor<u32>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
+ *scalar
} else {
todo!()
};
@@ -1098,14 +1316,14 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn vector(
&mut self,
desc: ArgumentDescriptor<&Vec<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
let (scalar_type, vec_len) = typ.get_vector()?;
if !desc.is_dst {
- let mut new_id = self.id_def.new_id(typ);
- self.func.push(Statement::Undef(typ, new_id));
+ let mut new_id = self.id_def.new_id(typ.clone());
+ self.func.push(Statement::Undef(typ.clone(), new_id));
for (idx, id) in desc.op.iter().enumerate() {
- let newer_id = self.id_def.new_id(typ);
+ let newer_id = self.id_def.new_id(typ.clone());
self.func.push(Statement::Instruction(ast::Instruction::Mov(
ast::MovDetails {
typ: ast::Type::Scalar(scalar_type),
@@ -1124,7 +1342,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
}
Ok(new_id)
} else {
- let new_id = self.id_def.new_id(typ);
+ let new_id = self.id_def.new_id(typ.clone());
for (idx, id) in desc.op.iter().enumerate() {
Self::insert_composite_read(
&mut self.post_stmts,
@@ -1144,7 +1362,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<ast::Type>,
+ t: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self.reg(desc, t)
}
@@ -1152,7 +1370,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::Operand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
@@ -1166,7 +1384,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::CallOperand::Reg(reg) => self.reg(desc.new_op(reg), Some(typ)),
@@ -1185,7 +1403,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::IdOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
@@ -1196,7 +1414,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::OperandOrVector::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
@@ -1236,8 +1454,8 @@ fn insert_implicit_conversions(
None,
)?,
Statement::Instruction(inst) => {
- let mut default_conversion_fn = should_bitcast_wrapper
- as fn(_, _, _) -> Result<Option<ConversionKind>, TranslateError>;
+ let mut default_conversion_fn =
+ should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _;
let mut state_space = None;
if let ast::Instruction::Ld(d, _) = &inst {
state_space = Some(d.state_space);
@@ -1281,9 +1499,9 @@ fn insert_implicit_conversions_impl(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
stmt: impl VisitVariableExpanded,
- default_conversion_fn: fn(
- ast::Type,
- ast::Type,
+ default_conversion_fn: for<'a> fn(
+ &'a ast::Type,
+ &'a ast::Type,
Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError>,
state_space: Option<ast::LdStateSpace>,
@@ -1315,16 +1533,16 @@ fn insert_implicit_conversions_impl(
conversion_fn = force_bitcast_ptr_to_bit;
}
};
- match conversion_fn(operand_type, instr_type, state_space)? {
+ match conversion_fn(&operand_type, instr_type, state_space)? {
Some(conv_kind) => {
let conv_output = if desc.is_dst {
&mut post_conv
} else {
&mut *func
};
- let mut from = instr_type;
+ let mut from = instr_type.clone();
let mut to = operand_type;
- let mut src = id_def.new_id(instr_type);
+ let mut src = id_def.new_id(instr_type.clone());
let mut dst = desc.op;
let result = Ok(src);
if !desc.is_dst {
@@ -1358,17 +1576,17 @@ fn get_function_type(
builder,
out_params
.iter()
- .map(|p| SpirvType::from(ast::Type::from(p.v_type))),
+ .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))),
in_params
.iter()
- .map(|p| SpirvType::from(ast::Type::from(p.v_type))),
+ .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))),
),
ast::MethodDecl::Kernel(_, params) => map.get_or_add_fn(
builder,
iter::empty(),
params
.iter()
- .map(|p| SpirvType::from(ast::Type::from(p.v_type))),
+ .map(|p| SpirvType::from(ast::Type::from(p.v_type.clone()))),
),
}
}
@@ -1398,7 +1616,7 @@ fn emit_function_body_ops(
Statement::Call(call) => {
let (result_type, result_id) = match &*call.ret_params {
[(id, typ)] => (
- map.get_or_add(builder, SpirvType::from(ast::Type::from(*typ))),
+ map.get_or_add(builder, SpirvType::from(ast::Type::from(typ.clone()))),
Some(*id),
),
[] => (map.void(), None),
@@ -1411,28 +1629,8 @@ fn emit_function_body_ops(
.collect::<Vec<_>>();
builder.function_call(result_type, result_id, call.func, arg_list)?;
}
- Statement::Variable(ast::Variable {
- align,
- v_type,
- name,
- }) => {
- let st_class = match v_type {
- ast::VariableType::Reg(_)
- | ast::VariableType::Param(_)
- | ast::VariableType::Local(_) => spirv::StorageClass::Function,
- };
- let type_id = map.get_or_add(
- builder,
- SpirvType::new_pointer(ast::Type::from(*v_type), st_class),
- );
- builder.variable(type_id, Some(*name), st_class, None);
- if let Some(align) = align {
- builder.decorate(
- *name,
- spirv::Decoration::Alignment,
- &[dr::Operand::LiteralInt32(*align)],
- );
- }
+ Statement::Variable(var) => {
+ emit_variable(builder, map, var)?;
}
Statement::Constant(cnst) => {
let typ_id = map.get_or_add_scalar(builder, cnst.typ);
@@ -1479,13 +1677,14 @@ fn emit_function_body_ops(
if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
- let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
+ let result_type = map.get_or_add(builder, SpirvType::from(data.typ.clone()));
match data.state_space {
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
builder.load(result_type, Some(arg.dst), arg.src, None, [])?;
}
ast::LdStateSpace::Param | ast::LdStateSpace::Local => {
- let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
+ let result_type =
+ map.get_or_add(builder, SpirvType::from(data.typ.clone()));
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
_ => todo!(),
@@ -1498,7 +1697,8 @@ fn emit_function_body_ops(
if data.state_space == ast::StStateSpace::Param
|| data.state_space == ast::StStateSpace::Local
{
- let result_type = map.get_or_add(builder, SpirvType::from(data.typ));
+ let result_type =
+ map.get_or_add(builder, SpirvType::from(data.typ.clone()));
builder.copy_object(result_type, Some(arg.src1), arg.src2)?;
} else if data.state_space == ast::StStateSpace::Generic
|| data.state_space == ast::StStateSpace::Global
@@ -1513,8 +1713,8 @@ fn emit_function_body_ops(
ast::Instruction::Mov(d, arg) => match arg {
ast::Arg2Mov::Normal(ast::Arg2MovNormal { dst, src })
| ast::Arg2Mov::Member(ast::Arg2MovMember::Src(dst, src)) => {
- let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ)));
+ let result_type = map
+ .get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone())));
builder.copy_object(result_type, Some(*dst), *src)?;
}
ast::Arg2Mov::Member(ast::Arg2MovMember::Dst(
@@ -1645,7 +1845,7 @@ fn emit_function_body_ops(
}
},
Statement::LoadVar(arg, typ) => {
- let type_id = map.get_or_add(builder, SpirvType::from(*typ));
+ let type_id = map.get_or_add(builder, SpirvType::from(typ.clone()));
builder.load(type_id, Some(arg.dst), arg.src, None, [])?;
}
Statement::StoreVar(arg, _) => {
@@ -1665,7 +1865,7 @@ fn emit_function_body_ops(
)?;
}
Statement::Undef(t, id) => {
- let result_type = map.get_or_add(builder, SpirvType::from(*t));
+ let result_type = map.get_or_add(builder, SpirvType::from(t.clone()));
builder.undef(result_type, Some(*id));
}
}
@@ -1673,6 +1873,41 @@ fn emit_function_body_ops(
Ok(())
}
+fn emit_variable(
+ builder: &mut dr::Builder,
+ map: &mut TypeWordMap,
+ var: &ast::Variable<ast::VariableType, spirv::Word>,
+) -> Result<(), TranslateError> {
+ let (should_init, st_class) = match var.v_type {
+ ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
+ (false, spirv::StorageClass::Function)
+ }
+ ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
+ };
+ let type_id = map.get_or_add(
+ builder,
+ SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class),
+ );
+ let initalizer = if should_init {
+ Some(map.get_or_add_constant(
+ builder,
+ &ast::Type::from(var.v_type.clone()),
+ &*var.array_init,
+ )?)
+ } else {
+ None
+ };
+ builder.variable(type_id, Some(var.name), st_class, initalizer);
+ if let Some(align) = var.align {
+ builder.decorate(
+ var.name,
+ spirv::Decoration::Alignment,
+ &[dr::Operand::LiteralInt32(align)],
+ );
+ }
+ Ok(())
+}
+
fn emit_mad_uint(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1876,7 +2111,7 @@ fn emit_cvt(
dst: new_dst,
from: ast::Type::Scalar(src_t),
to: ast::Type::Scalar(ast::ScalarType::from_parts(
- dest_t.width(),
+ dest_t.size_of(),
src_t.kind(),
)),
kind: ConversionKind::Default,
@@ -2041,7 +2276,7 @@ fn emit_mul_sint(
]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
- let instr_width = instruction_type.width();
+ let instr_width = instruction_type.size_of();
let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
let dst_type_id = map.get_or_add_scalar(builder, dst_type);
@@ -2088,7 +2323,7 @@ fn emit_mul_uint(
]);
let mul_ext_type_id = map.get_or_add(builder, mul_ext_type);
let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?;
- let instr_width = instruction_type.width();
+ let instr_width = instruction_type.size_of();
let instr_kind = instruction_type.kind();
let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind);
let dst_type_id = map.get_or_add_scalar(builder, dst_type);
@@ -2193,13 +2428,13 @@ fn emit_implicit_conversion(
(_, _, ConversionKind::BitToPtr(space)) => {
let dst_type = map.get_or_add(
builder,
- SpirvType::Pointer(Box::new(SpirvType::from(cv.to)), space.to_spirv()),
+ SpirvType::Pointer(Box::new(SpirvType::from(cv.to.clone())), space.to_spirv()),
);
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
if from_parts.width == to_parts.width {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.to));
+ let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
if from_parts.scalar_kind != ScalarKind::Float
&& to_parts.scalar_kind != ScalarKind::Float
{
@@ -2222,7 +2457,8 @@ fn emit_implicit_conversion(
scalar_kind: ScalarKind::Bit,
..to_parts
});
- let wide_bit_type_spirv = map.get_or_add(builder, SpirvType::from(wide_bit_type));
+ let wide_bit_type_spirv =
+ map.get_or_add(builder, SpirvType::from(wide_bit_type.clone()));
if to_parts.scalar_kind == ScalarKind::Unsigned
|| to_parts.scalar_kind == ScalarKind::Bit
{
@@ -2237,7 +2473,7 @@ fn emit_implicit_conversion(
src: wide_bit_value,
dst: cv.dst,
from: wide_bit_type,
- to: cv.to,
+ to: cv.to.clone(),
kind: ConversionKind::Default,
},
)?;
@@ -2248,7 +2484,7 @@ fn emit_implicit_conversion(
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
| (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default)
| (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
- let into_type = map.get_or_add(builder, SpirvType::from(cv.to));
+ let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
}
_ => unreachable!(),
@@ -2301,23 +2537,29 @@ fn expand_map_variables<'a, 'b>(
ast::VariableType::Reg(_) => StateSpace::Reg,
ast::VariableType::Local(_) => StateSpace::Local,
ast::VariableType::Param(_) => StateSpace::ParamReg,
+ ast::VariableType::Global(_) => todo!(),
};
match var.count {
Some(count) => {
- for new_id in id_defs.add_defs(var.var.name, count, ss, var.var.v_type.into()) {
+ for new_id in
+ id_defs.add_defs(var.var.name, count, ss, var.var.v_type.clone().into())
+ {
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
- v_type: var.var.v_type,
+ v_type: var.var.v_type.clone(),
name: new_id,
+ array_init: var.var.array_init.clone(),
}))
}
}
None => {
- let new_id = id_defs.add_def(var.var.name, Some((ss, var.var.v_type.into())));
+ let new_id =
+ id_defs.add_def(var.var.name, Some((ss, var.var.v_type.clone().into())));
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
- v_type: var.var.v_type,
+ v_type: var.var.v_type.clone(),
name: new_id,
+ array_init: var.var.array_init,
}));
}
}
@@ -2367,6 +2609,7 @@ impl PtxSpecialRegister {
struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
+ variables_type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
special_registers: HashMap<PtxSpecialRegister, spirv::Word>,
fns: HashMap<spirv::Word, FnDecl>,
}
@@ -2381,13 +2624,26 @@ impl<'a> GlobalStringIdResolver<'a> {
Self {
current_id: start_id,
variables: HashMap::new(),
+ variables_type_check: HashMap::new(),
special_registers: HashMap::new(),
fns: HashMap::new(),
}
}
fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word {
- match self.variables.entry(Cow::Borrowed(id)) {
+ self.get_or_add_impl(id, None)
+ }
+
+ fn get_or_add_def_typed(&mut self, id: &'a str, typ: (StateSpace, ast::Type)) -> spirv::Word {
+ self.get_or_add_impl(id, Some(typ))
+ }
+
+ fn get_or_add_impl(
+ &mut self,
+ id: &'a str,
+ typ: Option<(StateSpace, ast::Type)>,
+ ) -> spirv::Word {
+ let id = match self.variables.entry(Cow::Borrowed(id)) {
hash_map::Entry::Occupied(e) => *(e.get()),
hash_map::Entry::Vacant(e) => {
let numeric_id = self.current_id;
@@ -2395,7 +2651,9 @@ impl<'a> GlobalStringIdResolver<'a> {
self.current_id += 1;
numeric_id
}
- }
+ };
+ self.variables_type_check.insert(id, typ);
+ id
}
fn get_id(&self, id: &str) -> Result<spirv::Word, TranslateError> {
@@ -2422,6 +2680,7 @@ impl<'a> GlobalStringIdResolver<'a> {
let mut fn_resolver = FnStringIdResolver {
current_id: &mut self.current_id,
global_variables: &self.variables,
+ global_type_check: &self.variables_type_check,
special_registers: &mut self.special_registers,
variables: vec![HashMap::new(); 1],
type_check: HashMap::new(),
@@ -2436,8 +2695,8 @@ impl<'a> GlobalStringIdResolver<'a> {
self.fns.insert(
name_id,
FnDecl {
- ret_vals: ret_params_ids.iter().map(|p| p.v_type).collect(),
- params: params_ids.iter().map(|p| p.v_type).collect(),
+ ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(),
+ params: params_ids.iter().map(|p| p.v_type.clone()).collect(),
},
);
ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
@@ -2475,6 +2734,7 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut spirv::Word,
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
+ global_type_check: &'b HashMap<u32, Option<(StateSpace, ast::Type)>>,
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
@@ -2484,6 +2744,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
fn finish(self) -> NumericIdResolver<'b> {
NumericIdResolver {
current_id: self.current_id,
+ global_type_check: self.global_type_check,
type_check: self.type_check,
special_registers: self
.special_registers
@@ -2551,7 +2812,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
.last_mut()
.unwrap()
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
- self.type_check.insert(numeric_id + i, Some((ss, typ)));
+ self.type_check
+ .insert(numeric_id + i, Some((ss, typ.clone())));
}
*self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
@@ -2560,6 +2822,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
+ global_type_check: &'b HashMap<u32, Option<(StateSpace, ast::Type)>>,
type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
special_registers: HashMap<spirv::Word, PtxSpecialRegister>,
}
@@ -2571,11 +2834,14 @@ impl<'b> NumericIdResolver<'b> {
fn get_typed(&self, id: spirv::Word) -> Result<(StateSpace, ast::Type), TranslateError> {
match self.type_check.get(&id) {
- Some(Some(x)) => Ok(*x),
+ Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(&id) {
Some(x) => Ok((StateSpace::Reg, x.get_type())),
- None => Err(TranslateError::UntypedSymbol),
+ None => match self.global_type_check.get(&id) {
+ Some(Some(x)) => Ok(x.clone()),
+ Some(None) | None => Err(TranslateError::UntypedSymbol),
+ },
},
}
}
@@ -2655,7 +2921,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(typ.into()),
+ Some(&typ.clone().into()),
)?;
Ok((new_id, typ))
})
@@ -2678,7 +2944,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
- typ.into(),
+ &typ.clone().into(),
)?;
Ok((new_id, typ))
})
@@ -2697,7 +2963,7 @@ impl VisitVariable for ResolvedCall<TypedArgParams> {
'a,
F: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
@@ -2711,7 +2977,7 @@ impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
>(
self,
@@ -2821,6 +3087,24 @@ pub enum StateSpace {
ParamReg,
}
+impl From<ast::StateSpace> for StateSpace {
+ fn from(ss: ast::StateSpace) -> Self {
+ match ss {
+ ast::StateSpace::Reg => StateSpace::Reg,
+ ast::StateSpace::Const => StateSpace::Const,
+ ast::StateSpace::Global => StateSpace::Global,
+ ast::StateSpace::Local => StateSpace::Local,
+ ast::StateSpace::Shared => StateSpace::Shared,
+ ast::StateSpace::Param => StateSpace::Param,
+ }
+ }
+}
+
+enum Directive<'input> {
+ Variable(ast::Variable<ast::VariableType, spirv::Word>),
+ Method(Function<'input>),
+}
+
struct Function<'input> {
pub func_directive: ast::MethodDecl<'input, spirv::Word>,
pub globals: Vec<ExpandedStatement>,
@@ -2831,27 +3115,27 @@ pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
fn id(
&mut self,
desc: ArgumentDescriptor<T::Id>,
- typ: Option<ast::Type>,
+ typ: Option<&ast::Type>,
) -> Result<U::Id, TranslateError>;
fn operand(
&mut self,
desc: ArgumentDescriptor<T::Operand>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<U::Operand, TranslateError>;
fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<T::IdOrVector>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<U::IdOrVector, TranslateError>;
fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<T::OperandOrVector>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<U::OperandOrVector, TranslateError>;
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<T::CallOperand>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<U::CallOperand, TranslateError>;
fn src_member_operand(
&mut self,
@@ -2864,13 +3148,13 @@ impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<ast::Type>,
+ t: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -2878,7 +3162,7 @@ where
fn operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(t))
}
@@ -2886,7 +3170,7 @@ where
fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(typ))
}
@@ -2894,7 +3178,7 @@ where
fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(typ))
}
@@ -2902,7 +3186,7 @@ where
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<spirv::Word, TranslateError> {
self(desc, Some(t))
}
@@ -2912,7 +3196,7 @@ where
desc: ArgumentDescriptor<spirv::Word>,
(scalar_type, _): (ast::ScalarType, u8),
) -> Result<spirv::Word, TranslateError> {
- self(desc.new_op(desc.op), Some(ast::Type::Scalar(scalar_type)))
+ self(desc.new_op(desc.op), Some(&ast::Type::Scalar(scalar_type)))
}
}
@@ -2923,7 +3207,7 @@ where
fn id(
&mut self,
desc: ArgumentDescriptor<&str>,
- _: Option<ast::Type>,
+ _: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self(desc.op)
}
@@ -2931,7 +3215,7 @@ where
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
- _: ast::Type,
+ _: &ast::Type,
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
match desc.op {
ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)),
@@ -2943,7 +3227,7 @@ where
fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::IdOrVector<&'a str>>,
- _: ast::Type,
+ _: &ast::Type,
) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
match desc.op {
ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(id)?)),
@@ -2956,7 +3240,7 @@ where
fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::OperandOrVector<&'a str>>,
- _: ast::Type,
+ _: &ast::Type,
) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
match desc.op {
ast::OperandOrVector::Reg(id) => Ok(ast::OperandOrVector::Reg(self(id)?)),
@@ -2973,7 +3257,7 @@ where
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<&str>>,
- _: ast::Type,
+ _: &ast::Type,
) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
match desc.op {
ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(id)?)),
@@ -3027,39 +3311,39 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
) -> Result<ast::Instruction<U>, TranslateError> {
Ok(match self {
ast::Instruction::Abs(d, arg) => {
- ast::Instruction::Abs(d, arg.map(visitor, false, ast::Type::Scalar(d.typ))?)
+ ast::Instruction::Abs(d, arg.map(visitor, false, &ast::Type::Scalar(d.typ))?)
}
// Call instruction is converted to a call statement early on
ast::Instruction::Call(_) => return Err(TranslateError::Unreachable),
ast::Instruction::Ld(d, a) => {
- let inst_type = d.typ;
let is_param = d.state_space == ast::LdStateSpace::Param
|| d.state_space == ast::LdStateSpace::Local;
- ast::Instruction::Ld(d, a.map(visitor, inst_type, is_param)?)
+ let new_args = a.map(visitor, &d.typ, is_param)?;
+ ast::Instruction::Ld(d, new_args)
}
ast::Instruction::Mov(d, a) => {
- let mapped = a.map(visitor, d)?;
+ let mapped = a.map(visitor, &d)?;
ast::Instruction::Mov(d, mapped)
}
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
let is_wide = d.is_wide();
- ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type, is_wide)?)
+ ast::Instruction::Mul(d, a.map_non_shift(visitor, &inst_type, is_wide)?)
}
ast::Instruction::Add(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type, false)?)
+ ast::Instruction::Add(d, a.map_non_shift(visitor, &inst_type, false)?)
}
ast::Instruction::Setp(d, a) => {
let inst_type = d.typ;
- ast::Instruction::Setp(d, a.map(visitor, ast::Type::Scalar(inst_type))?)
+ ast::Instruction::Setp(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
ast::Instruction::SetpBool(d, a) => {
let inst_type = d.typ;
- ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type))?)
+ ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
ast::Instruction::Not(t, a) => {
- ast::Instruction::Not(t, a.map(visitor, false, t.to_type())?)
+ ast::Instruction::Not(t, a.map(visitor, false, &t.to_type())?)
}
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
@@ -3080,46 +3364,46 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Type::Scalar(desc.src.into()),
),
};
- ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t)?)
+ ast::Instruction::Cvt(d, a.map_cvt(visitor, &dst_t, &src_t)?)
}
ast::Instruction::Shl(t, a) => {
- ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type())?)
+ ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?)
}
ast::Instruction::Shr(t, a) => {
- ast::Instruction::Shr(t, a.map_shift(visitor, ast::Type::Scalar(t.into()))?)
+ ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?)
}
ast::Instruction::St(d, a) => {
- let inst_type = d.typ;
let is_param = d.state_space == ast::StStateSpace::Param
|| d.state_space == ast::StStateSpace::Local;
- ast::Instruction::St(d, a.map(visitor, inst_type, is_param)?)
+ let new_args = a.map(visitor, &d.typ, is_param)?;
+ ast::Instruction::St(d, new_args)
}
ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
ast::Instruction::Cvta(d, a) => {
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
- ast::Instruction::Cvta(d, a.map(visitor, false, inst_type)?)
+ ast::Instruction::Cvta(d, a.map(visitor, false, &inst_type)?)
}
ast::Instruction::Mad(d, a) => {
let inst_type = d.get_type();
let is_wide = d.is_wide();
- ast::Instruction::Mad(d, a.map(visitor, inst_type, is_wide)?)
+ ast::Instruction::Mad(d, a.map(visitor, &inst_type, is_wide)?)
}
ast::Instruction::Or(t, a) => ast::Instruction::Or(
t,
- a.map_non_shift(visitor, ast::Type::Scalar(t.into()), false)?,
+ a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?,
),
ast::Instruction::Sub(d, a) => {
let typ = d.get_type();
- ast::Instruction::Sub(d, a.map_non_shift(visitor, typ, false)?)
+ ast::Instruction::Sub(d, a.map_non_shift(visitor, &typ, false)?)
}
ast::Instruction::Min(d, a) => {
let typ = d.get_type();
- ast::Instruction::Min(d, a.map_non_shift(visitor, typ, false)?)
+ ast::Instruction::Min(d, a.map_non_shift(visitor, &typ, false)?)
}
ast::Instruction::Max(d, a) => {
let typ = d.get_type();
- ast::Instruction::Max(d, a.map_non_shift(visitor, typ, false)?)
+ ast::Instruction::Max(d, a.map_non_shift(visitor, &typ, false)?)
}
})
}
@@ -3130,7 +3414,7 @@ impl VisitVariable for ast::Instruction<TypedArgParams> {
'a,
F: FnMut(
ArgumentDescriptor<spirv_headers::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv_headers::Word, TranslateError>,
>(
self,
@@ -3144,13 +3428,13 @@ impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T
where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<ast::Type>,
+ t: Option<&ast::Type>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -3158,7 +3442,7 @@ where
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
match desc.op {
ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(desc.new_op(id), Some(t))?)),
@@ -3173,7 +3457,7 @@ where
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
match desc.op {
ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(desc.new_op(id), Some(t))?)),
@@ -3184,7 +3468,7 @@ where
fn id_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::IdOrVector<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<ast::IdOrVector<spirv::Word>, TranslateError> {
match desc.op {
ast::IdOrVector::Reg(id) => Ok(ast::IdOrVector::Reg(self(desc.new_op(id), Some(typ))?)),
@@ -3199,7 +3483,7 @@ where
fn operand_or_vector(
&mut self,
desc: ArgumentDescriptor<ast::OperandOrVector<spirv::Word>>,
- typ: ast::Type,
+ typ: &ast::Type,
) -> Result<ast::OperandOrVector<spirv::Word>, TranslateError> {
match desc.op {
ast::OperandOrVector::Reg(id) => {
@@ -3226,7 +3510,7 @@ where
Ok((
self(
desc.new_op(desc.op.0),
- Some(ast::Type::Vector(scalar_type.into(), vector_len)),
+ Some(&ast::Type::Vector(scalar_type.into(), vector_len)),
)?,
desc.op.1,
))
@@ -3238,7 +3522,7 @@ impl ast::Type {
match self {
ast::Type::Scalar(scalar) => {
let kind = scalar.kind();
- let width = scalar.width();
+ let width = scalar.size_of();
if (kind != ScalarKind::Signed
&& kind != ScalarKind::Unsigned
&& kind != ScalarKind::Bit)
@@ -3255,25 +3539,25 @@ impl ast::Type {
}
}
- fn to_parts(self) -> TypeParts {
+ fn to_parts(&self) -> TypeParts {
match self {
ast::Type::Scalar(scalar) => TypeParts {
kind: TypeKind::Scalar,
scalar_kind: scalar.kind(),
- width: scalar.width(),
- components: 0,
+ width: scalar.size_of(),
+ components: Vec::new(),
},
ast::Type::Vector(scalar, components) => TypeParts {
kind: TypeKind::Vector,
scalar_kind: scalar.kind(),
- width: scalar.width(),
- components: components as u32,
+ width: scalar.size_of(),
+ components: vec![*components as u32],
},
ast::Type::Array(scalar, components) => TypeParts {
kind: TypeKind::Array,
scalar_kind: scalar.kind(),
- width: scalar.width(),
- components: components,
+ width: scalar.size_of(),
+ components: components.clone(),
},
}
}
@@ -3285,7 +3569,7 @@ impl ast::Type {
}
TypeKind::Vector => ast::Type::Vector(
ast::ScalarType::from_parts(t.width, t.scalar_kind),
- t.components as u8,
+ t.components[0] as u8,
),
TypeKind::Array => ast::Type::Array(
ast::ScalarType::from_parts(t.width, t.scalar_kind),
@@ -3295,12 +3579,12 @@ impl ast::Type {
}
}
-#[derive(Eq, PartialEq, Copy, Clone)]
+#[derive(Eq, PartialEq, Clone)]
struct TypeParts {
kind: TypeKind,
scalar_kind: ScalarKind,
width: u8,
- components: u32,
+ components: Vec<u32>,
}
#[derive(Eq, PartialEq, Copy, Clone)]
@@ -3342,7 +3626,7 @@ impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv_headers::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv_headers::Word, TranslateError>,
>(
self,
@@ -3368,7 +3652,7 @@ impl VisitVariableExpanded for CompositeRead {
fn visit_variable_extended<
F: FnMut(
ArgumentDescriptor<spirv_headers::Word>,
- Option<ast::Type>,
+ Option<&ast::Type>,
) -> Result<spirv_headers::Word, TranslateError>,
>(
self,
@@ -3384,7 +3668,7 @@ impl VisitVariableExpanded for CompositeRead {
is_dst: true,
sema: dst_sema,
},
- Some(ast::Type::Scalar(self.typ)),
+ Some(&ast::Type::Scalar(self.typ)),
)?,
src_composite: f(
ArgumentDescriptor {
@@ -3392,7 +3676,7 @@ impl VisitVariableExpanded for CompositeRead {
is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Vector(self.typ, self.src_len as u8)),
+ Some(&ast::Type::Vector(self.typ, self.src_len as u8)),
)?,
..self
}))
@@ -3411,7 +3695,7 @@ struct BrachCondition {
if_false: spirv::Word,
}
-#[derive(Copy, Clone)]
+#[derive(Clone)]
struct ImplicitConversion {
src: spirv::Word,
dst: spirv::Word,
@@ -3471,11 +3755,12 @@ impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
}
impl ast::VariableParamType {
- fn width(self) -> usize {
+ fn width(&self) -> usize {
match self {
- ast::VariableParamType::Scalar(t) => ast::ScalarType::from(t).width() as usize,
+ ast::VariableParamType::Scalar(t) => ast::ScalarType::from(*t).size_of() as usize,
ast::VariableParamType::Array(t, len) => {
- (ast::ScalarType::from(t).width() as usize) * (len as usize)
+ (ast::ScalarType::from(*t).size_of() as usize)
+ * (len.iter().fold(1, |x, y| x * (*y)) as usize)
}
}
}
@@ -3489,7 +3774,7 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<ast::Type>,
+ t: Option<&ast::Type>,
) -> Result<ast::Arg1<U>, TranslateError> {
let new_src = visitor.id(
ArgumentDescriptor {
@@ -3515,7 +3800,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
self,
visitor: &mut V,
src_is_addr: bool,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
let new_dst = visitor.id(
ArgumentDescriptor {
@@ -3546,8 +3831,8 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- dst_t: ast::Type,
- src_t: ast::Type,
+ dst_t: &ast::Type,
+ src_t: &ast::Type,
) -> Result<ast::Arg2<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
@@ -3582,7 +3867,7 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: &ast::Type,
is_param: bool,
) -> Result<ast::Arg2Ld<U>, TranslateError> {
let dst = visitor.id_or_vector(
@@ -3591,7 +3876,7 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
is_dst: true,
sema: ArgumentSemantics::DefaultRelaxed,
},
- t.into(),
+ &ast::Type::from(t.clone()),
)?;
let src = visitor.operand(
ArgumentDescriptor {
@@ -3622,7 +3907,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: &ast::Type,
is_param: bool,
) -> Result<ast::Arg2St<U>, TranslateError> {
let src1 = visitor.operand(
@@ -3653,7 +3938,7 @@ impl<T: ArgParamsEx> ast::Arg2Mov<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- details: ast::MovDetails,
+ details: &ast::MovDetails,
) -> Result<ast::Arg2Mov<U>, TranslateError> {
Ok(match self {
ast::Arg2Mov::Normal(arg) => ast::Arg2Mov::Normal(arg.map(visitor, details)?),
@@ -3675,7 +3960,7 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<P, U>>(
self,
visitor: &mut V,
- details: ast::MovDetails,
+ details: &ast::MovDetails,
) -> Result<ast::Arg2MovNormal<U>, TranslateError> {
let dst = visitor.id_or_vector(
ArgumentDescriptor {
@@ -3683,7 +3968,7 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- details.typ.into(),
+ &details.typ.clone().into(),
)?;
let src = visitor.operand_or_vector(
ArgumentDescriptor {
@@ -3695,7 +3980,7 @@ impl<P: ArgParamsEx> ast::Arg2MovNormal<P> {
ArgumentSemantics::Default
},
},
- details.typ.into(),
+ &details.typ.clone().into(),
)?;
Ok(ast::Arg2MovNormal { dst, src })
}
@@ -3733,7 +4018,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- details: ast::MovDetails,
+ details: &ast::MovDetails,
) -> Result<ast::Arg2MovMember<U>, TranslateError> {
match self {
ast::Arg2MovMember::Dst((dst, len), composite_src, scalar_src) => {
@@ -3744,7 +4029,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Vector(scalar_type, details.dst_width)),
+ Some(&ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src1 = visitor.id(
ArgumentDescriptor {
@@ -3752,7 +4037,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Vector(scalar_type, details.dst_width)),
+ Some(&ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src2 = visitor.id(
ArgumentDescriptor {
@@ -3766,7 +4051,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
ArgumentSemantics::Default
},
},
- Some(details.typ.into()),
+ Some(&details.typ.clone().into()),
)?;
Ok(ast::Arg2MovMember::Dst((dst, len), src1, src2))
}
@@ -3777,7 +4062,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(details.typ.into()),
+ Some(&details.typ.clone().into()),
)?;
let scalar_typ = details.typ.get_scalar()?;
let src = visitor.src_member_operand(
@@ -3798,7 +4083,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Vector(scalar_type, details.dst_width)),
+ Some(&ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let composite_src = visitor.id(
ArgumentDescriptor {
@@ -3806,7 +4091,7 @@ impl<T: ArgParamsEx> ast::Arg2MovMember<T> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Vector(scalar_type, details.dst_width)),
+ Some(&ast::Type::Vector(scalar_type, details.dst_width)),
)?;
let src = visitor.src_member_operand(
ArgumentDescriptor {
@@ -3838,16 +4123,21 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
fn map_non_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- typ: ast::Type,
+ typ: &ast::Type,
is_wide: bool,
) -> Result<ast::Arg3<U>, TranslateError> {
+ let wide_type = if is_wide {
+ Some(typ.clone().widen()?)
+ } else {
+ None
+ };
let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(if is_wide { typ.widen()? } else { typ }),
+ Some(wide_type.as_ref().unwrap_or(typ)),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -3871,7 +4161,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
fn map_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<ast::Arg3<U>, TranslateError> {
let dst = visitor.id(
ArgumentDescriptor {
@@ -3895,7 +4185,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
- ast::Type::Scalar(ast::ScalarType::U32),
+ &ast::Type::Scalar(ast::ScalarType::U32),
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -3914,16 +4204,21 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: &ast::Type,
is_wide: bool,
) -> Result<ast::Arg4<U>, TranslateError> {
+ let wide_type = if is_wide {
+ Some(t.clone().widen()?)
+ } else {
+ None
+ };
let dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(if is_wide { t.widen()? } else { t }),
+ Some(wide_type.as_ref().unwrap_or(t)),
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -3971,7 +4266,7 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<ast::Arg4Setp<U>, TranslateError> {
let dst1 = visitor.id(
ArgumentDescriptor {
@@ -3979,7 +4274,7 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
)?;
let dst2 = self
.dst2
@@ -3990,7 +4285,7 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
)
})
.transpose()?;
@@ -4033,7 +4328,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::Type,
+ t: &ast::Type,
) -> Result<ast::Arg5<U>, TranslateError> {
let dst1 = visitor.id(
ArgumentDescriptor {
@@ -4041,7 +4336,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
)?;
let dst2 = self
.dst2
@@ -4052,7 +4347,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
)
})
.transpose()?;
@@ -4078,7 +4373,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
is_dst: false,
sema: ArgumentSemantics::Default,
},
- ast::Type::Scalar(ast::ScalarType::Pred),
+ &ast::Type::Scalar(ast::ScalarType::Pred),
)?;
Ok(ast::Arg5 {
dst1,
@@ -4091,16 +4386,16 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
}
impl ast::Type {
- fn get_vector(self) -> Result<(ast::ScalarType, u8), TranslateError> {
+ fn get_vector(&self) -> Result<(ast::ScalarType, u8), TranslateError> {
match self {
- ast::Type::Vector(t, len) => Ok((t, len)),
+ ast::Type::Vector(t, len) => Ok((*t, *len)),
_ => Err(TranslateError::MismatchedType),
}
}
- fn get_scalar(self) -> Result<ast::ScalarType, TranslateError> {
+ fn get_scalar(&self) -> Result<ast::ScalarType, TranslateError> {
match self {
- ast::Type::Scalar(t) => Ok(t),
+ ast::Type::Scalar(t) => Ok(*t),
_ => Err(TranslateError::MismatchedType),
}
}
@@ -4141,28 +4436,6 @@ enum ScalarKind {
}
impl ast::ScalarType {
- fn width(self) -> u8 {
- match self {
- ast::ScalarType::U8 => 1,
- ast::ScalarType::S8 => 1,
- ast::ScalarType::B8 => 1,
- ast::ScalarType::U16 => 2,
- ast::ScalarType::S16 => 2,
- ast::ScalarType::B16 => 2,
- ast::ScalarType::F16 => 2,
- ast::ScalarType::U32 => 4,
- ast::ScalarType::S32 => 4,
- ast::ScalarType::B32 => 4,
- ast::ScalarType::F32 => 4,
- ast::ScalarType::U64 => 8,
- ast::ScalarType::S64 => 8,
- ast::ScalarType::B64 => 8,
- ast::ScalarType::F64 => 8,
- ast::ScalarType::F16x2 => 4,
- ast::ScalarType::Pred => 1,
- }
- }
-
fn kind(self) -> ScalarKind {
match self {
ast::ScalarType::U8 => ScalarKind::Unsigned,
@@ -4283,20 +4556,6 @@ impl ast::MinMaxDetails {
}
}
-impl ast::IntType {
- fn try_new(t: ast::ScalarType) -> Option<Self> {
- match t {
- ast::ScalarType::U16 => Some(ast::IntType::U16),
- ast::ScalarType::U32 => Some(ast::IntType::U32),
- ast::ScalarType::U64 => Some(ast::IntType::U64),
- ast::ScalarType::S16 => Some(ast::IntType::S16),
- ast::ScalarType::S32 => Some(ast::IntType::S32),
- ast::ScalarType::S64 => Some(ast::IntType::S64),
- _ => None,
- }
- }
-}
-
impl ast::SIntType {
fn from_size(width: u8) -> Self {
match width {
@@ -4372,8 +4631,8 @@ impl ast::MulDetails {
}
fn force_bitcast(
- operand: ast::Type,
- instr: ast::Type,
+ operand: &ast::Type,
+ instr: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if instr != operand {
@@ -4384,8 +4643,8 @@ fn force_bitcast(
}
fn bitcast_physical_pointer(
- operand_type: ast::Type,
- _: ast::Type,
+ operand_type: &ast::Type,
+ _: &ast::Type,
ss: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
match operand_type {
@@ -4403,17 +4662,17 @@ fn bitcast_physical_pointer(
}
fn force_bitcast_ptr_to_bit(
- _: ast::Type,
- _: ast::Type,
+ _: &ast::Type,
+ _: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
Ok(Some(ConversionKind::PtrToBit))
}
-fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
+fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
- if inst.width() != operand.width() {
+ if inst.size_of() != operand.size_of() {
return false;
}
match inst.kind() {
@@ -4431,22 +4690,22 @@ fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
}
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
| (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => {
- should_bitcast(ast::Type::Scalar(inst), ast::Type::Scalar(operand))
+ should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand))
}
_ => false,
}
}
fn should_bitcast_packed(
- operand: ast::Type,
- instr: ast::Type,
+ operand: &ast::Type,
+ instr: &ast::Type,
ss: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
(operand, instr)
{
if scalar.kind() == ScalarKind::Bit
- && scalar.width() == (vec_underlying_type.width() * vec_len)
+ && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
{
return Ok(Some(ConversionKind::Default));
}
@@ -4455,8 +4714,8 @@ fn should_bitcast_packed(
}
fn should_bitcast_wrapper(
- operand: ast::Type,
- instr: ast::Type,
+ operand: &ast::Type,
+ instr: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if instr == operand {
@@ -4470,8 +4729,8 @@ fn should_bitcast_wrapper(
}
fn should_convert_relaxed_src_wrapper(
- src_type: ast::Type,
- instr_type: ast::Type,
+ src_type: &ast::Type,
+ instr_type: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if src_type == instr_type {
@@ -4485,8 +4744,8 @@ fn should_convert_relaxed_src_wrapper(
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands
fn should_convert_relaxed_src(
- src_type: ast::Type,
- instr_type: ast::Type,
+ src_type: &ast::Type,
+ instr_type: &ast::Type,
) -> Option<ConversionKind> {
if src_type == instr_type {
return None;
@@ -4494,21 +4753,24 @@ fn should_convert_relaxed_src(
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ScalarKind::Bit => {
- if instr_type.width() <= src_type.width() {
+ if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Signed | ScalarKind::Unsigned => {
- if instr_type.width() <= src_type.width() && src_type.kind() != ScalarKind::Float {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() != ScalarKind::Float
+ {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float => {
- if instr_type.width() <= src_type.width() && src_type.kind() == ScalarKind::Bit {
+ if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit
+ {
Some(ConversionKind::Default)
} else {
None
@@ -4519,15 +4781,18 @@ fn should_convert_relaxed_src(
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
- should_convert_relaxed_src(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type))
+ should_convert_relaxed_src(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
}
_ => None,
}
}
fn should_convert_relaxed_dst_wrapper(
- dst_type: ast::Type,
- instr_type: ast::Type,
+ dst_type: &ast::Type,
+ instr_type: &ast::Type,
_: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
if dst_type == instr_type {
@@ -4541,8 +4806,8 @@ fn should_convert_relaxed_dst_wrapper(
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands
fn should_convert_relaxed_dst(
- dst_type: ast::Type,
- instr_type: ast::Type,
+ dst_type: &ast::Type,
+ instr_type: &ast::Type,
) -> Option<ConversionKind> {
if dst_type == instr_type {
return None;
@@ -4550,7 +4815,7 @@ fn should_convert_relaxed_dst(
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
ScalarKind::Bit => {
- if instr_type.width() <= dst_type.width() {
+ if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
@@ -4558,9 +4823,9 @@ fn should_convert_relaxed_dst(
}
ScalarKind::Signed => {
if dst_type.kind() != ScalarKind::Float {
- if instr_type.width() == dst_type.width() {
+ if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
- } else if instr_type.width() < dst_type.width() {
+ } else if instr_type.size_of() < dst_type.size_of() {
Some(ConversionKind::SignExtend)
} else {
None
@@ -4570,14 +4835,17 @@ fn should_convert_relaxed_dst(
}
}
ScalarKind::Unsigned => {
- if instr_type.width() <= dst_type.width() && dst_type.kind() != ScalarKind::Float {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() != ScalarKind::Float
+ {
Some(ConversionKind::Default)
} else {
None
}
}
ScalarKind::Float => {
- if instr_type.width() <= dst_type.width() && dst_type.kind() == ScalarKind::Bit {
+ if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit
+ {
Some(ConversionKind::Default)
} else {
None
@@ -4588,7 +4856,10 @@ fn should_convert_relaxed_dst(
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
- should_convert_relaxed_dst(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type))
+ should_convert_relaxed_dst(
+ &ast::Type::Scalar(*dst_type),
+ &ast::Type::Scalar(*instr_type),
+ )
}
_ => None,
}
@@ -4611,7 +4882,8 @@ impl<'a> ast::MethodDecl<'a, spirv::Word> {
f(&ast::FnArgument {
align: arg.align,
name: arg.name,
- v_type: ast::FnArgumentType::Param(arg.v_type),
+ v_type: ast::FnArgumentType::Param(arg.v_type.clone()),
+ array_init: arg.array_init.clone(),
})
}),
}
@@ -4698,14 +4970,17 @@ mod tests {
.collect::<Vec<_>>()
}
- fn assert_conversion_table<F: Fn(ast::Type, ast::Type) -> Option<ConversionKind>>(
+ fn assert_conversion_table<F: Fn(&ast::Type, &ast::Type) -> Option<ConversionKind>>(
table: &'static str,
f: F,
) {
let conv_table = parse_conversion_table(table);
for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() {
for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() {
- let conversion = f(ast::Type::Scalar(*op_type), ast::Type::Scalar(*instr_type));
+ let conversion = f(
+ &ast::Type::Scalar(*op_type),
+ &ast::Type::Scalar(*instr_type),
+ );
if instr_idx == op_idx {
assert_eq!(conversion, None);
} else {