aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src
diff options
context:
space:
mode:
Diffstat (limited to 'ptx/src')
-rw-r--r--ptx/src/ast.rs725
-rw-r--r--ptx/src/ptx.lalrpop519
-rw-r--r--ptx/src/test/spirv_run/and.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/atom_add.spvtxt17
-rw-r--r--ptx/src/test/spirv_run/atom_add_float.spvtxt17
-rw-r--r--ptx/src/test/spirv_run/atom_cas.spvtxt28
-rw-r--r--ptx/src/test/spirv_run/atom_inc.spvtxt36
-rw-r--r--ptx/src/test/spirv_run/bfe.spvtxt18
-rw-r--r--ptx/src/test/spirv_run/bfi.spvtxt28
-rw-r--r--ptx/src/test/spirv_run/call.spvtxt4
-rw-r--r--ptx/src/test/spirv_run/cvt_rni.spvtxt18
-rw-r--r--ptx/src/test/spirv_run/cvt_rzi.spvtxt18
-rw-r--r--ptx/src/test/spirv_run/cvt_s32_f32.spvtxt21
-rw-r--r--ptx/src/test/spirv_run/div_approx.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/extern_shared.spvtxt49
-rw-r--r--ptx/src/test/spirv_run/extern_shared_call.spvtxt103
-rw-r--r--ptx/src/test/spirv_run/fma.spvtxt18
-rw-r--r--ptx/src/test/spirv_run/ld_st_offset.spvtxt18
-rw-r--r--ptx/src/test/spirv_run/mad_s32.spvtxt38
-rw-r--r--ptx/src/test/spirv_run/max.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/min.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/mul_ftz.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/mul_non_ftz.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/mul_wide.spvtxt22
-rw-r--r--ptx/src/test/spirv_run/or.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/pred_not.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/reg_local.spvtxt17
-rw-r--r--ptx/src/test/spirv_run/rem.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/selp.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/selp_true.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/setp.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/setp_gt.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/setp_leu.spvtxt10
-rw-r--r--ptx/src/test/spirv_run/setp_nan.spvtxt106
-rw-r--r--ptx/src/test/spirv_run/setp_num.spvtxt114
-rw-r--r--ptx/src/test/spirv_run/shared_ptr_32.spvtxt11
-rw-r--r--ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt55
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt66
-rw-r--r--ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt66
-rw-r--r--ptx/src/test/spirv_run/vector.spvtxt2
-rw-r--r--ptx/src/test/spirv_run/verify.py21
-rw-r--r--ptx/src/test/spirv_run/xor.spvtxt10
-rw-r--r--ptx/src/translate.rs3316
43 files changed, 2613 insertions, 3008 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 1c6d2fb..5432207 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -1,7 +1,6 @@
use half::f16;
use lalrpop_util::{lexer::Token, ParseError};
-use std::convert::TryInto;
-use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr};
+use std::{convert::From, mem, num::ParseFloatError, str::FromStr};
use std::{marker::PhantomData, num::ParseIntError};
#[derive(Debug, thiserror::Error)]
@@ -34,195 +33,12 @@ pub enum PtxError {
NonExternPointer,
}
-macro_rules! sub_enum {
- ($name:ident { $($variant:ident),+ $(,)? }) => {
- sub_enum!{ $name : ScalarType { $($variant),+ } }
- };
- ($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => {
- #[derive(PartialEq, Eq, Clone, Copy)]
- pub enum $name {
- $(
- $variant,
- )+
- }
-
- impl From<$name> for $base_type {
- fn from(t: $name) -> $base_type {
- match t {
- $(
- $name::$variant => $base_type::$variant,
- )+
- }
- }
- }
-
- impl std::convert::TryFrom<$base_type> for $name {
- type Error = ();
-
- fn try_from(t: $base_type) -> Result<Self, Self::Error> {
- match t {
- $(
- $base_type::$variant => Ok($name::$variant),
- )+
- _ => Err(()),
- }
- }
- }
- };
-}
-
-macro_rules! sub_type {
- ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
- sub_type! { $type_name : Type {
- $(
- $variant ($($field_type),+),
- )+
- }}
- };
- ($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => {
- #[derive(PartialEq, Eq, Clone)]
- pub enum $type_name {
- $(
- $variant ($($field_type),+),
- )+
- }
-
- impl From<$type_name> for $base_type {
- #[allow(non_snake_case)]
- fn from(t: $type_name) -> $base_type {
- match t {
- $(
- $type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+),
- )+
- }
- }
- }
-
- impl std::convert::TryFrom<$base_type> for $type_name {
- type Error = ();
-
- #[allow(non_snake_case)]
- #[allow(unreachable_patterns)]
- fn try_from(t: $base_type) -> Result<Self, Self::Error> {
- match t {
- $(
- $base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
- )+
- _ => Err(()),
- }
- }
- }
- };
-}
-
-sub_type! {
- VariableRegType {
- Scalar(ScalarType),
- Vector(SizedScalarType, u8),
- // Array type is used when emiting SSA statements at the start of a method
- Array(ScalarType, VecU32),
- // Pointer variant is used when passing around SLM pointer between
- // function calls for dynamic SLM
- Pointer(SizedScalarType, PointerStateSpace)
- }
-}
-
-type VecU32 = Vec<u32>;
-
-sub_type! {
- VariableLocalType {
- Scalar(SizedScalarType),
- Vector(SizedScalarType, u8),
- Array(SizedScalarType, VecU32),
- }
-}
-
-impl TryFrom<VariableGlobalType> for VariableLocalType {
- type Error = PtxError;
-
- fn try_from(value: VariableGlobalType) -> Result<Self, Self::Error> {
- match value {
- VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)),
- VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)),
- VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)),
- VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray),
- }
- }
-}
-
-sub_type! {
- VariableGlobalType {
- Scalar(SizedScalarType),
- Vector(SizedScalarType, u8),
- Array(SizedScalarType, VecU32),
- Pointer(SizedScalarType, PointerStateSpace),
- }
-}
-
// For some weird reson this is illegal:
// .param .f16x2 foobar;
// but this is legal:
// .param .f16x2 foobar[1];
// even more interestingly this is legal, but only in .func (not in .entry):
// .param .b32 foobar[]
-sub_type! {
- VariableParamType {
- Scalar(LdStScalarType),
- Array(SizedScalarType, VecU32),
- Pointer(SizedScalarType, PointerStateSpace),
- }
-}
-
-sub_enum!(SizedScalarType {
- B8,
- B16,
- B32,
- B64,
- U8,
- U16,
- U32,
- U64,
- S8,
- S16,
- S32,
- S64,
- F16,
- F16x2,
- F32,
- F64,
-});
-
-sub_enum!(LdStScalarType {
- B8,
- B16,
- B32,
- B64,
- U8,
- U16,
- U32,
- U64,
- S8,
- S16,
- S32,
- S64,
- F16,
- F32,
- F64,
-});
-
-sub_enum!(SelpType {
- B16,
- B32,
- B64,
- U16,
- U32,
- U64,
- S16,
- S32,
- S64,
- F32,
- F64,
-});
#[derive(Copy, Clone, Eq, PartialEq)]
pub enum BarDetails {
@@ -266,23 +82,25 @@ pub struct Module<'a> {
}
pub enum Directive<'a, P: ArgParams> {
- Variable(Variable<VariableType, P::Id>),
- Method(Function<'a, &'a str, Statement<P>>),
+ Variable(LinkingDirective, Variable<P::Id>),
+ Method(LinkingDirective, Function<'a, &'a str, Statement<P>>),
}
-pub enum MethodDecl<'a, ID> {
- Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>),
- Kernel {
- name: &'a str,
- in_args: Vec<KernelArgument<ID>>,
- },
+#[derive(Hash, PartialEq, Eq, Copy, Clone)]
+pub enum MethodName<'input, ID> {
+ Kernel(&'input str),
+ Func(ID),
}
-pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
-pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>;
+pub struct MethodDeclaration<'input, ID> {
+ pub return_arguments: Vec<Variable<ID>>,
+ pub name: MethodName<'input, ID>,
+ pub input_arguments: Vec<Variable<ID>>,
+ pub shared_mem: Option<ID>,
+}
pub struct Function<'a, ID, S> {
- pub func_directive: MethodDecl<'a, ID>,
+ pub func_directive: MethodDeclaration<'a, ID>,
pub tuning: Vec<TuningDirective>,
pub body: Option<Vec<S>>,
}
@@ -290,118 +108,50 @@ pub struct Function<'a, ID, S> {
pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>;
#[derive(PartialEq, Eq, Clone)]
-pub enum FnArgumentType {
- Reg(VariableRegType),
- Param(VariableParamType),
- Shared,
-}
-#[derive(PartialEq, Eq, Clone)]
-pub enum KernelArgumentType {
- Normal(VariableParamType),
- Shared,
-}
-
-impl From<KernelArgumentType> for Type {
- fn from(this: KernelArgumentType) -> Self {
- match this {
- KernelArgumentType::Normal(typ) => typ.into(),
- KernelArgumentType::Shared => {
- Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
- }
- }
- }
-}
-
-impl FnArgumentType {
- pub fn to_type(&self, is_kernel: bool) -> Type {
- if is_kernel {
- self.to_kernel_type()
- } else {
- self.to_func_type()
- }
- }
-
- pub fn to_kernel_type(&self) -> Type {
- match self {
- FnArgumentType::Reg(x) => x.clone().into(),
- FnArgumentType::Param(x) => x.clone().into(),
- FnArgumentType::Shared => {
- Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
- }
- }
- }
-
- pub fn to_func_type(&self) -> Type {
- match self {
- FnArgumentType::Reg(x) => x.clone().into(),
- FnArgumentType::Param(VariableParamType::Scalar(t)) => {
- Type::Pointer(PointerType::Scalar((*t).into()), LdStateSpace::Param)
- }
- FnArgumentType::Param(VariableParamType::Array(t, dims)) => Type::Pointer(
- PointerType::Array((*t).into(), dims.clone()),
- LdStateSpace::Param,
- ),
- FnArgumentType::Param(VariableParamType::Pointer(t, space)) => Type::Pointer(
- PointerType::Pointer((*t).into(), (*space).into()),
- LdStateSpace::Param,
- ),
- FnArgumentType::Shared => {
- Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared)
- }
- }
- }
-
- pub fn is_param(&self) -> bool {
- match self {
- FnArgumentType::Param(_) => true,
- _ => false,
- }
- }
-}
-
-sub_enum!(
- PointerStateSpace : LdStateSpace {
- Generic,
- Global,
- Const,
- Shared,
- Param,
- }
-);
-
-#[derive(PartialEq, Eq, Clone)]
pub enum Type {
+ // .param.b32 foo;
+ // -> OpTypeInt
Scalar(ScalarType),
+ // .param.v2.b32 foo;
+ // -> OpTypeVector
Vector(ScalarType, u8),
+ // .param.b32 foo[4];
+ // -> OpTypeArray
Array(ScalarType, Vec<u32>),
- Pointer(PointerType, LdStateSpace),
-}
-
-#[derive(PartialEq, Eq, Clone)]
-pub enum PointerType {
- Scalar(ScalarType),
- Vector(ScalarType, u8),
- Array(ScalarType, VecU32),
- Pointer(ScalarType, LdStateSpace),
-}
-
-impl From<SizedScalarType> for PointerType {
- fn from(t: SizedScalarType) -> Self {
- PointerType::Scalar(t.into())
- }
-}
-
-impl TryFrom<PointerType> for SizedScalarType {
- type Error = ();
-
- fn try_from(value: PointerType) -> Result<Self, Self::Error> {
- match value {
- PointerType::Scalar(t) => Ok(t.try_into()?),
- PointerType::Vector(_, _) => Err(()),
- PointerType::Array(_, _) => Err(()),
- PointerType::Pointer(_, _) => Err(()),
- }
- }
+ /*
+ Variables of this type almost never exist in the original .ptx and are
+ usually artificially created. Some examples below:
+ - extern pointers to the .shared memory in the form:
+ .extern .shared .b32 shared_mem[];
+ which we first parse as
+ .extern .shared .b32 shared_mem;
+ and then convert to an additional function parameter:
+ .param .ptr<.b32.shared> shared_mem;
+ and do a load at the start of the function (and renames inside fn):
+ .reg .ptr<.b32.shared> temp;
+ ld.param.ptr<.b32.shared> temp, [shared_mem];
+ note, we don't support non-.shared extern pointers, because there's
+ zero use for them in the ptxas
+ - artifical pointers created by stateful conversion, which work
+ similiarly to the above
+ - function parameters:
+ foobar(.param .align 4 .b8 numbers[])
+ which get parsed to
+ foobar(.param .align 4 .b8 numbers)
+ and then converted to
+ foobar(.reg .align 4 .ptr<.b8.param> numbers)
+ - ld/st with offset:
+ .reg.b32 x;
+ .param.b64 arg0;
+ st.param.b32 [arg0+4], x;
+ Yes, this code is legal and actually emitted by the NV compiler!
+ We convert the st to:
+ .reg ptr<.b64.param> temp = ptr_offset(arg0, 4);
+ st.param.b32 [temp], x;
+ */
+ // .reg ptr<.b64.param>
+ // -> OpTypePointer Function
+ Pointer(ScalarType, StateSpace),
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
@@ -425,52 +175,6 @@ pub enum ScalarType {
Pred,
}
-sub_enum!(IntType {
- U8,
- U16,
- U32,
- U64,
- S8,
- S16,
- S32,
- S64
-});
-
-sub_enum!(BitType { B8, B16, B32, B64 });
-
-sub_enum!(UIntType { U8, U16, U32, U64 });
-
-sub_enum!(SIntType { S8, S16, S32, S64 });
-
-impl IntType {
- pub fn is_signed(self) -> bool {
- match self {
- IntType::U8 | IntType::U16 | IntType::U32 | IntType::U64 => false,
- IntType::S8 | IntType::S16 | IntType::S32 | IntType::S64 => true,
- }
- }
-
- pub fn width(self) -> u8 {
- match self {
- IntType::U8 => 1,
- IntType::U16 => 2,
- IntType::U32 => 4,
- IntType::U64 => 8,
- IntType::S8 => 1,
- IntType::S16 => 2,
- IntType::S32 => 4,
- IntType::S64 => 8,
- }
- }
-}
-
-sub_enum!(FloatType {
- F16,
- F16x2,
- F32,
- F64
-});
-
impl ScalarType {
pub fn size_of(self) -> u8 {
match self {
@@ -509,51 +213,19 @@ pub enum Statement<P: ArgParams> {
}
pub struct MultiVariable<ID> {
- pub var: Variable<VariableType, ID>,
+ pub var: Variable<ID>,
pub count: Option<u32>,
}
#[derive(Clone)]
-pub struct Variable<T, ID> {
+pub struct Variable<ID> {
pub align: Option<u32>,
- pub v_type: T,
+ pub v_type: Type,
+ pub state_space: StateSpace,
pub name: ID,
pub array_init: Vec<u8>,
}
-#[derive(Eq, PartialEq, Clone)]
-pub enum VariableType {
- Reg(VariableRegType),
- Local(VariableLocalType),
- Param(VariableParamType),
- Global(VariableGlobalType),
- Shared(VariableGlobalType),
-}
-
-impl VariableType {
- pub fn to_type(&self) -> (StateSpace, Type) {
- match self {
- VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()),
- VariableType::Local(t) => (StateSpace::Local, t.clone().into()),
- VariableType::Param(t) => (StateSpace::Param, t.clone().into()),
- VariableType::Global(t) => (StateSpace::Global, t.clone().into()),
- VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()),
- }
- }
-}
-
-impl From<VariableType> for Type {
- fn from(t: VariableType) -> Self {
- match t {
- VariableType::Reg(t) => t.into(),
- VariableType::Local(t) => t.into(),
- VariableType::Param(t) => t.into(),
- VariableType::Global(t) => t.into(),
- VariableType::Shared(t) => t.into(),
- }
- }
-}
-
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum StateSpace {
Reg,
@@ -562,6 +234,8 @@ pub enum StateSpace {
Local,
Shared,
Param,
+ Generic,
+ Sreg,
}
pub struct PredAt<ID> {
@@ -576,24 +250,24 @@ pub enum Instruction<P: ArgParams> {
Add(ArithDetails, Arg3<P>),
Setp(SetpData, Arg4Setp<P>),
SetpBool(SetpBoolData, Arg5Setp<P>),
- Not(BooleanType, Arg2<P>),
+ Not(ScalarType, Arg2<P>),
Bra(BraData, Arg1<P>),
Cvt(CvtDetails, Arg2<P>),
Cvta(CvtaDetails, Arg2<P>),
- Shl(ShlType, Arg3<P>),
- Shr(ShrType, Arg3<P>),
+ Shl(ScalarType, Arg3<P>),
+ Shr(ScalarType, Arg3<P>),
St(StData, Arg2St<P>),
Ret(RetData),
Call(CallInst<P>),
Abs(AbsDetails, Arg2<P>),
Mad(MulDetails, Arg4<P>),
- Or(BooleanType, Arg3<P>),
+ Or(ScalarType, Arg3<P>),
Sub(ArithDetails, Arg3<P>),
Min(MinMaxDetails, Arg3<P>),
Max(MinMaxDetails, Arg3<P>),
Rcp(RcpDetails, Arg2<P>),
- And(BooleanType, Arg3<P>),
- Selp(SelpType, Arg4<P>),
+ And(ScalarType, Arg3<P>),
+ Selp(ScalarType, Arg4<P>),
Bar(BarDetails, Arg1Bar<P>),
Atom(AtomDetails, Arg3<P>),
AtomCas(AtomCasDetails, Arg4<P>),
@@ -605,13 +279,13 @@ pub enum Instruction<P: ArgParams> {
Cos { flush_to_zero: bool, arg: Arg2<P> },
Lg2 { flush_to_zero: bool, arg: Arg2<P> },
Ex2 { flush_to_zero: bool, arg: Arg2<P> },
- Clz { typ: BitType, arg: Arg2<P> },
- Brev { typ: BitType, arg: Arg2<P> },
- Popc { typ: BitType, arg: Arg2<P> },
- Xor { typ: BooleanType, arg: Arg3<P> },
- Bfe { typ: IntType, arg: Arg4<P> },
- Bfi { typ: BitType, arg: Arg5<P> },
- Rem { typ: IntType, arg: Arg3<P> },
+ Clz { typ: ScalarType, arg: Arg2<P> },
+ Brev { typ: ScalarType, arg: Arg2<P> },
+ Popc { typ: ScalarType, arg: Arg2<P> },
+ Xor { typ: ScalarType, arg: Arg3<P> },
+ Bfe { typ: ScalarType, arg: Arg4<P> },
+ Bfi { typ: ScalarType, arg: Arg5<P> },
+ Rem { typ: ScalarType, arg: Arg3<P> },
}
#[derive(Copy, Clone)]
@@ -737,34 +411,12 @@ pub enum VectorPrefix {
pub struct LdDetails {
pub qualifier: LdStQualifier,
- pub state_space: LdStateSpace,
+ pub state_space: StateSpace,
pub caching: LdCacheOperator,
- pub typ: LdStType,
+ pub typ: Type,
pub non_coherent: bool,
}
-sub_type! {
- LdStType {
- Scalar(LdStScalarType),
- Vector(LdStScalarType, u8),
- // Used in generated code
- Pointer(PointerType, LdStateSpace),
- }
-}
-
-impl From<LdStType> for PointerType {
- fn from(t: LdStType) -> Self {
- match t {
- LdStType::Scalar(t) => PointerType::Scalar(t.into()),
- LdStType::Vector(t, len) => PointerType::Vector(t.into(), len),
- LdStType::Pointer(PointerType::Scalar(scalar_type), space) => {
- PointerType::Pointer(scalar_type, space)
- }
- LdStType::Pointer(..) => unreachable!(),
- }
- }
-}
-
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum LdStQualifier {
Weak,
@@ -780,17 +432,6 @@ pub enum MemScope {
Sys,
}
-#[derive(Copy, Clone, PartialEq, Eq, Debug)]
-#[repr(u8)]
-pub enum LdStateSpace {
- Generic,
- Const,
- Global,
- Local,
- Param,
- Shared,
-}
-
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum LdCacheOperator {
Cached,
@@ -825,7 +466,7 @@ impl MovDetails {
#[derive(Copy, Clone)]
pub struct MulIntDesc {
- pub typ: IntType,
+ pub typ: ScalarType,
pub control: MulIntControl,
}
@@ -845,7 +486,7 @@ pub enum RoundingMode {
}
pub struct AddIntDesc {
- pub typ: IntType,
+ pub typ: ScalarType,
pub saturate: bool,
}
@@ -892,39 +533,39 @@ pub struct BraData {
pub enum CvtDetails {
IntFromInt(CvtIntToIntDesc),
- FloatFromFloat(CvtDesc<FloatType, FloatType>),
- IntFromFloat(CvtDesc<IntType, FloatType>),
- FloatFromInt(CvtDesc<FloatType, IntType>),
+ FloatFromFloat(CvtDesc),
+ IntFromFloat(CvtDesc),
+ FloatFromInt(CvtDesc),
}
pub struct CvtIntToIntDesc {
- pub dst: IntType,
- pub src: IntType,
+ pub dst: ScalarType,
+ pub src: ScalarType,
pub saturate: bool,
}
-pub struct CvtDesc<Dst, Src> {
+pub struct CvtDesc {
pub rounding: Option<RoundingMode>,
pub flush_to_zero: Option<bool>,
pub saturate: bool,
- pub dst: Dst,
- pub src: Src,
+ pub dst: ScalarType,
+ pub src: ScalarType,
}
impl CvtDetails {
pub fn new_int_from_int_checked<'err, 'input>(
saturate: bool,
- dst: IntType,
- src: IntType,
+ dst: ScalarType,
+ src: ScalarType,
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self {
if saturate {
- if src.is_signed() {
- if dst.is_signed() && dst.width() >= src.width() {
+ if src.kind() == ScalarKind::Signed {
+ if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() {
err.push(ParseError::from(PtxError::SyntaxError));
}
} else {
- if dst == src || dst.width() >= src.width() {
+ if dst == src || dst.size_of() >= src.size_of() {
err.push(ParseError::from(PtxError::SyntaxError));
}
}
@@ -936,11 +577,11 @@ impl CvtDetails {
rounding: RoundingMode,
flush_to_zero: bool,
saturate: bool,
- dst: FloatType,
- src: IntType,
+ dst: ScalarType,
+ src: ScalarType,
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self {
- if flush_to_zero && dst != FloatType::F32 {
+ if flush_to_zero && dst != ScalarType::F32 {
err.push(ParseError::from(PtxError::NonF32Ftz));
}
CvtDetails::FloatFromInt(CvtDesc {
@@ -956,11 +597,11 @@ impl CvtDetails {
rounding: RoundingMode,
flush_to_zero: bool,
saturate: bool,
- dst: IntType,
- src: FloatType,
+ dst: ScalarType,
+ src: ScalarType,
err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>,
) -> Self {
- if flush_to_zero && src != FloatType::F32 {
+ if flush_to_zero && src != ScalarType::F32 {
err.push(ParseError::from(PtxError::NonF32Ftz));
}
CvtDetails::IntFromFloat(CvtDesc {
@@ -974,58 +615,21 @@ impl CvtDetails {
}
pub struct CvtaDetails {
- pub to: CvtaStateSpace,
- pub from: CvtaStateSpace,
+ pub to: StateSpace,
+ pub from: StateSpace,
pub size: CvtaSize,
}
-#[derive(Copy, Clone, PartialEq, Eq)]
-pub enum CvtaStateSpace {
- Generic,
- Const,
- Global,
- Local,
- Shared,
-}
-
pub enum CvtaSize {
U32,
U64,
}
-#[derive(PartialEq, Eq, Copy, Clone)]
-pub enum ShlType {
- B16,
- B32,
- B64,
-}
-
-sub_enum!(ShrType {
- B16,
- B32,
- B64,
- U16,
- U32,
- U64,
- S16,
- S32,
- S64,
-});
-
pub struct StData {
pub qualifier: LdStQualifier,
- pub state_space: StStateSpace,
+ pub state_space: StateSpace,
pub caching: StCacheOperator,
- pub typ: LdStType,
-}
-
-#[derive(PartialEq, Eq, Copy, Clone)]
-pub enum StStateSpace {
- Generic,
- Global,
- Local,
- Param,
- Shared,
+ pub typ: Type,
}
#[derive(PartialEq, Eq)]
@@ -1040,13 +644,6 @@ pub struct RetData {
pub uniform: bool,
}
-sub_enum!(BooleanType {
- Pred,
- B16,
- B32,
- B64,
-});
-
#[derive(Copy, Clone)]
pub enum MulDetails {
Unsigned(MulUInt),
@@ -1056,32 +653,32 @@ pub enum MulDetails {
#[derive(Copy, Clone)]
pub struct MulUInt {
- pub typ: UIntType,
+ pub typ: ScalarType,
pub control: MulIntControl,
}
#[derive(Copy, Clone)]
pub struct MulSInt {
- pub typ: SIntType,
+ pub typ: ScalarType,
pub control: MulIntControl,
}
#[derive(Copy, Clone)]
pub enum ArithDetails {
- Unsigned(UIntType),
+ Unsigned(ScalarType),
Signed(ArithSInt),
Float(ArithFloat),
}
#[derive(Copy, Clone)]
pub struct ArithSInt {
- pub typ: SIntType,
+ pub typ: ScalarType,
pub saturate: bool,
}
#[derive(Copy, Clone)]
pub struct ArithFloat {
- pub typ: FloatType,
+ pub typ: ScalarType,
pub rounding: Option<RoundingMode>,
pub flush_to_zero: Option<bool>,
pub saturate: bool,
@@ -1089,8 +686,8 @@ pub struct ArithFloat {
#[derive(Copy, Clone)]
pub enum MinMaxDetails {
- Signed(SIntType),
- Unsigned(UIntType),
+ Signed(ScalarType),
+ Unsigned(ScalarType),
Float(MinMaxFloat),
}
@@ -1098,14 +695,14 @@ pub enum MinMaxDetails {
pub struct MinMaxFloat {
pub flush_to_zero: Option<bool>,
pub nan: bool,
- pub typ: FloatType,
+ pub typ: ScalarType,
}
#[derive(Copy, Clone)]
pub struct AtomDetails {
pub semantics: AtomSemantics,
pub scope: MemScope,
- pub space: AtomSpace,
+ pub space: StateSpace,
pub inner: AtomInnerDetails,
}
@@ -1118,18 +715,11 @@ pub enum AtomSemantics {
}
#[derive(Copy, Clone)]
-pub enum AtomSpace {
- Generic,
- Global,
- Shared,
-}
-
-#[derive(Copy, Clone)]
pub enum AtomInnerDetails {
- Bit { op: AtomBitOp, typ: BitType },
- Unsigned { op: AtomUIntOp, typ: UIntType },
- Signed { op: AtomSIntOp, typ: SIntType },
- Float { op: AtomFloatOp, typ: FloatType },
+ Bit { op: AtomBitOp, typ: ScalarType },
+ Unsigned { op: AtomUIntOp, typ: ScalarType },
+ Signed { op: AtomSIntOp, typ: ScalarType },
+ Float { op: AtomFloatOp, typ: ScalarType },
}
#[derive(Copy, Clone, Eq, PartialEq)]
@@ -1165,20 +755,20 @@ pub enum AtomFloatOp {
pub struct AtomCasDetails {
pub semantics: AtomSemantics,
pub scope: MemScope,
- pub space: AtomSpace,
- pub typ: BitType,
+ pub space: StateSpace,
+ pub typ: ScalarType,
}
#[derive(Copy, Clone)]
pub enum DivDetails {
- Unsigned(UIntType),
- Signed(SIntType),
+ Unsigned(ScalarType),
+ Signed(ScalarType),
Float(DivFloatDetails),
}
#[derive(Copy, Clone)]
pub struct DivFloatDetails {
- pub typ: FloatType,
+ pub typ: ScalarType,
pub flush_to_zero: Option<bool>,
pub kind: DivFloatKind,
}
@@ -1197,7 +787,7 @@ pub enum NumsOrArrays<'a> {
#[derive(Copy, Clone)]
pub struct SqrtDetails {
- pub typ: FloatType,
+ pub typ: ScalarType,
pub flush_to_zero: Option<bool>,
pub kind: SqrtKind,
}
@@ -1210,7 +800,7 @@ pub enum SqrtKind {
#[derive(Copy, Clone, Eq, PartialEq)]
pub struct RsqrtDetails {
- pub typ: FloatType,
+ pub typ: ScalarType,
pub flush_to_zero: bool,
}
@@ -1221,7 +811,7 @@ pub struct NegDetails {
}
impl<'a> NumsOrArrays<'a> {
- pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
+ pub fn to_vec(self, typ: ScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> {
self.normalize_dimensions(dimensions)?;
let sizeof_t = ScalarType::from(typ).size_of() as usize;
let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize));
@@ -1252,7 +842,7 @@ impl<'a> NumsOrArrays<'a> {
fn parse_and_copy(
&self,
- t: SizedScalarType,
+ t: ScalarType,
size_of_t: usize,
dimensions: &[u32],
result: &mut [u8],
@@ -1292,47 +882,48 @@ impl<'a> NumsOrArrays<'a> {
}
fn parse_and_copy_single(
- t: SizedScalarType,
+ t: ScalarType,
idx: usize,
str_val: &str,
radix: u32,
output: &mut [u8],
) -> Result<(), PtxError> {
match t {
- SizedScalarType::B8 | SizedScalarType::U8 => {
+ ScalarType::B8 | ScalarType::U8 => {
Self::parse_and_copy_single_t::<u8>(idx, str_val, radix, output)?;
}
- SizedScalarType::B16 | SizedScalarType::U16 => {
+ ScalarType::B16 | ScalarType::U16 => {
Self::parse_and_copy_single_t::<u16>(idx, str_val, radix, output)?;
}
- SizedScalarType::B32 | SizedScalarType::U32 => {
+ ScalarType::B32 | ScalarType::U32 => {
Self::parse_and_copy_single_t::<u32>(idx, str_val, radix, output)?;
}
- SizedScalarType::B64 | SizedScalarType::U64 => {
+ ScalarType::B64 | ScalarType::U64 => {
Self::parse_and_copy_single_t::<u64>(idx, str_val, radix, output)?;
}
- SizedScalarType::S8 => {
+ ScalarType::S8 => {
Self::parse_and_copy_single_t::<i8>(idx, str_val, radix, output)?;
}
- SizedScalarType::S16 => {
+ ScalarType::S16 => {
Self::parse_and_copy_single_t::<i16>(idx, str_val, radix, output)?;
}
- SizedScalarType::S32 => {
+ ScalarType::S32 => {
Self::parse_and_copy_single_t::<i32>(idx, str_val, radix, output)?;
}
- SizedScalarType::S64 => {
+ ScalarType::S64 => {
Self::parse_and_copy_single_t::<i64>(idx, str_val, radix, output)?;
}
- SizedScalarType::F16 => {
+ ScalarType::F16 => {
Self::parse_and_copy_single_t::<f16>(idx, str_val, radix, output)?;
}
- SizedScalarType::F16x2 => todo!(),
- SizedScalarType::F32 => {
+ ScalarType::F16x2 => todo!(),
+ ScalarType::F32 => {
Self::parse_and_copy_single_t::<f32>(idx, str_val, radix, output)?;
}
- SizedScalarType::F64 => {
+ ScalarType::F64 => {
Self::parse_and_copy_single_t::<f64>(idx, str_val, radix, output)?;
}
+ ScalarType::Pred => todo!(),
}
Ok(())
}
@@ -1379,6 +970,40 @@ pub enum TuningDirective {
MinNCtaPerSm(u32),
}
+#[derive(Clone, Copy, PartialEq, Eq)]
+pub enum ScalarKind {
+ Bit,
+ Unsigned,
+ Signed,
+ Float,
+ Float2,
+ Pred,
+}
+
+impl ScalarType {
+ pub fn kind(self) -> ScalarKind {
+ match self {
+ ScalarType::U8 => ScalarKind::Unsigned,
+ ScalarType::U16 => ScalarKind::Unsigned,
+ ScalarType::U32 => ScalarKind::Unsigned,
+ ScalarType::U64 => ScalarKind::Unsigned,
+ ScalarType::S8 => ScalarKind::Signed,
+ ScalarType::S16 => ScalarKind::Signed,
+ ScalarType::S32 => ScalarKind::Signed,
+ ScalarType::S64 => ScalarKind::Signed,
+ ScalarType::B8 => ScalarKind::Bit,
+ ScalarType::B16 => ScalarKind::Bit,
+ ScalarType::B32 => ScalarKind::Bit,
+ ScalarType::B64 => ScalarKind::Bit,
+ ScalarType::F16 => ScalarKind::Float,
+ ScalarType::F32 => ScalarKind::Float,
+ ScalarType::F64 => ScalarKind::Float,
+ ScalarType::F16x2 => ScalarKind::Float2,
+ ScalarType::Pred => ScalarKind::Pred,
+ }
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -1386,13 +1011,13 @@ mod tests {
#[test]
fn array_fails_multiple_0_dmiensions() {
let inp = NumsOrArrays::Nums(Vec::new());
- assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err());
+ assert!(inp.to_vec(ScalarType::B8, &mut vec![0, 0]).is_err());
}
#[test]
fn array_fails_on_empty() {
let inp = NumsOrArrays::Nums(Vec::new());
- assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err());
+ assert!(inp.to_vec(ScalarType::B8, &mut vec![0]).is_err());
}
#[test]
@@ -1404,7 +1029,7 @@ mod tests {
let mut dimensions = vec![0u32, 2];
assert_eq!(
vec![1u8, 2, 3, 4],
- inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap()
+ inp.to_vec(ScalarType::B8, &mut dimensions).unwrap()
);
assert_eq!(dimensions, vec![2u32, 2]);
}
@@ -1416,7 +1041,7 @@ mod tests {
NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]),
]);
let mut dimensions = vec![0u32, 2];
- assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
+ assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err());
}
#[test]
@@ -1426,6 +1051,6 @@ mod tests {
NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]),
]);
let mut dimensions = vec![0u32, 2];
- assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err());
+ assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err());
}
}
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 423fd57..b697317 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -343,10 +343,16 @@ TargetSpecifier = {
Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = {
AddressSize => None,
- <f:Function> => Some(ast::Directive::Method(f)),
+ <f:Function> => {
+ let (linking, func) = f;
+ Some(ast::Directive::Method(linking, func))
+ },
File => None,
Section => None,
- <v:ModuleVariable> ";" => Some(ast::Directive::Variable(v)),
+ <v:ModuleVariable> ";" => {
+ let (linking, var) = v;
+ Some(ast::Directive::Variable(linking, var))
+ },
! => {
let err = <>;
errors.push(err.error);
@@ -358,11 +364,13 @@ AddressSize = {
".address_size" U8Num
};
-Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = {
- LinkingDirectives
- <func_directive:MethodDecl>
+Function: (ast::LinkingDirective, ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>>) = {
+ <linking:LinkingDirectives>
+ <func_directive:MethodDeclaration>
<tuning:TuningDirective*>
- <body:FunctionBody> => ast::Function{<>}
+ <body:FunctionBody> => {
+ (linking, ast::Function{func_directive, tuning, body})
+ }
};
LinkingDirective: ast::LinkingDirective = {
@@ -388,44 +396,50 @@ LinkingDirectives: ast::LinkingDirective = {
}
}
-MethodDecl: ast::MethodDecl<'input, &'input str> = {
- ".entry" <name:ExtendedID> <in_args:KernelArguments> =>
- ast::MethodDecl::Kernel{ name, in_args },
- ".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
- ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params)
+MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = {
+ ".entry" <name:ExtendedID> <input_arguments:KernelArguments> => {
+ let return_arguments = Vec::new();
+ let name = ast::MethodName::Kernel(name);
+ ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None }
+ },
+ ".func" <return_arguments:FnArguments?> <name:ExtendedID> <input_arguments:FnArguments> => {
+ let return_arguments = return_arguments.unwrap_or_else(|| Vec::new());
+ let name = ast::MethodName::Func(name);
+ ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None }
}
};
-KernelArguments: Vec<ast::KernelArgument<&'input str>> = {
+KernelArguments: Vec<ast::Variable<&'input str>> = {
"(" <args:Comma<KernelInput>> ")" => args
};
-FnArguments: Vec<ast::FnArgument<&'input str>> = {
+FnArguments: Vec<ast::Variable<&'input str>> = {
"(" <args:Comma<FnInput>> ")" => args
};
-KernelInput: ast::Variable<ast::KernelArgumentType, &'input str> = {
+KernelInput: ast::Variable<&'input str> = {
<v:ParamDeclaration> => {
let (align, v_type, name) = v;
ast::Variable {
align,
- v_type: ast::KernelArgumentType::Normal(v_type),
+ v_type,
+ state_space: ast::StateSpace::Param,
name,
array_init: Vec::new()
}
}
}
-FnInput: ast::Variable<ast::FnArgumentType, &'input str> = {
+FnInput: ast::Variable<&'input str> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
- let v_type = ast::FnArgumentType::Reg(v_type);
- ast::Variable{ align, v_type, name, array_init: Vec::new() }
+ let state_space = ast::StateSpace::Reg;
+ ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() }
},
<v:ParamDeclaration> => {
let (align, v_type, name) = v;
- let v_type = ast::FnArgumentType::Param(v_type);
- ast::Variable{ align, v_type, name, array_init: Vec::new() }
+ let state_space = ast::StateSpace::Param;
+ ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() }
}
}
@@ -508,141 +522,148 @@ VariableParam: u32 = {
"<" <n:U32Num> ">" => n
}
-Variable: ast::Variable<ast::VariableType, &'input str> = {
+Variable: ast::Variable<&'input str> = {
<v:RegVariable> => {
let (align, v_type, name) = v;
- let v_type = ast::VariableType::Reg(v_type);
- ast::Variable {align, v_type, name, array_init: Vec::new()}
+ let state_space = ast::StateSpace::Reg;
+ ast::Variable {align, v_type, state_space, name, array_init: Vec::new()}
},
LocalVariable,
<v:ParamVariable> => {
let (align, array_init, v_type, name) = v;
- let v_type = ast::VariableType::Param(v_type);
- ast::Variable {align, v_type, name, array_init}
+ let state_space = ast::StateSpace::Param;
+ ast::Variable {align, v_type, state_space, name, array_init}
},
SharedVariable,
};
-RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
+RegVariable: (Option<u32>, ast::Type, &'input str) = {
".reg" <var:VariableScalar<ScalarType>> => {
let (align, t, name) = var;
- let v_type = ast::VariableRegType::Scalar(t);
+ let v_type = ast::Type::Scalar(t);
(align, v_type, name)
},
".reg" <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var;
- let v_type = ast::VariableRegType::Vector(t, v_len);
+ let v_type = ast::Type::Vector(t, v_len);
(align, v_type, name)
}
}
-LocalVariable: ast::Variable<ast::VariableType, &'input str> = {
+LocalVariable: ast::Variable<&'input str> = {
".local" <var:VariableScalar<SizedScalarType>> => {
let (align, t, name) = var;
- let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
- ast::Variable { align, v_type, name, array_init: Vec::new() }
+ let v_type = ast::Type::Scalar(t);
+ let state_space = ast::StateSpace::Local;
+ ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
".local" <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var;
- let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
- ast::Variable { align, v_type, name, array_init: Vec::new() }
+ let v_type = ast::Type::Vector(t, v_len);
+ let state_space = ast::StateSpace::Local;
+ ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
".local" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var;
+ let state_space = ast::StateSpace::Local;
let (v_type, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
- (ast::VariableLocalType::Array(t, dimensions), init)
+ (ast::Type::Array(t, dimensions), init)
}
ast::ArrayOrPointer::Pointer => {
return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
}
};
- Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init })
+ Ok(ast::Variable { align, v_type, state_space, name, array_init })
}
}
-SharedVariable: ast::Variable<ast::VariableType, &'input str> = {
+SharedVariable: ast::Variable<&'input str> = {
".shared" <var:VariableScalar<SizedScalarType>> => {
let (align, t, name) = var;
- let v_type = ast::VariableGlobalType::Scalar(t);
- ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
+ let state_space = ast::StateSpace::Shared;
+ let v_type = ast::Type::Scalar(t);
+ ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
".shared" <var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var;
- let v_type = ast::VariableGlobalType::Vector(t, v_len);
- ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
+ let state_space = ast::StateSpace::Shared;
+ let v_type = ast::Type::Vector(t, v_len);
+ ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }
},
".shared" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var;
+ let state_space = ast::StateSpace::Shared;
let (v_type, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
- (ast::VariableGlobalType::Array(t, dimensions), init)
+ (ast::Type::Array(t, dimensions), init)
}
ast::ArrayOrPointer::Pointer => {
return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
}
};
- Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init })
+ Ok(ast::Variable { align, v_type, state_space, name, array_init })
}
}
-
-ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
- LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => {
+ModuleVariable: (ast::LinkingDirective, ast::Variable<&'input str>) = {
+ <linking:LinkingDirectives> ".global" <def:GlobalVariableDefinitionNoArray> => {
let (align, v_type, name, array_init) = def;
- ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init }
+ let state_space = ast::StateSpace::Global;
+ (linking, ast::Variable { align, v_type, state_space, name, array_init })
},
- LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => {
+ <linking:LinkingDirectives> ".shared" <def:GlobalVariableDefinitionNoArray> => {
let (align, v_type, name, array_init) = def;
- ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
+ let state_space = ast::StateSpace::Shared;
+ (linking, ast::Variable { align, v_type, state_space, name, array_init: Vec::new() })
},
- <ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
+ <linking:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
let (align, t, name, arr_or_ptr) = var;
- let (v_type, array_init) = match arr_or_ptr {
+ let (v_type, state_space, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
if space == ".global" {
- (ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init)
+ (ast::Type::Array(t, dimensions), ast::StateSpace::Global, init)
} else {
- (ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init)
+ (ast::Type::Array(t, dimensions), ast::StateSpace::Shared, init)
}
}
ast::ArrayOrPointer::Pointer => {
- if !ldirs.contains(ast::LinkingDirective::EXTERN) {
+ if !linking.contains(ast::LinkingDirective::EXTERN) {
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
}
if space == ".global" {
- (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Global)), Vec::new())
+ (ast::Type::Array(t, Vec::new()), ast::StateSpace::Global, Vec::new())
} else {
- (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Shared)), Vec::new())
+ (ast::Type::Array(t, Vec::new()), ast::StateSpace::Shared, Vec::new())
}
}
};
- Ok(ast::Variable{ align, array_init, v_type, name })
+ Ok((linking, ast::Variable{ align, v_type, state_space, name, array_init }))
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
-ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
+ParamVariable: (Option<u32>, Vec<u8>, ast::Type, &'input str) = {
".param" <var:VariableScalar<LdStScalarType>> => {
let (align, t, name) = var;
- let v_type = ast::VariableParamType::Scalar(t);
+ let v_type = ast::Type::Scalar(t);
(align, Vec::new(), v_type, name)
},
".param" <var:VariableArrayOrPointer<SizedScalarType>> => {
let (align, t, name, arr_or_ptr) = var;
let (v_type, array_init) = match arr_or_ptr {
ast::ArrayOrPointer::Array { dimensions, init } => {
- (ast::VariableParamType::Array(t, dimensions), init)
+ (ast::Type::Array(t, dimensions), init)
}
ast::ArrayOrPointer::Pointer => {
- (ast::VariableParamType::Pointer(t, ast::PointerStateSpace::Param), Vec::new())
+ (ast::Type::Scalar(t), Vec::new())
}
};
(align, array_init, v_type, name)
}
}
-ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
+ParamDeclaration: (Option<u32>, ast::Type, &'input str) = {
<var:ParamVariable> =>? {
let (align, array_init, v_type, name) = var;
if array_init.len() > 0 {
@@ -653,56 +674,56 @@ ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
}
}
-GlobalVariableDefinitionNoArray: (Option<u32>, ast::VariableGlobalType, &'input str, Vec<u8>) = {
+GlobalVariableDefinitionNoArray: (Option<u32>, ast::Type, &'input str, Vec<u8>) = {
<scalar:VariableScalar<SizedScalarType>> => {
let (align, t, name) = scalar;
- let v_type = ast::VariableGlobalType::Scalar(t);
+ let v_type = ast::Type::Scalar(t);
(align, v_type, name, Vec::new())
},
<var:VariableVector<SizedScalarType>> => {
let (align, v_len, t, name) = var;
- let v_type = ast::VariableGlobalType::Vector(t, v_len);
+ let v_type = ast::Type::Vector(t, v_len);
(align, v_type, name, Vec::new())
},
}
#[inline]
-SizedScalarType: ast::SizedScalarType = {
- ".b8" => ast::SizedScalarType::B8,
- ".b16" => ast::SizedScalarType::B16,
- ".b32" => ast::SizedScalarType::B32,
- ".b64" => ast::SizedScalarType::B64,
- ".u8" => ast::SizedScalarType::U8,
- ".u16" => ast::SizedScalarType::U16,
- ".u32" => ast::SizedScalarType::U32,
- ".u64" => ast::SizedScalarType::U64,
- ".s8" => ast::SizedScalarType::S8,
- ".s16" => ast::SizedScalarType::S16,
- ".s32" => ast::SizedScalarType::S32,
- ".s64" => ast::SizedScalarType::S64,
- ".f16" => ast::SizedScalarType::F16,
- ".f16x2" => ast::SizedScalarType::F16x2,
- ".f32" => ast::SizedScalarType::F32,
- ".f64" => ast::SizedScalarType::F64,
+SizedScalarType: ast::ScalarType = {
+ ".b8" => ast::ScalarType::B8,
+ ".b16" => ast::ScalarType::B16,
+ ".b32" => ast::ScalarType::B32,
+ ".b64" => ast::ScalarType::B64,
+ ".u8" => ast::ScalarType::U8,
+ ".u16" => ast::ScalarType::U16,
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
+ ".s8" => ast::ScalarType::S8,
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
+ ".f16" => ast::ScalarType::F16,
+ ".f16x2" => ast::ScalarType::F16x2,
+ ".f32" => ast::ScalarType::F32,
+ ".f64" => ast::ScalarType::F64,
}
#[inline]
-LdStScalarType: ast::LdStScalarType = {
- ".b8" => ast::LdStScalarType::B8,
- ".b16" => ast::LdStScalarType::B16,
- ".b32" => ast::LdStScalarType::B32,
- ".b64" => ast::LdStScalarType::B64,
- ".u8" => ast::LdStScalarType::U8,
- ".u16" => ast::LdStScalarType::U16,
- ".u32" => ast::LdStScalarType::U32,
- ".u64" => ast::LdStScalarType::U64,
- ".s8" => ast::LdStScalarType::S8,
- ".s16" => ast::LdStScalarType::S16,
- ".s32" => ast::LdStScalarType::S32,
- ".s64" => ast::LdStScalarType::S64,
- ".f16" => ast::LdStScalarType::F16,
- ".f32" => ast::LdStScalarType::F32,
- ".f64" => ast::LdStScalarType::F64,
+LdStScalarType: ast::ScalarType = {
+ ".b8" => ast::ScalarType::B8,
+ ".b16" => ast::ScalarType::B16,
+ ".b32" => ast::ScalarType::B32,
+ ".b64" => ast::ScalarType::B64,
+ ".u8" => ast::ScalarType::U8,
+ ".u16" => ast::ScalarType::U16,
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
+ ".s8" => ast::ScalarType::S8,
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
+ ".f16" => ast::ScalarType::F16,
+ ".f32" => ast::ScalarType::F32,
+ ".f64" => ast::ScalarType::F64,
}
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
@@ -755,7 +776,7 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
ast::Instruction::Ld(
ast::LdDetails {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
- state_space: ss.unwrap_or(ast::LdStateSpace::Generic),
+ state_space: ss.unwrap_or(ast::StateSpace::Generic),
caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
typ: t,
non_coherent: false
@@ -767,7 +788,7 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
ast::Instruction::Ld(
ast::LdDetails {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
- state_space: ast::LdStateSpace::Global,
+ state_space: ast::StateSpace::Global,
caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
typ: t,
non_coherent: false
@@ -779,7 +800,7 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
ast::Instruction::Ld(
ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
- state_space: ast::LdStateSpace::Global,
+ state_space: ast::StateSpace::Global,
caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
typ: t,
non_coherent: true
@@ -789,9 +810,9 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
}
};
-LdStType: ast::LdStType = {
- <v:VectorPrefix> <t:LdStScalarType> => ast::LdStType::Vector(t, v),
- <t:LdStScalarType> => ast::LdStType::Scalar(t),
+LdStType: ast::Type = {
+ <v:VectorPrefix> <t:LdStScalarType> => ast::Type::Vector(t, v),
+ <t:LdStScalarType> => ast::Type::Scalar(t),
}
LdStQualifier: ast::LdStQualifier = {
@@ -807,11 +828,11 @@ MemScope: ast::MemScope = {
".sys" => ast::MemScope::Sys
};
-LdNonGlobalStateSpace: ast::LdStateSpace = {
- ".const" => ast::LdStateSpace::Const,
- ".local" => ast::LdStateSpace::Local,
- ".param" => ast::LdStateSpace::Param,
- ".shared" => ast::LdStateSpace::Shared,
+LdNonGlobalStateSpace: ast::StateSpace = {
+ ".const" => ast::StateSpace::Const,
+ ".local" => ast::StateSpace::Local,
+ ".param" => ast::StateSpace::Param,
+ ".shared" => ast::StateSpace::Shared,
};
LdCacheOperator: ast::LdCacheOperator = {
@@ -899,39 +920,39 @@ RoundingModeInt : ast::RoundingMode = {
".rpi" => ast::RoundingMode::PositiveInf,
};
-IntType : ast::IntType = {
- ".u16" => ast::IntType::U16,
- ".u32" => ast::IntType::U32,
- ".u64" => ast::IntType::U64,
- ".s16" => ast::IntType::S16,
- ".s32" => ast::IntType::S32,
- ".s64" => ast::IntType::S64,
+IntType : ast::ScalarType = {
+ ".u16" => ast::ScalarType::U16,
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
};
-IntType3264: ast::IntType = {
- ".u32" => ast::IntType::U32,
- ".u64" => ast::IntType::U64,
- ".s32" => ast::IntType::S32,
- ".s64" => ast::IntType::S64,
+IntType3264: ast::ScalarType = {
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
}
-UIntType: ast::UIntType = {
- ".u16" => ast::UIntType::U16,
- ".u32" => ast::UIntType::U32,
- ".u64" => ast::UIntType::U64,
+UIntType: ast::ScalarType = {
+ ".u16" => ast::ScalarType::U16,
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
};
-SIntType: ast::SIntType = {
- ".s16" => ast::SIntType::S16,
- ".s32" => ast::SIntType::S32,
- ".s64" => ast::SIntType::S64,
+SIntType: ast::ScalarType = {
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
};
-FloatType: ast::FloatType = {
- ".f16" => ast::FloatType::F16,
- ".f16x2" => ast::FloatType::F16x2,
- ".f32" => ast::FloatType::F32,
- ".f64" => ast::FloatType::F64,
+FloatType: ast::ScalarType = {
+ ".f16" => ast::ScalarType::F16,
+ ".f16x2" => ast::ScalarType::F16x2,
+ ".f32" => ast::ScalarType::F32,
+ ".f64" => ast::ScalarType::F64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add
@@ -1023,11 +1044,11 @@ InstNot: ast::Instruction<ast::ParsedArgParams<'input>> = {
"not" <t:BooleanType> <a:Arg2> => ast::Instruction::Not(t, a)
};
-BooleanType: ast::BooleanType = {
- ".pred" => ast::BooleanType::Pred,
- ".b16" => ast::BooleanType::B16,
- ".b32" => ast::BooleanType::B32,
- ".b64" => ast::BooleanType::B64,
+BooleanType: ast::ScalarType = {
+ ".pred" => ast::ScalarType::Pred,
+ ".b16" => ast::ScalarType::B16,
+ ".b32" => ast::ScalarType::B32,
+ ".b64" => ast::ScalarType::B64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at
@@ -1080,8 +1101,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: r,
flush_to_zero: None,
saturate: s.is_some(),
- dst: ast::FloatType::F16,
- src: ast::FloatType::F16
+ dst: ast::ScalarType::F16,
+ src: ast::ScalarType::F16
}
), a)
},
@@ -1091,8 +1112,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: None,
flush_to_zero: Some(f.is_some()),
saturate: s.is_some(),
- dst: ast::FloatType::F32,
- src: ast::FloatType::F16
+ dst: ast::ScalarType::F32,
+ src: ast::ScalarType::F16
}
), a)
},
@@ -1102,8 +1123,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: None,
flush_to_zero: None,
saturate: s.is_some(),
- dst: ast::FloatType::F64,
- src: ast::FloatType::F16
+ dst: ast::ScalarType::F64,
+ src: ast::ScalarType::F16
}
), a)
},
@@ -1113,8 +1134,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: Some(r),
flush_to_zero: Some(f.is_some()),
saturate: s.is_some(),
- dst: ast::FloatType::F16,
- src: ast::FloatType::F32
+ dst: ast::ScalarType::F16,
+ src: ast::ScalarType::F32
}
), a)
},
@@ -1124,8 +1145,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: r,
flush_to_zero: Some(f.is_some()),
saturate: s.is_some(),
- dst: ast::FloatType::F32,
- src: ast::FloatType::F32
+ dst: ast::ScalarType::F32,
+ src: ast::ScalarType::F32
}
), a)
},
@@ -1135,8 +1156,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: None,
flush_to_zero: None,
saturate: s.is_some(),
- dst: ast::FloatType::F64,
- src: ast::FloatType::F32
+ dst: ast::ScalarType::F64,
+ src: ast::ScalarType::F32
}
), a)
},
@@ -1146,8 +1167,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: Some(r),
flush_to_zero: None,
saturate: s.is_some(),
- dst: ast::FloatType::F16,
- src: ast::FloatType::F64
+ dst: ast::ScalarType::F16,
+ src: ast::ScalarType::F64
}
), a)
},
@@ -1157,8 +1178,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: Some(r),
flush_to_zero: Some(s.is_some()),
saturate: s.is_some(),
- dst: ast::FloatType::F32,
- src: ast::FloatType::F64
+ dst: ast::ScalarType::F32,
+ src: ast::ScalarType::F64
}
), a)
},
@@ -1168,28 +1189,28 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = {
rounding: r,
flush_to_zero: None,
saturate: s.is_some(),
- dst: ast::FloatType::F64,
- src: ast::FloatType::F64
+ dst: ast::ScalarType::F64,
+ src: ast::ScalarType::F64
}
), a)
},
};
-CvtTypeInt: ast::IntType = {
- ".u8" => ast::IntType::U8,
- ".u16" => ast::IntType::U16,
- ".u32" => ast::IntType::U32,
- ".u64" => ast::IntType::U64,
- ".s8" => ast::IntType::S8,
- ".s16" => ast::IntType::S16,
- ".s32" => ast::IntType::S32,
- ".s64" => ast::IntType::S64,
+CvtTypeInt: ast::ScalarType = {
+ ".u8" => ast::ScalarType::U8,
+ ".u16" => ast::ScalarType::U16,
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
+ ".s8" => ast::ScalarType::S8,
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
};
-CvtTypeFloat: ast::FloatType = {
- ".f16" => ast::FloatType::F16,
- ".f32" => ast::FloatType::F32,
- ".f64" => ast::FloatType::F64,
+CvtTypeFloat: ast::ScalarType = {
+ ".f16" => ast::ScalarType::F16,
+ ".f32" => ast::ScalarType::F32,
+ ".f64" => ast::ScalarType::F64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl
@@ -1197,10 +1218,10 @@ InstShl: ast::Instruction<ast::ParsedArgParams<'input>> = {
"shl" <t:ShlType> <a:Arg3> => ast::Instruction::Shl(t, a)
};
-ShlType: ast::ShlType = {
- ".b16" => ast::ShlType::B16,
- ".b32" => ast::ShlType::B32,
- ".b64" => ast::ShlType::B64,
+ShlType: ast::ScalarType = {
+ ".b16" => ast::ScalarType::B16,
+ ".b32" => ast::ScalarType::B32,
+ ".b64" => ast::ScalarType::B64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr
@@ -1208,16 +1229,16 @@ InstShr: ast::Instruction<ast::ParsedArgParams<'input>> = {
"shr" <t:ShrType> <a:Arg3> => ast::Instruction::Shr(t, a)
};
-ShrType: ast::ShrType = {
- ".b16" => ast::ShrType::B16,
- ".b32" => ast::ShrType::B32,
- ".b64" => ast::ShrType::B64,
- ".u16" => ast::ShrType::U16,
- ".u32" => ast::ShrType::U32,
- ".u64" => ast::ShrType::U64,
- ".s16" => ast::ShrType::S16,
- ".s32" => ast::ShrType::S32,
- ".s64" => ast::ShrType::S64,
+ShrType: ast::ScalarType = {
+ ".b16" => ast::ScalarType::B16,
+ ".b32" => ast::ScalarType::B32,
+ ".b64" => ast::ScalarType::B64,
+ ".u16" => ast::ScalarType::U16,
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
@@ -1227,7 +1248,7 @@ InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
ast::Instruction::St(
ast::StData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
- state_space: ss.unwrap_or(ast::StStateSpace::Generic),
+ state_space: ss.unwrap_or(ast::StateSpace::Generic),
caching: cop.unwrap_or(ast::StCacheOperator::Writeback),
typ: t
},
@@ -1241,11 +1262,11 @@ MemoryOperand: ast::Operand<&'input str> = {
"[" <o:Operand> "]" => o
}
-StStateSpace: ast::StStateSpace = {
- ".global" => ast::StStateSpace::Global,
- ".local" => ast::StStateSpace::Local,
- ".param" => ast::StStateSpace::Param,
- ".shared" => ast::StStateSpace::Shared,
+StStateSpace: ast::StateSpace = {
+ ".global" => ast::StateSpace::Global,
+ ".local" => ast::StateSpace::Local,
+ ".param" => ast::StateSpace::Param,
+ ".shared" => ast::StateSpace::Shared,
};
StCacheOperator: ast::StCacheOperator = {
@@ -1264,7 +1285,7 @@ InstRet: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstCvta: ast::Instruction<ast::ParsedArgParams<'input>> = {
"cvta" <from:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
ast::Instruction::Cvta(ast::CvtaDetails {
- to: ast::CvtaStateSpace::Generic,
+ to: ast::StateSpace::Generic,
from,
size: s
},
@@ -1273,18 +1294,18 @@ InstCvta: ast::Instruction<ast::ParsedArgParams<'input>> = {
"cvta" ".to" <to:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
ast::Instruction::Cvta(ast::CvtaDetails {
to,
- from: ast::CvtaStateSpace::Generic,
+ from: ast::StateSpace::Generic,
size: s
},
a)
}
}
-CvtaStateSpace: ast::CvtaStateSpace = {
- ".const" => ast::CvtaStateSpace::Const,
- ".global" => ast::CvtaStateSpace::Global,
- ".local" => ast::CvtaStateSpace::Local,
- ".shared" => ast::CvtaStateSpace::Shared,
+CvtaStateSpace: ast::StateSpace = {
+ ".const" => ast::StateSpace::Const,
+ ".global" => ast::StateSpace::Global,
+ ".local" => ast::StateSpace::Local,
+ ".shared" => ast::StateSpace::Shared,
}
CvtaSize: ast::CvtaSize = {
@@ -1393,16 +1414,16 @@ MinMaxDetails: ast::MinMaxDetails = {
<t:UIntType> => ast::MinMaxDetails::Unsigned(t),
<t:SIntType> => ast::MinMaxDetails::Signed(t),
<ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float(
- ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F32 }
+ ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F32 }
),
".f64" => ast::MinMaxDetails::Float(
- ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::FloatType::F64 }
+ ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::ScalarType::F64 }
),
<ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float(
- ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16 }
+ ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16 }
),
<ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float(
- ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16x2 }
+ ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16x2 }
)
}
@@ -1411,18 +1432,18 @@ InstSelp: ast::Instruction<ast::ParsedArgParams<'input>> = {
"selp" <t:SelpType> <a:Arg4> => ast::Instruction::Selp(t, a),
};
-SelpType: ast::SelpType = {
- ".b16" => ast::SelpType::B16,
- ".b32" => ast::SelpType::B32,
- ".b64" => ast::SelpType::B64,
- ".u16" => ast::SelpType::U16,
- ".u32" => ast::SelpType::U32,
- ".u64" => ast::SelpType::U64,
- ".s16" => ast::SelpType::S16,
- ".s32" => ast::SelpType::S32,
- ".s64" => ast::SelpType::S64,
- ".f32" => ast::SelpType::F32,
- ".f64" => ast::SelpType::F64,
+SelpType: ast::ScalarType = {
+ ".b16" => ast::ScalarType::B16,
+ ".b32" => ast::ScalarType::B32,
+ ".b64" => ast::ScalarType::B64,
+ ".u16" => ast::ScalarType::U16,
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
+ ".s16" => ast::ScalarType::S16,
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
+ ".f32" => ast::ScalarType::F32,
+ ".f64" => ast::ScalarType::F64,
};
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar
@@ -1442,7 +1463,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
inner: ast::AtomInnerDetails::Bit { op, typ }
};
ast::Instruction::Atom(details,a)
@@ -1451,10 +1472,10 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
inner: ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Inc,
- typ: ast::UIntType::U32
+ typ: ast::ScalarType::U32
}
};
ast::Instruction::Atom(details,a)
@@ -1463,10 +1484,10 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
inner: ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Dec,
- typ: ast::UIntType::U32
+ typ: ast::ScalarType::U32
}
};
ast::Instruction::Atom(details,a)
@@ -1476,7 +1497,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
inner: ast::AtomInnerDetails::Float { op, typ }
};
ast::Instruction::Atom(details,a)
@@ -1485,7 +1506,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
inner: ast::AtomInnerDetails::Unsigned { op, typ }
};
ast::Instruction::Atom(details,a)
@@ -1494,7 +1515,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
inner: ast::AtomInnerDetails::Signed { op, typ }
};
ast::Instruction::Atom(details,a)
@@ -1506,7 +1527,7 @@ InstAtomCas: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomCasDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
typ,
};
ast::Instruction::AtomCas(details,a)
@@ -1520,9 +1541,9 @@ AtomSemantics: ast::AtomSemantics = {
".acq_rel" => ast::AtomSemantics::AcquireRelease
}
-AtomSpace: ast::AtomSpace = {
- ".global" => ast::AtomSpace::Global,
- ".shared" => ast::AtomSpace::Shared
+AtomSpace: ast::StateSpace = {
+ ".global" => ast::StateSpace::Global,
+ ".shared" => ast::StateSpace::Shared
}
AtomBitOp: ast::AtomBitOp = {
@@ -1544,19 +1565,19 @@ AtomSIntOp: ast::AtomSIntOp = {
".max" => ast::AtomSIntOp::Max,
}
-BitType: ast::BitType = {
- ".b32" => ast::BitType::B32,
- ".b64" => ast::BitType::B64,
+BitType: ast::ScalarType = {
+ ".b32" => ast::ScalarType::B32,
+ ".b64" => ast::ScalarType::B64,
}
-UIntType3264: ast::UIntType = {
- ".u32" => ast::UIntType::U32,
- ".u64" => ast::UIntType::U64,
+UIntType3264: ast::ScalarType = {
+ ".u32" => ast::ScalarType::U32,
+ ".u64" => ast::ScalarType::U64,
}
-SIntType3264: ast::SIntType = {
- ".s32" => ast::SIntType::S32,
- ".s64" => ast::SIntType::S64,
+SIntType3264: ast::ScalarType = {
+ ".s32" => ast::ScalarType::S32,
+ ".s64" => ast::ScalarType::S64,
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div
@@ -1566,7 +1587,7 @@ InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = {
"div" <t:SIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Signed(t), a),
"div" <kind:DivFloatKind> <ftz:".ftz"?> ".f32" <a:Arg3> => {
let inner = ast::DivFloatDetails {
- typ: ast::FloatType::F32,
+ typ: ast::ScalarType::F32,
flush_to_zero: Some(ftz.is_some()),
kind
};
@@ -1574,7 +1595,7 @@ InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = {
},
"div" <rnd:RoundingModeFloat> ".f64" <a:Arg3> => {
let inner = ast::DivFloatDetails {
- typ: ast::FloatType::F64,
+ typ: ast::ScalarType::F64,
flush_to_zero: None,
kind: ast::DivFloatKind::Rounding(rnd)
};
@@ -1592,7 +1613,7 @@ DivFloatKind: ast::DivFloatKind = {
InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"sqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
let details = ast::SqrtDetails {
- typ: ast::FloatType::F32,
+ typ: ast::ScalarType::F32,
flush_to_zero: Some(ftz.is_some()),
kind: ast::SqrtKind::Approx,
};
@@ -1600,7 +1621,7 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
},
"sqrt" <rnd:RoundingModeFloat> <ftz:".ftz"?> ".f32" <a:Arg2> => {
let details = ast::SqrtDetails {
- typ: ast::FloatType::F32,
+ typ: ast::ScalarType::F32,
flush_to_zero: Some(ftz.is_some()),
kind: ast::SqrtKind::Rounding(rnd),
};
@@ -1608,7 +1629,7 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
},
"sqrt" <rnd:RoundingModeFloat> ".f64" <a:Arg2> => {
let details = ast::SqrtDetails {
- typ: ast::FloatType::F64,
+ typ: ast::ScalarType::F64,
flush_to_zero: None,
kind: ast::SqrtKind::Rounding(rnd),
};
@@ -1621,14 +1642,14 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = {
"rsqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => {
let details = ast::RsqrtDetails {
- typ: ast::FloatType::F32,
+ typ: ast::ScalarType::F32,
flush_to_zero: ftz.is_some(),
};
ast::Instruction::Rsqrt(details, a)
},
"rsqrt" ".approx" <ftz:".ftz"?> ".f64" <a:Arg2> => {
let details = ast::RsqrtDetails {
- typ: ast::FloatType::F64,
+ typ: ast::ScalarType::F64,
flush_to_zero: ftz.is_some(),
};
ast::Instruction::Rsqrt(details, a)
@@ -1739,7 +1760,7 @@ ArithDetails: ast::ArithDetails = {
saturate: false,
}),
".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S32,
+ typ: ast::ScalarType::S32,
saturate: true,
}),
<f:ArithFloat> => ast::ArithDetails::Float(f)
@@ -1747,25 +1768,25 @@ ArithDetails: ast::ArithDetails = {
ArithFloat: ast::ArithFloat = {
<rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
- typ: ast::FloatType::F32,
+ typ: ast::ScalarType::F32,
rounding: rn,
flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(),
},
<rn:RoundingModeFloat?> ".f64" => ast::ArithFloat {
- typ: ast::FloatType::F64,
+ typ: ast::ScalarType::F64,
rounding: rn,
flush_to_zero: None,
saturate: false,
},
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
- typ: ast::FloatType::F16,
+ typ: ast::ScalarType::F16,
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(),
},
<rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
- typ: ast::FloatType::F16x2,
+ typ: ast::ScalarType::F16x2,
rounding: rn.map(|_| ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(),
@@ -1774,25 +1795,25 @@ ArithFloat: ast::ArithFloat = {
ArithFloatMustRound: ast::ArithFloat = {
<rn:RoundingModeFloat> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat {
- typ: ast::FloatType::F32,
+ typ: ast::ScalarType::F32,
rounding: Some(rn),
flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(),
},
<rn:RoundingModeFloat> ".f64" => ast::ArithFloat {
- typ: ast::FloatType::F64,
+ typ: ast::ScalarType::F64,
rounding: Some(rn),
flush_to_zero: None,
saturate: false,
},
".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat {
- typ: ast::FloatType::F16,
+ typ: ast::ScalarType::F16,
rounding: Some(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(),
},
".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat {
- typ: ast::FloatType::F16x2,
+ typ: ast::ScalarType::F16x2,
rounding: Some(ast::RoundingMode::NearestEven),
flush_to_zero: Some(ftz.is_some()),
saturate: sat.is_some(),
diff --git a/ptx/src/test/spirv_run/and.spvtxt b/ptx/src/test/spirv_run/and.spvtxt
index a378602..f66639a 100644
--- a/ptx/src/test/spirv_run/and.spvtxt
+++ b/ptx/src/test/spirv_run/and.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%1 = OpFunction %void None %34
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
@@ -39,9 +41,11 @@
%12 = OpLoad %uint %23 Aligned 4
OpStore %6 %12
%15 = OpLoad %ulong %4
- %22 = OpIAdd %ulong %15 %ulong_4
- %24 = OpConvertUToPtr %_ptr_Generic_uint %22
- %14 = OpLoad %uint %24 Aligned 4
+ %24 = OpConvertUToPtr %_ptr_Generic_uint %15
+ %41 = OpBitcast %_ptr_Generic_uchar %24
+ %42 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %41 %ulong_4
+ %22 = OpBitcast %_ptr_Generic_uint %42
+ %14 = OpLoad %uint %22 Aligned 4
OpStore %7 %14
%17 = OpLoad %uint %6
%18 = OpLoad %uint %7
diff --git a/ptx/src/test/spirv_run/atom_add.spvtxt b/ptx/src/test/spirv_run/atom_add.spvtxt
index 3966da6..b4de00a 100644
--- a/ptx/src/test/spirv_run/atom_add.spvtxt
+++ b/ptx/src/test/spirv_run/atom_add.spvtxt
@@ -24,6 +24,7 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
%uint_1 = OpConstant %uint 1
%uint_0 = OpConstant %uint 0
@@ -49,9 +50,11 @@
%13 = OpLoad %uint %29 Aligned 4
OpStore %7 %13
%16 = OpLoad %ulong %5
- %26 = OpIAdd %ulong %16 %ulong_4
- %30 = OpConvertUToPtr %_ptr_Generic_uint %26
- %15 = OpLoad %uint %30 Aligned 4
+ %30 = OpConvertUToPtr %_ptr_Generic_uint %16
+ %51 = OpBitcast %_ptr_Generic_uchar %30
+ %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4
+ %26 = OpBitcast %_ptr_Generic_uint %52
+ %15 = OpLoad %uint %26 Aligned 4
OpStore %8 %15
%17 = OpLoad %uint %7
%31 = OpBitcast %_ptr_Workgroup_uint %4
@@ -69,8 +72,10 @@
OpStore %34 %22 Aligned 4
%23 = OpLoad %ulong %6
%24 = OpLoad %uint %8
- %28 = OpIAdd %ulong %23 %ulong_4_0
- %35 = OpConvertUToPtr %_ptr_Generic_uint %28
- OpStore %35 %24 Aligned 4
+ %35 = OpConvertUToPtr %_ptr_Generic_uint %23
+ %56 = OpBitcast %_ptr_Generic_uchar %35
+ %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4_0
+ %28 = OpBitcast %_ptr_Generic_uint %57
+ OpStore %28 %24 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/atom_add_float.spvtxt b/ptx/src/test/spirv_run/atom_add_float.spvtxt
index c2292f1..7d25632 100644
--- a/ptx/src/test/spirv_run/atom_add_float.spvtxt
+++ b/ptx/src/test/spirv_run/atom_add_float.spvtxt
@@ -28,6 +28,7 @@
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_4_0 = OpConstant %ulong 4
%37 = OpFunction %float None %46
%39 = OpFunctionParameter %_ptr_Workgroup_float
@@ -54,9 +55,11 @@
%13 = OpLoad %float %29 Aligned 4
OpStore %7 %13
%16 = OpLoad %ulong %5
- %26 = OpIAdd %ulong %16 %ulong_4
- %30 = OpConvertUToPtr %_ptr_Generic_float %26
- %15 = OpLoad %float %30 Aligned 4
+ %30 = OpConvertUToPtr %_ptr_Generic_float %16
+ %58 = OpBitcast %_ptr_Generic_uchar %30
+ %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_4
+ %26 = OpBitcast %_ptr_Generic_float %59
+ %15 = OpLoad %float %26 Aligned 4
OpStore %8 %15
%17 = OpLoad %float %7
%31 = OpBitcast %_ptr_Workgroup_float %4
@@ -74,8 +77,10 @@
OpStore %34 %22 Aligned 4
%23 = OpLoad %ulong %6
%24 = OpLoad %float %8
- %28 = OpIAdd %ulong %23 %ulong_4_0
- %35 = OpConvertUToPtr %_ptr_Generic_float %28
- OpStore %35 %24 Aligned 4
+ %35 = OpConvertUToPtr %_ptr_Generic_float %23
+ %60 = OpBitcast %_ptr_Generic_uchar %35
+ %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4_0
+ %28 = OpBitcast %_ptr_Generic_float %61
+ OpStore %28 %24 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/atom_cas.spvtxt b/ptx/src/test/spirv_run/atom_cas.spvtxt
index e1feb0a..7c2f4fa 100644
--- a/ptx/src/test/spirv_run/atom_cas.spvtxt
+++ b/ptx/src/test/spirv_run/atom_cas.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%uint_100 = OpConstant %uint 100
%uint_1 = OpConstant %uint 1
%uint_0 = OpConstant %uint 0
@@ -45,16 +47,20 @@
OpStore %6 %12
%15 = OpLoad %ulong %4
%16 = OpLoad %uint %6
- %24 = OpIAdd %ulong %15 %ulong_4
- %32 = OpConvertUToPtr %_ptr_Generic_uint %24
+ %31 = OpConvertUToPtr %_ptr_Generic_uint %15
+ %49 = OpBitcast %_ptr_Generic_uchar %31
+ %50 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %49 %ulong_4
+ %24 = OpBitcast %_ptr_Generic_uint %50
%33 = OpCopyObject %uint %16
- %31 = OpAtomicCompareExchange %uint %32 %uint_1 %uint_0 %uint_0 %uint_100 %33
- %14 = OpCopyObject %uint %31
+ %32 = OpAtomicCompareExchange %uint %24 %uint_1 %uint_0 %uint_0 %uint_100 %33
+ %14 = OpCopyObject %uint %32
OpStore %6 %14
%18 = OpLoad %ulong %4
- %27 = OpIAdd %ulong %18 %ulong_4_0
- %34 = OpConvertUToPtr %_ptr_Generic_uint %27
- %17 = OpLoad %uint %34 Aligned 4
+ %34 = OpConvertUToPtr %_ptr_Generic_uint %18
+ %53 = OpBitcast %_ptr_Generic_uchar %34
+ %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_4_0
+ %27 = OpBitcast %_ptr_Generic_uint %54
+ %17 = OpLoad %uint %27 Aligned 4
OpStore %7 %17
%19 = OpLoad %ulong %5
%20 = OpLoad %uint %6
@@ -62,8 +68,10 @@
OpStore %35 %20 Aligned 4
%21 = OpLoad %ulong %5
%22 = OpLoad %uint %7
- %29 = OpIAdd %ulong %21 %ulong_4_1
- %36 = OpConvertUToPtr %_ptr_Generic_uint %29
- OpStore %36 %22 Aligned 4
+ %36 = OpConvertUToPtr %_ptr_Generic_uint %21
+ %55 = OpBitcast %_ptr_Generic_uchar %36
+ %56 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %55 %ulong_4_1
+ %29 = OpBitcast %_ptr_Generic_uint %56
+ OpStore %29 %22 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/atom_inc.spvtxt b/ptx/src/test/spirv_run/atom_inc.spvtxt
index 11b4243..4855cd4 100644
--- a/ptx/src/test/spirv_run/atom_inc.spvtxt
+++ b/ptx/src/test/spirv_run/atom_inc.spvtxt
@@ -10,14 +10,14 @@
%47 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "atom_inc"
- OpDecorate %42 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_global_inc" Import
OpDecorate %38 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_generic_inc" Import
+ OpDecorate %42 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_global_inc" Import
%void = OpTypeVoid
%uint = OpTypeInt 32 0
-%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
- %51 = OpTypeFunction %uint %_ptr_CrossWorkgroup_uint %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
- %53 = OpTypeFunction %uint %_ptr_Generic_uint %uint
+ %51 = OpTypeFunction %uint %_ptr_Generic_uint %uint
+%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
+ %53 = OpTypeFunction %uint %_ptr_CrossWorkgroup_uint %uint
%ulong = OpTypeInt 64 0
%55 = OpTypeFunction %void %ulong %ulong
%_ptr_Function_ulong = OpTypePointer Function %ulong
@@ -25,15 +25,17 @@
%uint_101 = OpConstant %uint 101
%uint_101_0 = OpConstant %uint 101
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_8 = OpConstant %ulong 8
- %42 = OpFunction %uint None %51
- %44 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
- %45 = OpFunctionParameter %uint
- OpFunctionEnd
- %38 = OpFunction %uint None %53
+ %38 = OpFunction %uint None %51
%40 = OpFunctionParameter %_ptr_Generic_uint
%41 = OpFunctionParameter %uint
OpFunctionEnd
+ %42 = OpFunction %uint None %53
+ %44 = OpFunctionParameter %_ptr_CrossWorkgroup_uint
+ %45 = OpFunctionParameter %uint
+ OpFunctionEnd
%1 = OpFunction %void None %55
%9 = OpFunctionParameter %ulong
%10 = OpFunctionParameter %ulong
@@ -69,13 +71,17 @@
OpStore %34 %20 Aligned 4
%21 = OpLoad %ulong %5
%22 = OpLoad %uint %7
- %28 = OpIAdd %ulong %21 %ulong_4
- %35 = OpConvertUToPtr %_ptr_Generic_uint %28
- OpStore %35 %22 Aligned 4
+ %35 = OpConvertUToPtr %_ptr_Generic_uint %21
+ %60 = OpBitcast %_ptr_Generic_uchar %35
+ %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4
+ %28 = OpBitcast %_ptr_Generic_uint %61
+ OpStore %28 %22 Aligned 4
%23 = OpLoad %ulong %5
%24 = OpLoad %uint %8
- %30 = OpIAdd %ulong %23 %ulong_8
- %36 = OpConvertUToPtr %_ptr_Generic_uint %30
- OpStore %36 %24 Aligned 4
+ %36 = OpConvertUToPtr %_ptr_Generic_uint %23
+ %62 = OpBitcast %_ptr_Generic_uchar %36
+ %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_8
+ %30 = OpBitcast %_ptr_Generic_uint %63
+ OpStore %30 %24 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/bfe.spvtxt b/ptx/src/test/spirv_run/bfe.spvtxt
index 535ede9..0001808 100644
--- a/ptx/src/test/spirv_run/bfe.spvtxt
+++ b/ptx/src/test/spirv_run/bfe.spvtxt
@@ -20,6 +20,8 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_8 = OpConstant %ulong 8
%34 = OpFunction %uint None %43
%36 = OpFunctionParameter %uint
@@ -48,14 +50,18 @@
%13 = OpLoad %uint %29 Aligned 4
OpStore %6 %13
%16 = OpLoad %ulong %4
- %26 = OpIAdd %ulong %16 %ulong_4
- %30 = OpConvertUToPtr %_ptr_Generic_uint %26
- %15 = OpLoad %uint %30 Aligned 4
+ %30 = OpConvertUToPtr %_ptr_Generic_uint %16
+ %51 = OpBitcast %_ptr_Generic_uchar %30
+ %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4
+ %26 = OpBitcast %_ptr_Generic_uint %52
+ %15 = OpLoad %uint %26 Aligned 4
OpStore %7 %15
%18 = OpLoad %ulong %4
- %28 = OpIAdd %ulong %18 %ulong_8
- %31 = OpConvertUToPtr %_ptr_Generic_uint %28
- %17 = OpLoad %uint %31 Aligned 4
+ %31 = OpConvertUToPtr %_ptr_Generic_uint %18
+ %53 = OpBitcast %_ptr_Generic_uchar %31
+ %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_8
+ %28 = OpBitcast %_ptr_Generic_uint %54
+ %17 = OpLoad %uint %28 Aligned 4
OpStore %8 %17
%20 = OpLoad %uint %6
%21 = OpLoad %uint %7
diff --git a/ptx/src/test/spirv_run/bfi.spvtxt b/ptx/src/test/spirv_run/bfi.spvtxt
index a226f78..1979939 100644
--- a/ptx/src/test/spirv_run/bfi.spvtxt
+++ b/ptx/src/test/spirv_run/bfi.spvtxt
@@ -20,6 +20,8 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_8 = OpConstant %ulong 8
%ulong_12 = OpConstant %ulong 12
%44 = OpFunction %uint None %54
@@ -51,19 +53,25 @@
%14 = OpLoad %uint %35 Aligned 4
OpStore %6 %14
%17 = OpLoad %ulong %4
- %30 = OpIAdd %ulong %17 %ulong_4
- %36 = OpConvertUToPtr %_ptr_Generic_uint %30
- %16 = OpLoad %uint %36 Aligned 4
+ %36 = OpConvertUToPtr %_ptr_Generic_uint %17
+ %62 = OpBitcast %_ptr_Generic_uchar %36
+ %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_4
+ %30 = OpBitcast %_ptr_Generic_uint %63
+ %16 = OpLoad %uint %30 Aligned 4
OpStore %7 %16
%19 = OpLoad %ulong %4
- %32 = OpIAdd %ulong %19 %ulong_8
- %37 = OpConvertUToPtr %_ptr_Generic_uint %32
- %18 = OpLoad %uint %37 Aligned 4
+ %37 = OpConvertUToPtr %_ptr_Generic_uint %19
+ %64 = OpBitcast %_ptr_Generic_uchar %37
+ %65 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %64 %ulong_8
+ %32 = OpBitcast %_ptr_Generic_uint %65
+ %18 = OpLoad %uint %32 Aligned 4
OpStore %8 %18
%21 = OpLoad %ulong %4
- %34 = OpIAdd %ulong %21 %ulong_12
- %38 = OpConvertUToPtr %_ptr_Generic_uint %34
- %20 = OpLoad %uint %38 Aligned 4
+ %38 = OpConvertUToPtr %_ptr_Generic_uint %21
+ %66 = OpBitcast %_ptr_Generic_uchar %38
+ %67 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %66 %ulong_12
+ %34 = OpBitcast %_ptr_Generic_uint %67
+ %20 = OpLoad %uint %34 Aligned 4
OpStore %9 %20
%23 = OpLoad %uint %6
%24 = OpLoad %uint %7
@@ -71,7 +79,7 @@
%26 = OpLoad %uint %9
%40 = OpCopyObject %uint %23
%41 = OpCopyObject %uint %24
- %39 = OpFunctionCall %uint %44 %41 %40 %25 %26
+ %39 = OpFunctionCall %uint %44 %40 %41 %25 %26
%22 = OpCopyObject %uint %39
OpStore %6 %22
%27 = OpLoad %ulong %5
diff --git a/ptx/src/test/spirv_run/call.spvtxt b/ptx/src/test/spirv_run/call.spvtxt
index 5473234..6929b1e 100644
--- a/ptx/src/test/spirv_run/call.spvtxt
+++ b/ptx/src/test/spirv_run/call.spvtxt
@@ -42,7 +42,7 @@
%23 = OpBitcast %_ptr_Function_ulong %10
%24 = OpCopyObject %ulong %18
OpStore %23 %24 Aligned 8
- %43 = OpFunctionCall %void %1 %11 %10
+ %43 = OpFunctionCall %void %1 %10 %11
%19 = OpLoad %ulong %11 Aligned 8
OpStore %9 %19
%20 = OpLoad %ulong %8
@@ -52,8 +52,8 @@
OpReturn
OpFunctionEnd
%1 = OpFunction %void None %44
- %27 = OpFunctionParameter %_ptr_Function_ulong
%28 = OpFunctionParameter %_ptr_Function_ulong
+ %27 = OpFunctionParameter %_ptr_Function_ulong
%35 = OpLabel
%29 = OpVariable %_ptr_Function_ulong Function
%30 = OpLoad %ulong %28 Aligned 8
diff --git a/ptx/src/test/spirv_run/cvt_rni.spvtxt b/ptx/src/test/spirv_run/cvt_rni.spvtxt
index 288a939..e10999c 100644
--- a/ptx/src/test/spirv_run/cvt_rni.spvtxt
+++ b/ptx/src/test/spirv_run/cvt_rni.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_4_0 = OpConstant %ulong 4
%1 = OpFunction %void None %37
%8 = OpFunctionParameter %ulong
@@ -40,9 +42,11 @@
%12 = OpLoad %float %28 Aligned 4
OpStore %6 %12
%15 = OpLoad %ulong %4
- %25 = OpIAdd %ulong %15 %ulong_4
- %29 = OpConvertUToPtr %_ptr_Generic_float %25
- %14 = OpLoad %float %29 Aligned 4
+ %29 = OpConvertUToPtr %_ptr_Generic_float %15
+ %44 = OpBitcast %_ptr_Generic_uchar %29
+ %45 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %44 %ulong_4
+ %25 = OpBitcast %_ptr_Generic_float %45
+ %14 = OpLoad %float %25 Aligned 4
OpStore %7 %14
%17 = OpLoad %float %6
%16 = OpExtInst %float %34 rint %17
@@ -56,8 +60,10 @@
OpStore %30 %21 Aligned 4
%22 = OpLoad %ulong %5
%23 = OpLoad %float %7
- %27 = OpIAdd %ulong %22 %ulong_4_0
- %31 = OpConvertUToPtr %_ptr_Generic_float %27
- OpStore %31 %23 Aligned 4
+ %31 = OpConvertUToPtr %_ptr_Generic_float %22
+ %46 = OpBitcast %_ptr_Generic_uchar %31
+ %47 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %46 %ulong_4_0
+ %27 = OpBitcast %_ptr_Generic_float %47
+ OpStore %27 %23 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/cvt_rzi.spvtxt b/ptx/src/test/spirv_run/cvt_rzi.spvtxt
index 68c12c6..7dda454 100644
--- a/ptx/src/test/spirv_run/cvt_rzi.spvtxt
+++ b/ptx/src/test/spirv_run/cvt_rzi.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_4_0 = OpConstant %ulong 4
%1 = OpFunction %void None %37
%8 = OpFunctionParameter %ulong
@@ -40,9 +42,11 @@
%12 = OpLoad %float %28 Aligned 4
OpStore %6 %12
%15 = OpLoad %ulong %4
- %25 = OpIAdd %ulong %15 %ulong_4
- %29 = OpConvertUToPtr %_ptr_Generic_float %25
- %14 = OpLoad %float %29 Aligned 4
+ %29 = OpConvertUToPtr %_ptr_Generic_float %15
+ %44 = OpBitcast %_ptr_Generic_uchar %29
+ %45 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %44 %ulong_4
+ %25 = OpBitcast %_ptr_Generic_float %45
+ %14 = OpLoad %float %25 Aligned 4
OpStore %7 %14
%17 = OpLoad %float %6
%16 = OpExtInst %float %34 trunc %17
@@ -56,8 +60,10 @@
OpStore %30 %21 Aligned 4
%22 = OpLoad %ulong %5
%23 = OpLoad %float %7
- %27 = OpIAdd %ulong %22 %ulong_4_0
- %31 = OpConvertUToPtr %_ptr_Generic_float %27
- OpStore %31 %23 Aligned 4
+ %31 = OpConvertUToPtr %_ptr_Generic_float %22
+ %46 = OpBitcast %_ptr_Generic_uchar %31
+ %47 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %46 %ulong_4_0
+ %27 = OpBitcast %_ptr_Generic_float %47
+ OpStore %27 %23 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt b/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt
index d9ae053..c1229d4 100644
--- a/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt
+++ b/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt
@@ -21,8 +21,11 @@
%float = OpTypeFloat 32
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
%ulong_4_0 = OpConstant %ulong 4
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
%1 = OpFunction %void None %45
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
@@ -45,10 +48,12 @@
%12 = OpBitcast %uint %28
OpStore %6 %12
%15 = OpLoad %ulong %4
- %25 = OpIAdd %ulong %15 %ulong_4
- %31 = OpConvertUToPtr %_ptr_Generic_float %25
- %30 = OpLoad %float %31 Aligned 4
- %14 = OpBitcast %uint %30
+ %30 = OpConvertUToPtr %_ptr_Generic_float %15
+ %53 = OpBitcast %_ptr_Generic_uchar %30
+ %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_4
+ %25 = OpBitcast %_ptr_Generic_float %54
+ %31 = OpLoad %float %25 Aligned 4
+ %14 = OpBitcast %uint %31
OpStore %7 %14
%17 = OpLoad %uint %6
%33 = OpBitcast %float %17
@@ -67,9 +72,11 @@
OpStore %36 %37 Aligned 4
%22 = OpLoad %ulong %5
%23 = OpLoad %uint %7
- %27 = OpIAdd %ulong %22 %ulong_4_0
- %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %27
+ %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %22
+ %57 = OpBitcast %_ptr_CrossWorkgroup_uchar %38
+ %58 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %57 %ulong_4_0
+ %27 = OpBitcast %_ptr_CrossWorkgroup_uint %58
%39 = OpCopyObject %uint %23
- OpStore %38 %39 Aligned 4
+ OpStore %27 %39 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/div_approx.spvtxt b/ptx/src/test/spirv_run/div_approx.spvtxt
index 274f73e..858ec8d 100644
--- a/ptx/src/test/spirv_run/div_approx.spvtxt
+++ b/ptx/src/test/spirv_run/div_approx.spvtxt
@@ -19,6 +19,8 @@
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%1 = OpFunction %void None %31
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
@@ -40,9 +42,11 @@
%12 = OpLoad %float %23 Aligned 4
OpStore %6 %12
%15 = OpLoad %ulong %4
- %22 = OpIAdd %ulong %15 %ulong_4
- %24 = OpConvertUToPtr %_ptr_Generic_float %22
- %14 = OpLoad %float %24 Aligned 4
+ %24 = OpConvertUToPtr %_ptr_Generic_float %15
+ %38 = OpBitcast %_ptr_Generic_uchar %24
+ %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4
+ %22 = OpBitcast %_ptr_Generic_float %39
+ %14 = OpLoad %float %22 Aligned 4
OpStore %7 %14
%17 = OpLoad %float %6
%18 = OpLoad %float %7
diff --git a/ptx/src/test/spirv_run/extern_shared.spvtxt b/ptx/src/test/spirv_run/extern_shared.spvtxt
index fb2987e..13587d5 100644
--- a/ptx/src/test/spirv_run/extern_shared.spvtxt
+++ b/ptx/src/test/spirv_run/extern_shared.spvtxt
@@ -7,37 +7,30 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %30 = OpExtInstImport "OpenCL.std"
+ %27 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %2 "extern_shared" %1
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
-%_ptr_Workgroup__ptr_Workgroup_uint = OpTypePointer Workgroup %_ptr_Workgroup_uint
- %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup
+ %1 = OpVariable %_ptr_Workgroup_uint Workgroup
%ulong = OpTypeInt 64 0
%uchar = OpTypeInt 8 0
%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
- %38 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar
-%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar
+ %34 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
-%_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
- %2 = OpFunction %void None %38
+ %2 = OpFunction %void None %34
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
- %26 = OpFunctionParameter %_ptr_Workgroup_uchar
- %39 = OpLabel
- %27 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function
+ %24 = OpFunctionParameter %_ptr_Workgroup_uchar
+ %22 = OpLabel
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
%6 = OpVariable %_ptr_Function_ulong Function
%7 = OpVariable %_ptr_Function_ulong Function
- OpStore %27 %26
- OpBranch %24
- %24 = OpLabel
OpStore %3 %8
OpStore %4 %9
%10 = OpLoad %ulong %3 Aligned 8
@@ -45,22 +38,20 @@
%11 = OpLoad %ulong %4 Aligned 8
OpStore %6 %11
%13 = OpLoad %ulong %5
- %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %13
- %12 = OpLoad %ulong %20 Aligned 8
+ %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %13
+ %12 = OpLoad %ulong %18 Aligned 8
OpStore %7 %12
- %28 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %27
- %14 = OpLoad %_ptr_Workgroup_uint %28
- %15 = OpLoad %ulong %7
- %21 = OpBitcast %_ptr_Workgroup_ulong %14
- OpStore %21 %15 Aligned 8
- %29 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %27
- %17 = OpLoad %_ptr_Workgroup_uint %29
- %22 = OpBitcast %_ptr_Workgroup_ulong %17
- %16 = OpLoad %ulong %22 Aligned 8
- OpStore %7 %16
- %18 = OpLoad %ulong %6
- %19 = OpLoad %ulong %7
- %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %18
- OpStore %23 %19 Aligned 8
+ %14 = OpLoad %ulong %7
+ %25 = OpBitcast %_ptr_Workgroup_uint %24
+ %19 = OpBitcast %_ptr_Workgroup_ulong %25
+ OpStore %19 %14 Aligned 8
+ %26 = OpBitcast %_ptr_Workgroup_uint %24
+ %20 = OpBitcast %_ptr_Workgroup_ulong %26
+ %15 = OpLoad %ulong %20 Aligned 8
+ OpStore %7 %15
+ %16 = OpLoad %ulong %6
+ %17 = OpLoad %ulong %7
+ %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16
+ OpStore %21 %17 Aligned 8
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt
index 7043172..5af7168 100644
--- a/ptx/src/test/spirv_run/extern_shared_call.spvtxt
+++ b/ptx/src/test/spirv_run/extern_shared_call.spvtxt
@@ -7,87 +7,72 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %46 = OpExtInstImport "OpenCL.std"
+ %40 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
- OpEntryPoint Kernel %14 "extern_shared_call" %1
+ OpEntryPoint Kernel %12 "extern_shared_call" %1
OpDecorate %1 Alignment 4
%void = OpTypeVoid
%uint = OpTypeInt 32 0
%_ptr_Workgroup_uint = OpTypePointer Workgroup %uint
-%_ptr_Workgroup__ptr_Workgroup_uint = OpTypePointer Workgroup %_ptr_Workgroup_uint
- %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup
+ %1 = OpVariable %_ptr_Workgroup_uint Workgroup
%uchar = OpTypeInt 8 0
%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
- %53 = OpTypeFunction %void %_ptr_Workgroup_uchar
-%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar
+ %46 = OpTypeFunction %void %_ptr_Workgroup_uchar
%ulong = OpTypeInt 64 0
%_ptr_Function_ulong = OpTypePointer Function %ulong
-%_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
%ulong_2 = OpConstant %ulong 2
- %60 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar
+ %50 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
- %2 = OpFunction %void None %53
- %38 = OpFunctionParameter %_ptr_Workgroup_uchar
- %54 = OpLabel
- %39 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function
+ %2 = OpFunction %void None %46
+ %34 = OpFunctionParameter %_ptr_Workgroup_uchar
+ %11 = OpLabel
%3 = OpVariable %_ptr_Function_ulong Function
- OpStore %39 %38
- OpBranch %13
- %13 = OpLabel
- %40 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39
- %5 = OpLoad %_ptr_Workgroup_uint %40
- %11 = OpBitcast %_ptr_Workgroup_ulong %5
- %4 = OpLoad %ulong %11 Aligned 8
+ %35 = OpBitcast %_ptr_Workgroup_uint %34
+ %9 = OpBitcast %_ptr_Workgroup_ulong %35
+ %4 = OpLoad %ulong %9 Aligned 8
OpStore %3 %4
+ %6 = OpLoad %ulong %3
+ %5 = OpIAdd %ulong %6 %ulong_2
+ OpStore %3 %5
%7 = OpLoad %ulong %3
- %6 = OpIAdd %ulong %7 %ulong_2
- OpStore %3 %6
- %41 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39
- %8 = OpLoad %_ptr_Workgroup_uint %41
- %9 = OpLoad %ulong %3
- %12 = OpBitcast %_ptr_Workgroup_ulong %8
- OpStore %12 %9 Aligned 8
+ %36 = OpBitcast %_ptr_Workgroup_uint %34
+ %10 = OpBitcast %_ptr_Workgroup_ulong %36
+ OpStore %10 %7 Aligned 8
OpReturn
OpFunctionEnd
- %14 = OpFunction %void None %60
- %20 = OpFunctionParameter %ulong
- %21 = OpFunctionParameter %ulong
- %42 = OpFunctionParameter %_ptr_Workgroup_uchar
- %61 = OpLabel
- %43 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function
+ %12 = OpFunction %void None %50
+ %18 = OpFunctionParameter %ulong
+ %19 = OpFunctionParameter %ulong
+ %37 = OpFunctionParameter %_ptr_Workgroup_uchar
+ %32 = OpLabel
+ %13 = OpVariable %_ptr_Function_ulong Function
+ %14 = OpVariable %_ptr_Function_ulong Function
%15 = OpVariable %_ptr_Function_ulong Function
%16 = OpVariable %_ptr_Function_ulong Function
%17 = OpVariable %_ptr_Function_ulong Function
- %18 = OpVariable %_ptr_Function_ulong Function
- %19 = OpVariable %_ptr_Function_ulong Function
- OpStore %43 %42
- OpBranch %36
- %36 = OpLabel
+ OpStore %13 %18
+ OpStore %14 %19
+ %20 = OpLoad %ulong %13 Aligned 8
OpStore %15 %20
+ %21 = OpLoad %ulong %14 Aligned 8
OpStore %16 %21
- %22 = OpLoad %ulong %15 Aligned 8
+ %23 = OpLoad %ulong %15
+ %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %23
+ %22 = OpLoad %ulong %28 Aligned 8
OpStore %17 %22
- %23 = OpLoad %ulong %16 Aligned 8
- OpStore %18 %23
- %25 = OpLoad %ulong %17
- %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %25
- %24 = OpLoad %ulong %32 Aligned 8
- OpStore %19 %24
- %44 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43
- %26 = OpLoad %_ptr_Workgroup_uint %44
- %27 = OpLoad %ulong %19
- %33 = OpBitcast %_ptr_Workgroup_ulong %26
- OpStore %33 %27 Aligned 8
- %63 = OpFunctionCall %void %2 %42
- %45 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43
- %29 = OpLoad %_ptr_Workgroup_uint %45
- %34 = OpBitcast %_ptr_Workgroup_ulong %29
- %28 = OpLoad %ulong %34 Aligned 8
- OpStore %19 %28
- %30 = OpLoad %ulong %18
- %31 = OpLoad %ulong %19
- %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %30
- OpStore %35 %31 Aligned 8
+ %24 = OpLoad %ulong %17
+ %38 = OpBitcast %_ptr_Workgroup_uint %37
+ %29 = OpBitcast %_ptr_Workgroup_ulong %38
+ OpStore %29 %24 Aligned 8
+ %52 = OpFunctionCall %void %2 %37
+ %39 = OpBitcast %_ptr_Workgroup_uint %37
+ %30 = OpBitcast %_ptr_Workgroup_ulong %39
+ %25 = OpLoad %ulong %30 Aligned 8
+ OpStore %17 %25
+ %26 = OpLoad %ulong %16
+ %27 = OpLoad %ulong %17
+ %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26
+ OpStore %31 %27 Aligned 8
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/fma.spvtxt b/ptx/src/test/spirv_run/fma.spvtxt
index 300a328..8cc0e16 100644
--- a/ptx/src/test/spirv_run/fma.spvtxt
+++ b/ptx/src/test/spirv_run/fma.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_8 = OpConstant %ulong 8
%1 = OpFunction %void None %38
%9 = OpFunctionParameter %ulong
@@ -41,14 +43,18 @@
%13 = OpLoad %float %29 Aligned 4
OpStore %6 %13
%16 = OpLoad %ulong %4
- %26 = OpIAdd %ulong %16 %ulong_4
- %30 = OpConvertUToPtr %_ptr_Generic_float %26
- %15 = OpLoad %float %30 Aligned 4
+ %30 = OpConvertUToPtr %_ptr_Generic_float %16
+ %45 = OpBitcast %_ptr_Generic_uchar %30
+ %46 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %45 %ulong_4
+ %26 = OpBitcast %_ptr_Generic_float %46
+ %15 = OpLoad %float %26 Aligned 4
OpStore %7 %15
%18 = OpLoad %ulong %4
- %28 = OpIAdd %ulong %18 %ulong_8
- %31 = OpConvertUToPtr %_ptr_Generic_float %28
- %17 = OpLoad %float %31 Aligned 4
+ %31 = OpConvertUToPtr %_ptr_Generic_float %18
+ %47 = OpBitcast %_ptr_Generic_uchar %31
+ %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_8
+ %28 = OpBitcast %_ptr_Generic_float %48
+ %17 = OpLoad %float %28 Aligned 4
OpStore %8 %17
%20 = OpLoad %float %6
%21 = OpLoad %float %7
diff --git a/ptx/src/test/spirv_run/ld_st_offset.spvtxt b/ptx/src/test/spirv_run/ld_st_offset.spvtxt
index 5e314a0..ea97222 100644
--- a/ptx/src/test/spirv_run/ld_st_offset.spvtxt
+++ b/ptx/src/test/spirv_run/ld_st_offset.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_4_0 = OpConstant %ulong 4
%1 = OpFunction %void None %33
%8 = OpFunctionParameter %ulong
@@ -40,9 +42,11 @@
%12 = OpLoad %uint %24 Aligned 4
OpStore %6 %12
%15 = OpLoad %ulong %4
- %21 = OpIAdd %ulong %15 %ulong_4
- %25 = OpConvertUToPtr %_ptr_Generic_uint %21
- %14 = OpLoad %uint %25 Aligned 4
+ %25 = OpConvertUToPtr %_ptr_Generic_uint %15
+ %40 = OpBitcast %_ptr_Generic_uchar %25
+ %41 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %40 %ulong_4
+ %21 = OpBitcast %_ptr_Generic_uint %41
+ %14 = OpLoad %uint %21 Aligned 4
OpStore %7 %14
%16 = OpLoad %ulong %5
%17 = OpLoad %uint %7
@@ -50,8 +54,10 @@
OpStore %26 %17 Aligned 4
%18 = OpLoad %ulong %5
%19 = OpLoad %uint %6
- %23 = OpIAdd %ulong %18 %ulong_4_0
- %27 = OpConvertUToPtr %_ptr_Generic_uint %23
- OpStore %27 %19 Aligned 4
+ %27 = OpConvertUToPtr %_ptr_Generic_uint %18
+ %42 = OpBitcast %_ptr_Generic_uchar %27
+ %43 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %42 %ulong_4_0
+ %23 = OpBitcast %_ptr_Generic_uint %43
+ OpStore %23 %19 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mad_s32.spvtxt b/ptx/src/test/spirv_run/mad_s32.spvtxt
index bb44af0..0ee3ca7 100644
--- a/ptx/src/test/spirv_run/mad_s32.spvtxt
+++ b/ptx/src/test/spirv_run/mad_s32.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_8 = OpConstant %ulong 8
%ulong_4_0 = OpConstant %ulong 4
%ulong_8_0 = OpConstant %ulong 8
@@ -44,20 +46,24 @@
%14 = OpLoad %uint %38 Aligned 4
OpStore %7 %14
%17 = OpLoad %ulong %4
- %31 = OpIAdd %ulong %17 %ulong_4
- %39 = OpConvertUToPtr %_ptr_Generic_uint %31
- %16 = OpLoad %uint %39 Aligned 4
+ %39 = OpConvertUToPtr %_ptr_Generic_uint %17
+ %56 = OpBitcast %_ptr_Generic_uchar %39
+ %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4
+ %31 = OpBitcast %_ptr_Generic_uint %57
+ %16 = OpLoad %uint %31 Aligned 4
OpStore %8 %16
%19 = OpLoad %ulong %4
- %33 = OpIAdd %ulong %19 %ulong_8
- %40 = OpConvertUToPtr %_ptr_Generic_uint %33
- %18 = OpLoad %uint %40 Aligned 4
+ %40 = OpConvertUToPtr %_ptr_Generic_uint %19
+ %58 = OpBitcast %_ptr_Generic_uchar %40
+ %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_8
+ %33 = OpBitcast %_ptr_Generic_uint %59
+ %18 = OpLoad %uint %33 Aligned 4
OpStore %9 %18
%21 = OpLoad %uint %7
%22 = OpLoad %uint %8
%23 = OpLoad %uint %9
- %54 = OpIMul %uint %21 %22
- %20 = OpIAdd %uint %23 %54
+ %60 = OpIMul %uint %21 %22
+ %20 = OpIAdd %uint %23 %60
OpStore %6 %20
%24 = OpLoad %ulong %5
%25 = OpLoad %uint %6
@@ -65,13 +71,17 @@
OpStore %41 %25 Aligned 4
%26 = OpLoad %ulong %5
%27 = OpLoad %uint %6
- %35 = OpIAdd %ulong %26 %ulong_4_0
- %42 = OpConvertUToPtr %_ptr_Generic_uint %35
- OpStore %42 %27 Aligned 4
+ %42 = OpConvertUToPtr %_ptr_Generic_uint %26
+ %61 = OpBitcast %_ptr_Generic_uchar %42
+ %62 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %61 %ulong_4_0
+ %35 = OpBitcast %_ptr_Generic_uint %62
+ OpStore %35 %27 Aligned 4
%28 = OpLoad %ulong %5
%29 = OpLoad %uint %6
- %37 = OpIAdd %ulong %28 %ulong_8_0
- %43 = OpConvertUToPtr %_ptr_Generic_uint %37
- OpStore %43 %29 Aligned 4
+ %43 = OpConvertUToPtr %_ptr_Generic_uint %28
+ %63 = OpBitcast %_ptr_Generic_uchar %43
+ %64 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %63 %ulong_8_0
+ %37 = OpBitcast %_ptr_Generic_uint %64
+ OpStore %37 %29 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/max.spvtxt b/ptx/src/test/spirv_run/max.spvtxt
index d3ffa2f..86b732a 100644
--- a/ptx/src/test/spirv_run/max.spvtxt
+++ b/ptx/src/test/spirv_run/max.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%1 = OpFunction %void None %31
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
@@ -39,9 +41,11 @@
%12 = OpLoad %uint %23 Aligned 4
OpStore %6 %12
%15 = OpLoad %ulong %4
- %22 = OpIAdd %ulong %15 %ulong_4
- %24 = OpConvertUToPtr %_ptr_Generic_uint %22
- %14 = OpLoad %uint %24 Aligned 4
+ %24 = OpConvertUToPtr %_ptr_Generic_uint %15
+ %38 = OpBitcast %_ptr_Generic_uchar %24
+ %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4
+ %22 = OpBitcast %_ptr_Generic_uint %39
+ %14 = OpLoad %uint %22 Aligned 4
OpStore %7 %14
%17 = OpLoad %uint %6
%18 = OpLoad %uint %7
diff --git a/ptx/src/test/spirv_run/min.spvtxt b/ptx/src/test/spirv_run/min.spvtxt
index de2e35e..a187376 100644
--- a/ptx/src/test/spirv_run/min.spvtxt
+++ b/ptx/src/test/spirv_run/min.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%1 = OpFunction %void None %31
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
@@ -39,9 +41,11 @@
%12 = OpLoad %uint %23 Aligned 4
OpStore %6 %12
%15 = OpLoad %ulong %4
- %22 = OpIAdd %ulong %15 %ulong_4
- %24 = OpConvertUToPtr %_ptr_Generic_uint %22
- %14 = OpLoad %uint %24 Aligned 4
+ %24 = OpConvertUToPtr %_ptr_Generic_uint %15
+ %38 = OpBitcast %_ptr_Generic_uchar %24
+ %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4
+ %22 = OpBitcast %_ptr_Generic_uint %39
+ %14 = OpLoad %uint %22 Aligned 4
OpStore %7 %14
%17 = OpLoad %uint %6
%18 = OpLoad %uint %7
diff --git a/ptx/src/test/spirv_run/mul_ftz.spvtxt b/ptx/src/test/spirv_run/mul_ftz.spvtxt
index ed268fb..e7a4a56 100644
--- a/ptx/src/test/spirv_run/mul_ftz.spvtxt
+++ b/ptx/src/test/spirv_run/mul_ftz.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%1 = OpFunction %void None %31
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
@@ -39,9 +41,11 @@
%12 = OpLoad %float %23 Aligned 4
OpStore %6 %12
%15 = OpLoad %ulong %4
- %22 = OpIAdd %ulong %15 %ulong_4
- %24 = OpConvertUToPtr %_ptr_Generic_float %22
- %14 = OpLoad %float %24 Aligned 4
+ %24 = OpConvertUToPtr %_ptr_Generic_float %15
+ %38 = OpBitcast %_ptr_Generic_uchar %24
+ %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4
+ %22 = OpBitcast %_ptr_Generic_float %39
+ %14 = OpLoad %float %22 Aligned 4
OpStore %7 %14
%17 = OpLoad %float %6
%18 = OpLoad %float %7
diff --git a/ptx/src/test/spirv_run/mul_non_ftz.spvtxt b/ptx/src/test/spirv_run/mul_non_ftz.spvtxt
index 436aca1..5326baa 100644
--- a/ptx/src/test/spirv_run/mul_non_ftz.spvtxt
+++ b/ptx/src/test/spirv_run/mul_non_ftz.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_float = OpTypePointer Function %float
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%1 = OpFunction %void None %31
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
@@ -39,9 +41,11 @@
%12 = OpLoad %float %23 Aligned 4
OpStore %6 %12
%15 = OpLoad %ulong %4
- %22 = OpIAdd %ulong %15 %ulong_4
- %24 = OpConvertUToPtr %_ptr_Generic_float %22
- %14 = OpLoad %float %24 Aligned 4
+ %24 = OpConvertUToPtr %_ptr_Generic_float %15
+ %38 = OpBitcast %_ptr_Generic_uchar %24
+ %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4
+ %22 = OpBitcast %_ptr_Generic_float %39
+ %14 = OpLoad %float %22 Aligned 4
OpStore %7 %14
%17 = OpLoad %float %6
%18 = OpLoad %float %7
diff --git a/ptx/src/test/spirv_run/mul_wide.spvtxt b/ptx/src/test/spirv_run/mul_wide.spvtxt
index 7ac81cf..e96a964 100644
--- a/ptx/src/test/spirv_run/mul_wide.spvtxt
+++ b/ptx/src/test/spirv_run/mul_wide.spvtxt
@@ -18,7 +18,9 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint
%ulong_4 = OpConstant %ulong 4
- %_struct_38 = OpTypeStruct %uint %uint
+ %uchar = OpTypeInt 8 0
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
+ %_struct_42 = OpTypeStruct %uint %uint
%v2uint = OpTypeVector %uint 2
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%1 = OpFunction %void None %33
@@ -43,17 +45,19 @@
%13 = OpLoad %uint %24 Aligned 4
OpStore %6 %13
%16 = OpLoad %ulong %4
- %23 = OpIAdd %ulong %16 %ulong_4
- %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %23
- %15 = OpLoad %uint %25 Aligned 4
+ %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %16
+ %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %25
+ %41 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %40 %ulong_4
+ %23 = OpBitcast %_ptr_CrossWorkgroup_uint %41
+ %15 = OpLoad %uint %23 Aligned 4
OpStore %7 %15
%18 = OpLoad %uint %6
%19 = OpLoad %uint %7
- %39 = OpSMulExtended %_struct_38 %18 %19
- %40 = OpCompositeExtract %uint %39 0
- %41 = OpCompositeExtract %uint %39 1
- %43 = OpCompositeConstruct %v2uint %40 %41
- %17 = OpBitcast %ulong %43
+ %43 = OpSMulExtended %_struct_42 %18 %19
+ %44 = OpCompositeExtract %uint %43 0
+ %45 = OpCompositeExtract %uint %43 1
+ %47 = OpCompositeConstruct %v2uint %44 %45
+ %17 = OpBitcast %ulong %47
OpStore %8 %17
%20 = OpLoad %ulong %5
%21 = OpLoad %ulong %8
diff --git a/ptx/src/test/spirv_run/or.spvtxt b/ptx/src/test/spirv_run/or.spvtxt
index fef3f40..82db00c 100644
--- a/ptx/src/test/spirv_run/or.spvtxt
+++ b/ptx/src/test/spirv_run/or.spvtxt
@@ -16,6 +16,8 @@
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_8 = OpConstant %ulong 8
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%1 = OpFunction %void None %34
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
@@ -37,9 +39,11 @@
%12 = OpLoad %ulong %23 Aligned 8
OpStore %6 %12
%15 = OpLoad %ulong %4
- %22 = OpIAdd %ulong %15 %ulong_8
- %24 = OpConvertUToPtr %_ptr_Generic_ulong %22
- %14 = OpLoad %ulong %24 Aligned 8
+ %24 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %39 = OpBitcast %_ptr_Generic_uchar %24
+ %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_8
+ %22 = OpBitcast %_ptr_Generic_ulong %40
+ %14 = OpLoad %ulong %22 Aligned 8
OpStore %7 %14
%17 = OpLoad %ulong %6
%18 = OpLoad %ulong %7
diff --git a/ptx/src/test/spirv_run/pred_not.spvtxt b/ptx/src/test/spirv_run/pred_not.spvtxt
index 18fde05..644731b 100644
--- a/ptx/src/test/spirv_run/pred_not.spvtxt
+++ b/ptx/src/test/spirv_run/pred_not.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_bool = OpTypePointer Function %bool
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_8 = OpConstant %ulong 8
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%true = OpConstantTrue %bool
%false = OpConstantFalse %bool
%ulong_1 = OpConstant %ulong 1
@@ -45,9 +47,11 @@
%18 = OpLoad %ulong %37 Aligned 8
OpStore %6 %18
%21 = OpLoad %ulong %4
- %34 = OpIAdd %ulong %21 %ulong_8
- %38 = OpConvertUToPtr %_ptr_Generic_ulong %34
- %20 = OpLoad %ulong %38 Aligned 8
+ %38 = OpConvertUToPtr %_ptr_Generic_ulong %21
+ %52 = OpBitcast %_ptr_Generic_uchar %38
+ %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_8
+ %34 = OpBitcast %_ptr_Generic_ulong %53
+ %20 = OpLoad %ulong %34 Aligned 8
OpStore %7 %20
%23 = OpLoad %ulong %6
%24 = OpLoad %ulong %7
diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt
index 7bb5bd9..a0b957a 100644
--- a/ptx/src/test/spirv_run/reg_local.spvtxt
+++ b/ptx/src/test/spirv_run/reg_local.spvtxt
@@ -26,6 +26,7 @@
%ulong_0 = OpConstant %ulong 0
%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_0_0 = OpConstant %ulong 0
+%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
%1 = OpFunction %void None %37
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
@@ -48,10 +49,10 @@
%12 = OpCopyObject %ulong %24
OpStore %7 %12
%14 = OpLoad %ulong %7
- %26 = OpCopyObject %ulong %14
- %19 = OpIAdd %ulong %26 %ulong_1
- %27 = OpBitcast %_ptr_Generic_ulong %4
- OpStore %27 %19 Aligned 8
+ %19 = OpIAdd %ulong %14 %ulong_1
+ %26 = OpBitcast %_ptr_Generic_ulong %4
+ %27 = OpCopyObject %ulong %19
+ OpStore %26 %27 Aligned 8
%28 = OpBitcast %_ptr_Generic_ulong %4
%47 = OpBitcast %_ptr_Generic_uchar %28
%48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_0
@@ -61,9 +62,11 @@
OpStore %7 %15
%16 = OpLoad %ulong %6
%17 = OpLoad %ulong %7
- %23 = OpIAdd %ulong %16 %ulong_0_0
- %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %23
+ %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16
+ %50 = OpBitcast %_ptr_CrossWorkgroup_uchar %30
+ %51 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %50 %ulong_0_0
+ %23 = OpBitcast %_ptr_CrossWorkgroup_ulong %51
%31 = OpCopyObject %ulong %17
- OpStore %30 %31 Aligned 8
+ OpStore %23 %31 Aligned 8
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/rem.spvtxt b/ptx/src/test/spirv_run/rem.spvtxt
index ce1d3e6..2184523 100644
--- a/ptx/src/test/spirv_run/rem.spvtxt
+++ b/ptx/src/test/spirv_run/rem.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%1 = OpFunction %void None %31
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
@@ -39,9 +41,11 @@
%12 = OpLoad %uint %23 Aligned 4
OpStore %6 %12
%15 = OpLoad %ulong %4
- %22 = OpIAdd %ulong %15 %ulong_4
- %24 = OpConvertUToPtr %_ptr_Generic_uint %22
- %14 = OpLoad %uint %24 Aligned 4
+ %24 = OpConvertUToPtr %_ptr_Generic_uint %15
+ %38 = OpBitcast %_ptr_Generic_uchar %24
+ %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4
+ %22 = OpBitcast %_ptr_Generic_uint %39
+ %14 = OpLoad %uint %22 Aligned 4
OpStore %7 %14
%17 = OpLoad %uint %6
%18 = OpLoad %uint %7
diff --git a/ptx/src/test/spirv_run/selp.spvtxt b/ptx/src/test/spirv_run/selp.spvtxt
index 9798758..40c0bce 100644
--- a/ptx/src/test/spirv_run/selp.spvtxt
+++ b/ptx/src/test/spirv_run/selp.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_ushort = OpTypePointer Function %ushort
%_ptr_Generic_ushort = OpTypePointer Generic %ushort
%ulong_2 = OpConstant %ulong 2
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%bool = OpTypeBool
%false = OpConstantFalse %bool
%1 = OpFunction %void None %32
@@ -41,9 +43,11 @@
%12 = OpLoad %ushort %24 Aligned 2
OpStore %6 %12
%15 = OpLoad %ulong %4
- %22 = OpIAdd %ulong %15 %ulong_2
- %25 = OpConvertUToPtr %_ptr_Generic_ushort %22
- %14 = OpLoad %ushort %25 Aligned 2
+ %25 = OpConvertUToPtr %_ptr_Generic_ushort %15
+ %39 = OpBitcast %_ptr_Generic_uchar %25
+ %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_2
+ %22 = OpBitcast %_ptr_Generic_ushort %40
+ %14 = OpLoad %ushort %22 Aligned 2
OpStore %7 %14
%17 = OpLoad %ushort %6
%18 = OpLoad %ushort %7
diff --git a/ptx/src/test/spirv_run/selp_true.spvtxt b/ptx/src/test/spirv_run/selp_true.spvtxt
index f7038e0..81b3b5f 100644
--- a/ptx/src/test/spirv_run/selp_true.spvtxt
+++ b/ptx/src/test/spirv_run/selp_true.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_ushort = OpTypePointer Function %ushort
%_ptr_Generic_ushort = OpTypePointer Generic %ushort
%ulong_2 = OpConstant %ulong 2
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%bool = OpTypeBool
%true = OpConstantTrue %bool
%1 = OpFunction %void None %32
@@ -41,9 +43,11 @@
%12 = OpLoad %ushort %24 Aligned 2
OpStore %6 %12
%15 = OpLoad %ulong %4
- %22 = OpIAdd %ulong %15 %ulong_2
- %25 = OpConvertUToPtr %_ptr_Generic_ushort %22
- %14 = OpLoad %ushort %25 Aligned 2
+ %25 = OpConvertUToPtr %_ptr_Generic_ushort %15
+ %39 = OpBitcast %_ptr_Generic_uchar %25
+ %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_2
+ %22 = OpBitcast %_ptr_Generic_ushort %40
+ %14 = OpLoad %ushort %22 Aligned 2
OpStore %7 %14
%17 = OpLoad %ushort %6
%18 = OpLoad %ushort %7
diff --git a/ptx/src/test/spirv_run/setp.spvtxt b/ptx/src/test/spirv_run/setp.spvtxt
index c3129e3..5868881 100644
--- a/ptx/src/test/spirv_run/setp.spvtxt
+++ b/ptx/src/test/spirv_run/setp.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_bool = OpTypePointer Function %bool
%_ptr_Generic_ulong = OpTypePointer Generic %ulong
%ulong_8 = OpConstant %ulong 8
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_1 = OpConstant %ulong 1
%ulong_2 = OpConstant %ulong 2
%1 = OpFunction %void None %43
@@ -43,9 +45,11 @@
%18 = OpLoad %ulong %35 Aligned 8
OpStore %6 %18
%21 = OpLoad %ulong %4
- %32 = OpIAdd %ulong %21 %ulong_8
- %36 = OpConvertUToPtr %_ptr_Generic_ulong %32
- %20 = OpLoad %ulong %36 Aligned 8
+ %36 = OpConvertUToPtr %_ptr_Generic_ulong %21
+ %50 = OpBitcast %_ptr_Generic_uchar %36
+ %51 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %50 %ulong_8
+ %32 = OpBitcast %_ptr_Generic_ulong %51
+ %20 = OpLoad %ulong %32 Aligned 8
OpStore %7 %20
%23 = OpLoad %ulong %6
%24 = OpLoad %ulong %7
diff --git a/ptx/src/test/spirv_run/setp_gt.spvtxt b/ptx/src/test/spirv_run/setp_gt.spvtxt
index 77f6546..e9783f5 100644
--- a/ptx/src/test/spirv_run/setp_gt.spvtxt
+++ b/ptx/src/test/spirv_run/setp_gt.spvtxt
@@ -20,6 +20,8 @@
%_ptr_Function_bool = OpTypePointer Function %bool
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%1 = OpFunction %void None %43
%14 = OpFunctionParameter %ulong
%15 = OpFunctionParameter %ulong
@@ -43,9 +45,11 @@
%18 = OpLoad %float %35 Aligned 4
OpStore %6 %18
%21 = OpLoad %ulong %4
- %34 = OpIAdd %ulong %21 %ulong_4
- %36 = OpConvertUToPtr %_ptr_Generic_float %34
- %20 = OpLoad %float %36 Aligned 4
+ %36 = OpConvertUToPtr %_ptr_Generic_float %21
+ %52 = OpBitcast %_ptr_Generic_uchar %36
+ %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_4
+ %34 = OpBitcast %_ptr_Generic_float %53
+ %20 = OpLoad %float %34 Aligned 4
OpStore %7 %20
%23 = OpLoad %float %6
%24 = OpLoad %float %7
diff --git a/ptx/src/test/spirv_run/setp_leu.spvtxt b/ptx/src/test/spirv_run/setp_leu.spvtxt
index f80880a..1d2d781 100644
--- a/ptx/src/test/spirv_run/setp_leu.spvtxt
+++ b/ptx/src/test/spirv_run/setp_leu.spvtxt
@@ -20,6 +20,8 @@
%_ptr_Function_bool = OpTypePointer Function %bool
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%1 = OpFunction %void None %43
%14 = OpFunctionParameter %ulong
%15 = OpFunctionParameter %ulong
@@ -43,9 +45,11 @@
%18 = OpLoad %float %35 Aligned 4
OpStore %6 %18
%21 = OpLoad %ulong %4
- %34 = OpIAdd %ulong %21 %ulong_4
- %36 = OpConvertUToPtr %_ptr_Generic_float %34
- %20 = OpLoad %float %36 Aligned 4
+ %36 = OpConvertUToPtr %_ptr_Generic_float %21
+ %52 = OpBitcast %_ptr_Generic_uchar %36
+ %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_4
+ %34 = OpBitcast %_ptr_Generic_float %53
+ %20 = OpLoad %float %34 Aligned 4
OpStore %7 %20
%23 = OpLoad %float %6
%24 = OpLoad %float %7
diff --git a/ptx/src/test/spirv_run/setp_nan.spvtxt b/ptx/src/test/spirv_run/setp_nan.spvtxt
index 4a9fe11..2ee333a 100644
--- a/ptx/src/test/spirv_run/setp_nan.spvtxt
+++ b/ptx/src/test/spirv_run/setp_nan.spvtxt
@@ -22,6 +22,8 @@
%_ptr_Function_bool = OpTypePointer Function %bool
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_8 = OpConstant %ulong 8
%ulong_12 = OpConstant %ulong 12
%ulong_16 = OpConstant %ulong 16
@@ -69,45 +71,59 @@
%36 = OpLoad %float %116 Aligned 4
OpStore %6 %36
%39 = OpLoad %ulong %4
- %89 = OpIAdd %ulong %39 %ulong_4
- %117 = OpConvertUToPtr %_ptr_Generic_float %89
- %38 = OpLoad %float %117 Aligned 4
+ %117 = OpConvertUToPtr %_ptr_Generic_float %39
+ %144 = OpBitcast %_ptr_Generic_uchar %117
+ %145 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %144 %ulong_4
+ %89 = OpBitcast %_ptr_Generic_float %145
+ %38 = OpLoad %float %89 Aligned 4
OpStore %7 %38
%41 = OpLoad %ulong %4
- %91 = OpIAdd %ulong %41 %ulong_8
- %118 = OpConvertUToPtr %_ptr_Generic_float %91
- %40 = OpLoad %float %118 Aligned 4
+ %118 = OpConvertUToPtr %_ptr_Generic_float %41
+ %146 = OpBitcast %_ptr_Generic_uchar %118
+ %147 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %146 %ulong_8
+ %91 = OpBitcast %_ptr_Generic_float %147
+ %40 = OpLoad %float %91 Aligned 4
OpStore %8 %40
%43 = OpLoad %ulong %4
- %93 = OpIAdd %ulong %43 %ulong_12
- %119 = OpConvertUToPtr %_ptr_Generic_float %93
- %42 = OpLoad %float %119 Aligned 4
+ %119 = OpConvertUToPtr %_ptr_Generic_float %43
+ %148 = OpBitcast %_ptr_Generic_uchar %119
+ %149 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %148 %ulong_12
+ %93 = OpBitcast %_ptr_Generic_float %149
+ %42 = OpLoad %float %93 Aligned 4
OpStore %9 %42
%45 = OpLoad %ulong %4
- %95 = OpIAdd %ulong %45 %ulong_16
- %120 = OpConvertUToPtr %_ptr_Generic_float %95
- %44 = OpLoad %float %120 Aligned 4
+ %120 = OpConvertUToPtr %_ptr_Generic_float %45
+ %150 = OpBitcast %_ptr_Generic_uchar %120
+ %151 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %150 %ulong_16
+ %95 = OpBitcast %_ptr_Generic_float %151
+ %44 = OpLoad %float %95 Aligned 4
OpStore %10 %44
%47 = OpLoad %ulong %4
- %97 = OpIAdd %ulong %47 %ulong_20
- %121 = OpConvertUToPtr %_ptr_Generic_float %97
- %46 = OpLoad %float %121 Aligned 4
+ %121 = OpConvertUToPtr %_ptr_Generic_float %47
+ %152 = OpBitcast %_ptr_Generic_uchar %121
+ %153 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %152 %ulong_20
+ %97 = OpBitcast %_ptr_Generic_float %153
+ %46 = OpLoad %float %97 Aligned 4
OpStore %11 %46
%49 = OpLoad %ulong %4
- %99 = OpIAdd %ulong %49 %ulong_24
- %122 = OpConvertUToPtr %_ptr_Generic_float %99
- %48 = OpLoad %float %122 Aligned 4
+ %122 = OpConvertUToPtr %_ptr_Generic_float %49
+ %154 = OpBitcast %_ptr_Generic_uchar %122
+ %155 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %154 %ulong_24
+ %99 = OpBitcast %_ptr_Generic_float %155
+ %48 = OpLoad %float %99 Aligned 4
OpStore %12 %48
%51 = OpLoad %ulong %4
- %101 = OpIAdd %ulong %51 %ulong_28
- %123 = OpConvertUToPtr %_ptr_Generic_float %101
- %50 = OpLoad %float %123 Aligned 4
+ %123 = OpConvertUToPtr %_ptr_Generic_float %51
+ %156 = OpBitcast %_ptr_Generic_uchar %123
+ %157 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %156 %ulong_28
+ %101 = OpBitcast %_ptr_Generic_float %157
+ %50 = OpLoad %float %101 Aligned 4
OpStore %13 %50
%53 = OpLoad %float %6
%54 = OpLoad %float %7
- %142 = OpIsNan %bool %53
- %143 = OpIsNan %bool %54
- %52 = OpLogicalOr %bool %142 %143
+ %158 = OpIsNan %bool %53
+ %159 = OpIsNan %bool %54
+ %52 = OpLogicalOr %bool %158 %159
OpStore %15 %52
%55 = OpLoad %bool %15
OpBranchConditional %55 %16 %17
@@ -129,9 +145,9 @@
OpStore %124 %60 Aligned 4
%62 = OpLoad %float %8
%63 = OpLoad %float %9
- %145 = OpIsNan %bool %62
- %146 = OpIsNan %bool %63
- %61 = OpLogicalOr %bool %145 %146
+ %161 = OpIsNan %bool %62
+ %162 = OpIsNan %bool %63
+ %61 = OpLogicalOr %bool %161 %162
OpStore %15 %61
%64 = OpLoad %bool %15
OpBranchConditional %64 %20 %21
@@ -149,14 +165,16 @@
%23 = OpLabel
%68 = OpLoad %ulong %5
%69 = OpLoad %uint %14
- %107 = OpIAdd %ulong %68 %ulong_4_0
- %125 = OpConvertUToPtr %_ptr_Generic_uint %107
- OpStore %125 %69 Aligned 4
+ %125 = OpConvertUToPtr %_ptr_Generic_uint %68
+ %163 = OpBitcast %_ptr_Generic_uchar %125
+ %164 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %163 %ulong_4_0
+ %107 = OpBitcast %_ptr_Generic_uint %164
+ OpStore %107 %69 Aligned 4
%71 = OpLoad %float %10
%72 = OpLoad %float %11
- %147 = OpIsNan %bool %71
- %148 = OpIsNan %bool %72
- %70 = OpLogicalOr %bool %147 %148
+ %165 = OpIsNan %bool %71
+ %166 = OpIsNan %bool %72
+ %70 = OpLogicalOr %bool %165 %166
OpStore %15 %70
%73 = OpLoad %bool %15
OpBranchConditional %73 %24 %25
@@ -174,14 +192,16 @@
%27 = OpLabel
%77 = OpLoad %ulong %5
%78 = OpLoad %uint %14
- %111 = OpIAdd %ulong %77 %ulong_8_0
- %126 = OpConvertUToPtr %_ptr_Generic_uint %111
- OpStore %126 %78 Aligned 4
+ %126 = OpConvertUToPtr %_ptr_Generic_uint %77
+ %167 = OpBitcast %_ptr_Generic_uchar %126
+ %168 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %167 %ulong_8_0
+ %111 = OpBitcast %_ptr_Generic_uint %168
+ OpStore %111 %78 Aligned 4
%80 = OpLoad %float %12
%81 = OpLoad %float %13
- %149 = OpIsNan %bool %80
- %150 = OpIsNan %bool %81
- %79 = OpLogicalOr %bool %149 %150
+ %169 = OpIsNan %bool %80
+ %170 = OpIsNan %bool %81
+ %79 = OpLogicalOr %bool %169 %170
OpStore %15 %79
%82 = OpLoad %bool %15
OpBranchConditional %82 %28 %29
@@ -199,8 +219,10 @@
%31 = OpLabel
%86 = OpLoad %ulong %5
%87 = OpLoad %uint %14
- %115 = OpIAdd %ulong %86 %ulong_12_0
- %127 = OpConvertUToPtr %_ptr_Generic_uint %115
- OpStore %127 %87 Aligned 4
+ %127 = OpConvertUToPtr %_ptr_Generic_uint %86
+ %171 = OpBitcast %_ptr_Generic_uchar %127
+ %172 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %171 %ulong_12_0
+ %115 = OpBitcast %_ptr_Generic_uint %172
+ OpStore %115 %87 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/setp_num.spvtxt b/ptx/src/test/spirv_run/setp_num.spvtxt
index 3ac6eab..c576a50 100644
--- a/ptx/src/test/spirv_run/setp_num.spvtxt
+++ b/ptx/src/test/spirv_run/setp_num.spvtxt
@@ -22,6 +22,8 @@
%_ptr_Function_bool = OpTypePointer Function %bool
%_ptr_Generic_float = OpTypePointer Generic %float
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%ulong_8 = OpConstant %ulong 8
%ulong_12 = OpConstant %ulong 12
%ulong_16 = OpConstant %ulong 16
@@ -77,46 +79,60 @@
%36 = OpLoad %float %116 Aligned 4
OpStore %6 %36
%39 = OpLoad %ulong %4
- %89 = OpIAdd %ulong %39 %ulong_4
- %117 = OpConvertUToPtr %_ptr_Generic_float %89
- %38 = OpLoad %float %117 Aligned 4
+ %117 = OpConvertUToPtr %_ptr_Generic_float %39
+ %144 = OpBitcast %_ptr_Generic_uchar %117
+ %145 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %144 %ulong_4
+ %89 = OpBitcast %_ptr_Generic_float %145
+ %38 = OpLoad %float %89 Aligned 4
OpStore %7 %38
%41 = OpLoad %ulong %4
- %91 = OpIAdd %ulong %41 %ulong_8
- %118 = OpConvertUToPtr %_ptr_Generic_float %91
- %40 = OpLoad %float %118 Aligned 4
+ %118 = OpConvertUToPtr %_ptr_Generic_float %41
+ %146 = OpBitcast %_ptr_Generic_uchar %118
+ %147 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %146 %ulong_8
+ %91 = OpBitcast %_ptr_Generic_float %147
+ %40 = OpLoad %float %91 Aligned 4
OpStore %8 %40
%43 = OpLoad %ulong %4
- %93 = OpIAdd %ulong %43 %ulong_12
- %119 = OpConvertUToPtr %_ptr_Generic_float %93
- %42 = OpLoad %float %119 Aligned 4
+ %119 = OpConvertUToPtr %_ptr_Generic_float %43
+ %148 = OpBitcast %_ptr_Generic_uchar %119
+ %149 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %148 %ulong_12
+ %93 = OpBitcast %_ptr_Generic_float %149
+ %42 = OpLoad %float %93 Aligned 4
OpStore %9 %42
%45 = OpLoad %ulong %4
- %95 = OpIAdd %ulong %45 %ulong_16
- %120 = OpConvertUToPtr %_ptr_Generic_float %95
- %44 = OpLoad %float %120 Aligned 4
+ %120 = OpConvertUToPtr %_ptr_Generic_float %45
+ %150 = OpBitcast %_ptr_Generic_uchar %120
+ %151 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %150 %ulong_16
+ %95 = OpBitcast %_ptr_Generic_float %151
+ %44 = OpLoad %float %95 Aligned 4
OpStore %10 %44
%47 = OpLoad %ulong %4
- %97 = OpIAdd %ulong %47 %ulong_20
- %121 = OpConvertUToPtr %_ptr_Generic_float %97
- %46 = OpLoad %float %121 Aligned 4
+ %121 = OpConvertUToPtr %_ptr_Generic_float %47
+ %152 = OpBitcast %_ptr_Generic_uchar %121
+ %153 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %152 %ulong_20
+ %97 = OpBitcast %_ptr_Generic_float %153
+ %46 = OpLoad %float %97 Aligned 4
OpStore %11 %46
%49 = OpLoad %ulong %4
- %99 = OpIAdd %ulong %49 %ulong_24
- %122 = OpConvertUToPtr %_ptr_Generic_float %99
- %48 = OpLoad %float %122 Aligned 4
+ %122 = OpConvertUToPtr %_ptr_Generic_float %49
+ %154 = OpBitcast %_ptr_Generic_uchar %122
+ %155 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %154 %ulong_24
+ %99 = OpBitcast %_ptr_Generic_float %155
+ %48 = OpLoad %float %99 Aligned 4
OpStore %12 %48
%51 = OpLoad %ulong %4
- %101 = OpIAdd %ulong %51 %ulong_28
- %123 = OpConvertUToPtr %_ptr_Generic_float %101
- %50 = OpLoad %float %123 Aligned 4
+ %123 = OpConvertUToPtr %_ptr_Generic_float %51
+ %156 = OpBitcast %_ptr_Generic_uchar %123
+ %157 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %156 %ulong_28
+ %101 = OpBitcast %_ptr_Generic_float %157
+ %50 = OpLoad %float %101 Aligned 4
OpStore %13 %50
%53 = OpLoad %float %6
%54 = OpLoad %float %7
- %142 = OpIsNan %bool %53
- %143 = OpIsNan %bool %54
- %144 = OpLogicalOr %bool %142 %143
- %52 = OpSelect %bool %144 %false %true
+ %158 = OpIsNan %bool %53
+ %159 = OpIsNan %bool %54
+ %160 = OpLogicalOr %bool %158 %159
+ %52 = OpSelect %bool %160 %false %true
OpStore %15 %52
%55 = OpLoad %bool %15
OpBranchConditional %55 %16 %17
@@ -138,10 +154,10 @@
OpStore %124 %60 Aligned 4
%62 = OpLoad %float %8
%63 = OpLoad %float %9
- %148 = OpIsNan %bool %62
- %149 = OpIsNan %bool %63
- %150 = OpLogicalOr %bool %148 %149
- %61 = OpSelect %bool %150 %false_0 %true_0
+ %164 = OpIsNan %bool %62
+ %165 = OpIsNan %bool %63
+ %166 = OpLogicalOr %bool %164 %165
+ %61 = OpSelect %bool %166 %false_0 %true_0
OpStore %15 %61
%64 = OpLoad %bool %15
OpBranchConditional %64 %20 %21
@@ -159,15 +175,17 @@
%23 = OpLabel
%68 = OpLoad %ulong %5
%69 = OpLoad %uint %14
- %107 = OpIAdd %ulong %68 %ulong_4_0
- %125 = OpConvertUToPtr %_ptr_Generic_uint %107
- OpStore %125 %69 Aligned 4
+ %125 = OpConvertUToPtr %_ptr_Generic_uint %68
+ %169 = OpBitcast %_ptr_Generic_uchar %125
+ %170 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %169 %ulong_4_0
+ %107 = OpBitcast %_ptr_Generic_uint %170
+ OpStore %107 %69 Aligned 4
%71 = OpLoad %float %10
%72 = OpLoad %float %11
- %153 = OpIsNan %bool %71
- %154 = OpIsNan %bool %72
- %155 = OpLogicalOr %bool %153 %154
- %70 = OpSelect %bool %155 %false_1 %true_1
+ %171 = OpIsNan %bool %71
+ %172 = OpIsNan %bool %72
+ %173 = OpLogicalOr %bool %171 %172
+ %70 = OpSelect %bool %173 %false_1 %true_1
OpStore %15 %70
%73 = OpLoad %bool %15
OpBranchConditional %73 %24 %25
@@ -185,15 +203,17 @@
%27 = OpLabel
%77 = OpLoad %ulong %5
%78 = OpLoad %uint %14
- %111 = OpIAdd %ulong %77 %ulong_8_0
- %126 = OpConvertUToPtr %_ptr_Generic_uint %111
- OpStore %126 %78 Aligned 4
+ %126 = OpConvertUToPtr %_ptr_Generic_uint %77
+ %176 = OpBitcast %_ptr_Generic_uchar %126
+ %177 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %176 %ulong_8_0
+ %111 = OpBitcast %_ptr_Generic_uint %177
+ OpStore %111 %78 Aligned 4
%80 = OpLoad %float %12
%81 = OpLoad %float %13
- %158 = OpIsNan %bool %80
- %159 = OpIsNan %bool %81
- %160 = OpLogicalOr %bool %158 %159
- %79 = OpSelect %bool %160 %false_2 %true_2
+ %178 = OpIsNan %bool %80
+ %179 = OpIsNan %bool %81
+ %180 = OpLogicalOr %bool %178 %179
+ %79 = OpSelect %bool %180 %false_2 %true_2
OpStore %15 %79
%82 = OpLoad %bool %15
OpBranchConditional %82 %28 %29
@@ -211,8 +231,10 @@
%31 = OpLabel
%86 = OpLoad %ulong %5
%87 = OpLoad %uint %14
- %115 = OpIAdd %ulong %86 %ulong_12_0
- %127 = OpConvertUToPtr %_ptr_Generic_uint %115
- OpStore %127 %87 Aligned 4
+ %127 = OpConvertUToPtr %_ptr_Generic_uint %86
+ %183 = OpBitcast %_ptr_Generic_uchar %127
+ %184 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %183 %ulong_12_0
+ %115 = OpBitcast %_ptr_Generic_uint %184
+ OpStore %115 %87 Aligned 4
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt
index 2ea964c..1b2e3dd 100644
--- a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt
+++ b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt
@@ -24,7 +24,8 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
- %uint_0 = OpConstant %uint 0
+ %ulong_0 = OpConstant %ulong 0
+%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
%1 = OpFunction %void None %40
%10 = OpFunctionParameter %ulong
%11 = OpFunctionParameter %ulong
@@ -54,9 +55,11 @@
%27 = OpConvertUToPtr %_ptr_Workgroup_ulong %17
OpStore %27 %18 Aligned 8
%20 = OpLoad %uint %7
- %24 = OpIAdd %uint %20 %uint_0
- %28 = OpConvertUToPtr %_ptr_Workgroup_ulong %24
- %19 = OpLoad %ulong %28 Aligned 8
+ %28 = OpConvertUToPtr %_ptr_Workgroup_ulong %20
+ %46 = OpBitcast %_ptr_Workgroup_uchar %28
+ %47 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uchar %46 %ulong_0
+ %24 = OpBitcast %_ptr_Workgroup_ulong %47
+ %19 = OpLoad %ulong %24 Aligned 8
OpStore %9 %19
%21 = OpLoad %ulong %6
%22 = OpLoad %ulong %9
diff --git a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt
index 19d5a5a..fd4f893 100644
--- a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt
+++ b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt
@@ -7,27 +7,24 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %33 = OpExtInstImport "OpenCL.std"
+ %31 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %2 "shared_ptr_take_address" %1
OpDecorate %1 Alignment 4
%void = OpTypeVoid
%uchar = OpTypeInt 8 0
%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar
-%_ptr_Workgroup__ptr_Workgroup_uchar = OpTypePointer Workgroup %_ptr_Workgroup_uchar
- %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uchar Workgroup
+ %1 = OpVariable %_ptr_Workgroup_uchar Workgroup
%ulong = OpTypeInt 64 0
- %39 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar
-%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar
+ %36 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
%_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong
- %2 = OpFunction %void None %39
+ %2 = OpFunction %void None %36
%10 = OpFunctionParameter %ulong
%11 = OpFunctionParameter %ulong
- %31 = OpFunctionParameter %_ptr_Workgroup_uchar
- %40 = OpLabel
- %32 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function
+ %30 = OpFunctionParameter %_ptr_Workgroup_uchar
+ %28 = OpLabel
%3 = OpVariable %_ptr_Function_ulong Function
%4 = OpVariable %_ptr_Function_ulong Function
%5 = OpVariable %_ptr_Function_ulong Function
@@ -35,34 +32,30 @@
%7 = OpVariable %_ptr_Function_ulong Function
%8 = OpVariable %_ptr_Function_ulong Function
%9 = OpVariable %_ptr_Function_ulong Function
- OpStore %32 %31
- OpBranch %29
- %29 = OpLabel
OpStore %3 %10
OpStore %4 %11
%12 = OpLoad %ulong %3 Aligned 8
OpStore %5 %12
%13 = OpLoad %ulong %4 Aligned 8
OpStore %6 %13
- %15 = OpLoad %_ptr_Workgroup_uchar %32
- %24 = OpConvertPtrToU %ulong %15
- %14 = OpCopyObject %ulong %24
+ %23 = OpConvertPtrToU %ulong %30
+ %14 = OpCopyObject %ulong %23
OpStore %7 %14
- %17 = OpLoad %ulong %5
- %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %17
- %16 = OpLoad %ulong %25 Aligned 8
- OpStore %8 %16
- %18 = OpLoad %ulong %7
- %19 = OpLoad %ulong %8
- %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %18
- OpStore %26 %19 Aligned 8
- %21 = OpLoad %ulong %7
- %27 = OpConvertUToPtr %_ptr_Workgroup_ulong %21
- %20 = OpLoad %ulong %27 Aligned 8
- OpStore %9 %20
- %22 = OpLoad %ulong %6
- %23 = OpLoad %ulong %9
- %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %22
- OpStore %28 %23 Aligned 8
+ %16 = OpLoad %ulong %5
+ %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16
+ %15 = OpLoad %ulong %24 Aligned 8
+ OpStore %8 %15
+ %17 = OpLoad %ulong %7
+ %18 = OpLoad %ulong %8
+ %25 = OpConvertUToPtr %_ptr_Workgroup_ulong %17
+ OpStore %25 %18 Aligned 8
+ %20 = OpLoad %ulong %7
+ %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %20
+ %19 = OpLoad %ulong %26 Aligned 8
+ OpStore %9 %19
+ %21 = OpLoad %ulong %6
+ %22 = OpLoad %ulong %9
+ %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21
+ OpStore %27 %22 Aligned 8
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt
index 33812f6..cf0d86e 100644
--- a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt
+++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt
@@ -7,7 +7,7 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %50 = OpExtInstImport "OpenCL.std"
+ %54 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "stateful_ld_st_ntid" %gl_LocalInvocationID
OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
@@ -18,34 +18,34 @@
%gl_LocalInvocationID = OpVariable %_ptr_Input_v3ulong Input
%uchar = OpTypeInt 8 0
%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
- %57 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+ %61 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
- %1 = OpFunction %void None %57
+ %1 = OpFunction %void None %61
%20 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
%21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
- %48 = OpLabel
- %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %52 = OpLabel
+ %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %13 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%6 = OpVariable %_ptr_Function_uint Function
%7 = OpVariable %_ptr_Function_ulong Function
%8 = OpVariable %_ptr_Function_ulong Function
- OpStore %2 %20
- OpStore %3 %21
- %13 = OpBitcast %_ptr_Function_ulong %2
- %44 = OpLoad %ulong %13 Aligned 8
- %12 = OpCopyObject %ulong %44
- %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %12
+ OpStore %12 %20
+ OpStore %13 %21
+ %45 = OpBitcast %_ptr_Function_ulong %12
+ %44 = OpLoad %ulong %45 Aligned 8
+ %14 = OpCopyObject %ulong %44
+ %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14
OpStore %10 %22
- %15 = OpBitcast %_ptr_Function_ulong %3
- %45 = OpLoad %ulong %15 Aligned 8
- %14 = OpCopyObject %ulong %45
- %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14
+ %47 = OpBitcast %_ptr_Function_ulong %13
+ %46 = OpLoad %ulong %47 Aligned 8
+ %15 = OpCopyObject %ulong %46
+ %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15
OpStore %11 %23
%24 = OpLoad %_ptr_CrossWorkgroup_uchar %10
%17 = OpConvertPtrToU %ulong %24
@@ -57,35 +57,37 @@
%18 = OpCopyObject %ulong %19
%27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18
OpStore %11 %27
- %62 = OpLoad %v3ulong %gl_LocalInvocationID
- %43 = OpCompositeExtract %ulong %62 0
- %63 = OpBitcast %ulong %43
- %29 = OpUConvert %uint %63
+ %66 = OpLoad %v3ulong %gl_LocalInvocationID
+ %43 = OpCompositeExtract %ulong %66 0
+ %67 = OpBitcast %ulong %43
+ %29 = OpUConvert %uint %67
%28 = OpCopyObject %uint %29
OpStore %6 %28
%31 = OpLoad %uint %6
- %64 = OpBitcast %uint %31
- %30 = OpUConvert %ulong %64
+ %68 = OpBitcast %uint %31
+ %30 = OpUConvert %ulong %68
OpStore %7 %30
%33 = OpLoad %_ptr_CrossWorkgroup_uchar %10
%34 = OpLoad %ulong %7
- %65 = OpBitcast %_ptr_CrossWorkgroup_uchar %33
- %66 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %65 %34
- %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %66
+ %48 = OpCopyObject %ulong %34
+ %69 = OpBitcast %_ptr_CrossWorkgroup_uchar %33
+ %70 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %69 %48
+ %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %70
OpStore %10 %32
%36 = OpLoad %_ptr_CrossWorkgroup_uchar %11
%37 = OpLoad %ulong %7
- %67 = OpBitcast %_ptr_CrossWorkgroup_uchar %36
- %68 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %67 %37
- %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %68
+ %49 = OpCopyObject %ulong %37
+ %71 = OpBitcast %_ptr_CrossWorkgroup_uchar %36
+ %72 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %71 %49
+ %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %72
OpStore %11 %35
%39 = OpLoad %_ptr_CrossWorkgroup_uchar %10
- %46 = OpBitcast %_ptr_CrossWorkgroup_ulong %39
- %38 = OpLoad %ulong %46 Aligned 8
+ %50 = OpBitcast %_ptr_CrossWorkgroup_ulong %39
+ %38 = OpLoad %ulong %50 Aligned 8
OpStore %8 %38
%40 = OpLoad %_ptr_CrossWorkgroup_uchar %11
%41 = OpLoad %ulong %8
- %47 = OpBitcast %_ptr_CrossWorkgroup_ulong %40
- OpStore %47 %41 Aligned 8
+ %51 = OpBitcast %_ptr_CrossWorkgroup_ulong %40
+ OpStore %51 %41 Aligned 8
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt
index cb77d14..97bf000 100644
--- a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt
+++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt
@@ -7,7 +7,7 @@
OpCapability Int64
OpCapability Float16
OpCapability Float64
- %58 = OpExtInstImport "OpenCL.std"
+ %62 = OpExtInstImport "OpenCL.std"
OpMemoryModel Physical64 OpenCL
OpEntryPoint Kernel %1 "stateful_ld_st_ntid_chain" %gl_LocalInvocationID
OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId
@@ -18,18 +18,18 @@
%gl_LocalInvocationID = OpVariable %_ptr_Input_v3ulong Input
%uchar = OpTypeInt 8 0
%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar
- %65 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
+ %69 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar
%_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar
%uint = OpTypeInt 32 0
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Function_ulong = OpTypePointer Function %ulong
%_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong
- %1 = OpFunction %void None %65
+ %1 = OpFunction %void None %69
%28 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
%29 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar
- %56 = OpLabel
- %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
- %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %60 = OpLabel
+ %20 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
+ %21 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
%16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function
@@ -39,17 +39,17 @@
%10 = OpVariable %_ptr_Function_uint Function
%11 = OpVariable %_ptr_Function_ulong Function
%12 = OpVariable %_ptr_Function_ulong Function
- OpStore %2 %28
- OpStore %3 %29
- %21 = OpBitcast %_ptr_Function_ulong %2
- %52 = OpLoad %ulong %21 Aligned 8
- %20 = OpCopyObject %ulong %52
- %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20
+ OpStore %20 %28
+ OpStore %21 %29
+ %53 = OpBitcast %_ptr_Function_ulong %20
+ %52 = OpLoad %ulong %53 Aligned 8
+ %22 = OpCopyObject %ulong %52
+ %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22
OpStore %14 %30
- %23 = OpBitcast %_ptr_Function_ulong %3
- %53 = OpLoad %ulong %23 Aligned 8
- %22 = OpCopyObject %ulong %53
- %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22
+ %55 = OpBitcast %_ptr_Function_ulong %21
+ %54 = OpLoad %ulong %55 Aligned 8
+ %23 = OpCopyObject %ulong %54
+ %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23
OpStore %17 %31
%32 = OpLoad %_ptr_CrossWorkgroup_uchar %14
%25 = OpConvertPtrToU %ulong %32
@@ -61,35 +61,37 @@
%26 = OpCopyObject %ulong %27
%35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26
OpStore %18 %35
- %70 = OpLoad %v3ulong %gl_LocalInvocationID
- %51 = OpCompositeExtract %ulong %70 0
- %71 = OpBitcast %ulong %51
- %37 = OpUConvert %uint %71
+ %74 = OpLoad %v3ulong %gl_LocalInvocationID
+ %51 = OpCompositeExtract %ulong %74 0
+ %75 = OpBitcast %ulong %51
+ %37 = OpUConvert %uint %75
%36 = OpCopyObject %uint %37
OpStore %10 %36
%39 = OpLoad %uint %10
- %72 = OpBitcast %uint %39
- %38 = OpUConvert %ulong %72
+ %76 = OpBitcast %uint %39
+ %38 = OpUConvert %ulong %76
OpStore %11 %38
%41 = OpLoad %_ptr_CrossWorkgroup_uchar %15
%42 = OpLoad %ulong %11
- %73 = OpBitcast %_ptr_CrossWorkgroup_uchar %41
- %74 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %73 %42
- %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %74
+ %56 = OpCopyObject %ulong %42
+ %77 = OpBitcast %_ptr_CrossWorkgroup_uchar %41
+ %78 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %77 %56
+ %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %78
OpStore %16 %40
%44 = OpLoad %_ptr_CrossWorkgroup_uchar %18
%45 = OpLoad %ulong %11
- %75 = OpBitcast %_ptr_CrossWorkgroup_uchar %44
- %76 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %75 %45
- %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %76
+ %57 = OpCopyObject %ulong %45
+ %79 = OpBitcast %_ptr_CrossWorkgroup_uchar %44
+ %80 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %79 %57
+ %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %80
OpStore %19 %43
%47 = OpLoad %_ptr_CrossWorkgroup_uchar %16
- %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %47
- %46 = OpLoad %ulong %54 Aligned 8
+ %58 = OpBitcast %_ptr_CrossWorkgroup_ulong %47
+ %46 = OpLoad %ulong %58 Aligned 8
OpStore %12 %46
%48 = OpLoad %_ptr_CrossWorkgroup_uchar %19
%49 = OpLoad %ulong %12
- %55 = OpBitcast %_ptr_CrossWorkgroup_ulong %48
- OpStore %55 %49 Aligned 8
+ %59 = OpBitcast %_ptr_CrossWorkgroup_ulong %48
+ OpStore %59 %49 Aligned 8
OpReturn
OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt
index ecf2858..8253bf9 100644
--- a/ptx/src/test/spirv_run/vector.spvtxt
+++ b/ptx/src/test/spirv_run/vector.spvtxt
@@ -25,8 +25,8 @@
%1 = OpFunction %v2uint None %55
%7 = OpFunctionParameter %v2uint
%24 = OpLabel
- %2 = OpVariable %_ptr_Function_v2uint Function
%3 = OpVariable %_ptr_Function_v2uint Function
+ %2 = OpVariable %_ptr_Function_v2uint Function
%4 = OpVariable %_ptr_Function_v2uint Function
%5 = OpVariable %_ptr_Function_uint Function
%6 = OpVariable %_ptr_Function_uint Function
diff --git a/ptx/src/test/spirv_run/verify.py b/ptx/src/test/spirv_run/verify.py
new file mode 100644
index 0000000..dbfab00
--- /dev/null
+++ b/ptx/src/test/spirv_run/verify.py
@@ -0,0 +1,21 @@
+import os, sys, subprocess
+
+def main(path):
+ dirs = os.listdir(path)
+ for file in dirs:
+ if not file.endswith(".spvtxt"):
+ continue
+ full_file = os.path.join(path, file)
+ print(file)
+ spv_file = f"/tmp/{file}.spv"
+ # We nominally emit spv1.3, but use spv1.4 feature (OpEntryPoint interface changes in 1.4)
+ proc1 = subprocess.run(["spirv-as", "--target-env", "spv1.4", full_file, "-o", spv_file])
+ proc2 = subprocess.run(["spirv-dis", spv_file, "-o", f"{spv_file}.dis.txt"])
+ proc3 = subprocess.run(["spirv-val", spv_file ])
+ if proc1.returncode != 0 or proc2.returncode != 0 or proc3.returncode != 0:
+ print(proc1.returncode)
+ print(proc2.returncode)
+ print(proc3.returncode)
+
+if __name__ == "__main__":
+ main(sys.argv[1])
diff --git a/ptx/src/test/spirv_run/xor.spvtxt b/ptx/src/test/spirv_run/xor.spvtxt
index 4cc8968..c3a1f6f 100644
--- a/ptx/src/test/spirv_run/xor.spvtxt
+++ b/ptx/src/test/spirv_run/xor.spvtxt
@@ -18,6 +18,8 @@
%_ptr_Function_uint = OpTypePointer Function %uint
%_ptr_Generic_uint = OpTypePointer Generic %uint
%ulong_4 = OpConstant %ulong 4
+ %uchar = OpTypeInt 8 0
+%_ptr_Generic_uchar = OpTypePointer Generic %uchar
%1 = OpFunction %void None %31
%8 = OpFunctionParameter %ulong
%9 = OpFunctionParameter %ulong
@@ -39,9 +41,11 @@
%12 = OpLoad %uint %23 Aligned 4
OpStore %6 %12
%15 = OpLoad %ulong %4
- %22 = OpIAdd %ulong %15 %ulong_4
- %24 = OpConvertUToPtr %_ptr_Generic_uint %22
- %14 = OpLoad %uint %24 Aligned 4
+ %24 = OpConvertUToPtr %_ptr_Generic_uint %15
+ %38 = OpBitcast %_ptr_Generic_uchar %24
+ %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4
+ %22 = OpBitcast %_ptr_Generic_uint %39
+ %14 = OpLoad %uint %22 Aligned 4
OpStore %7 %14
%17 = OpLoad %uint %6
%18 = OpLoad %uint %7
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 7170950..c2562c3 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -1,11 +1,9 @@
use crate::ast;
use half::f16;
use rspirv::dr;
-use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem};
-use std::{
- collections::{hash_map, HashMap, HashSet},
- convert::TryInto,
-};
+use std::cell::RefCell;
+use std::collections::{hash_map, HashMap, HashSet};
+use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc};
use rspirv::binary::Assemble;
@@ -48,64 +46,21 @@ enum SpirvType {
}
impl SpirvType {
- fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
- let key = t.into();
- SpirvType::Pointer(Box::new(key), sc)
- }
-}
-
-impl From<ast::Type> for SpirvType {
- fn from(t: ast::Type) -> Self {
+ fn new(t: ast::Type) -> Self {
match t {
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
- ast::Type::Pointer(pointer_t, state_space) => SpirvType::Pointer(
- Box::new(SpirvType::from(ast::Type::from(pointer_t))),
- state_space.to_spirv(),
+ ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer(
+ Box::new(SpirvType::Base(pointer_t.into())),
+ space.to_spirv(),
),
}
}
-}
-impl From<ast::PointerType> for ast::Type {
- fn from(t: ast::PointerType) -> Self {
- match t {
- ast::PointerType::Scalar(t) => ast::Type::Scalar(t),
- ast::PointerType::Vector(t, len) => ast::Type::Vector(t, len),
- ast::PointerType::Array(t, dims) => ast::Type::Array(t, dims),
- ast::PointerType::Pointer(t, space) => {
- ast::Type::Pointer(ast::PointerType::Scalar(t), space)
- }
- }
- }
-}
-
-impl ast::Type {
- fn param_pointer_to(self, space: ast::LdStateSpace) -> Result<Self, TranslateError> {
- Ok(match self {
- ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
- ast::Type::Vector(t, len) => {
- ast::Type::Pointer(ast::PointerType::Vector(t, len), space)
- }
- ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
- ast::Type::Pointer(ast::PointerType::Scalar(t), space) => {
- ast::Type::Pointer(ast::PointerType::Pointer(t, space), space)
- }
- ast::Type::Pointer(_, _) => return Err(error_unreachable()),
- })
- }
-}
-
-impl Into<spirv::StorageClass> for ast::PointerStateSpace {
- fn into(self) -> spirv::StorageClass {
- match self {
- ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant,
- ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
- ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup,
- ast::PointerStateSpace::Param => spirv::StorageClass::Function,
- ast::PointerStateSpace::Generic => spirv::StorageClass::Generic,
- }
+ fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self {
+ let key = Self::new(t);
+ SpirvType::Pointer(Box::new(key), outer_space)
}
}
@@ -213,14 +168,18 @@ impl TypeWordMap {
.or_insert_with(|| b.type_vector(None, base, len as u32))
}
SpirvType::Array(typ, array_dimensions) => {
- let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let (base_type, length) = match &*array_dimensions {
+ &[] => {
+ return self.get_or_add(b, SpirvType::Base(typ));
+ }
&[len] => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let base = self.get_or_add_spirv_scalar(b, typ);
let len_const = b.constant_u32(u32_type, None, len);
(base, len_const)
}
array_dimensions => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let base = self
.get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
let len_const = b.constant_u32(u32_type, None, array_dimensions[0]);
@@ -262,7 +221,7 @@ impl TypeWordMap {
fn get_or_add_fn(
&mut self,
b: &mut dr::Builder,
- in_params: impl ExactSizeIterator<Item = SpirvType>,
+ in_params: impl Iterator<Item = SpirvType>,
mut out_params: impl ExactSizeIterator<Item = SpirvType>,
) -> (spirv::Word, spirv::Word) {
let (out_args, out_spirv_type) = if out_params.len() == 0 {
@@ -274,6 +233,7 @@ impl TypeWordMap {
self.get_or_add(b, arg_as_key),
)
} else {
+ // TODO: support multiple return values
todo!()
};
(
@@ -410,18 +370,7 @@ impl TypeWordMap {
b.constant_composite(result_type, None, components.into_iter())
}
},
- ast::Type::Pointer(typ, state_space) => {
- let base_t = typ.clone().into();
- let base = self.get_or_add_constant(b, &base_t, &[])?;
- let result_type = self.get_or_add(
- b,
- SpirvType::Pointer(
- Box::new(SpirvType::from(base_t)),
- (*state_space).to_spirv(),
- ),
- );
- b.variable(result_type, None, (*state_space).to_spirv(), Some(base))
- }
+ ast::Type::Pointer(..) => return Err(error_unreachable()),
})
}
@@ -487,7 +436,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
.collect::<Vec<_>>();
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
- let call_map = get_call_map(&directives);
+ let call_map = get_kernels_call_map(&directives);
let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
@@ -525,9 +474,12 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro
}
// TODO: remove this once we have perf-function support for denorms
-fn emit_denorm_build_string(
+fn emit_denorm_build_string<'input>(
call_map: &HashMap<&str, HashSet<u32>>,
- denorm_information: &HashMap<MethodName, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+ denorm_information: &HashMap<
+ ast::MethodName<'input, spirv::Word>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
) -> CString {
let denorm_counts = denorm_information
.iter()
@@ -545,10 +497,12 @@ fn emit_denorm_build_string(
.collect::<HashMap<_, _>>();
let mut flush_over_preserve = 0;
for (kernel, children) in call_map {
- flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0);
+ flush_over_preserve += *denorm_counts
+ .get(&ast::MethodName::Kernel(kernel))
+ .unwrap_or(&0);
for child_fn in children {
flush_over_preserve += *denorm_counts
- .get(&MethodName::Func(*child_fn))
+ .get(&ast::MethodName::Func(*child_fn))
.unwrap_or(&0);
}
}
@@ -564,15 +518,18 @@ fn emit_directives<'input>(
map: &mut TypeWordMap,
id_defs: &GlobalStringIdResolver<'input>,
opencl_id: spirv::Word,
- denorm_information: &HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+ denorm_information: &HashMap<
+ ast::MethodName<'input, spirv::Word>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
- directives: Vec<Directive>,
+ directives: Vec<Directive<'input>>,
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
let empty_body = Vec::new();
for d in directives.iter() {
match d {
- Directive::Variable(var) => {
+ Directive::Variable(_, var) => {
emit_variable(builder, map, &var)?;
}
Directive::Method(f) => {
@@ -589,12 +546,13 @@ fn emit_directives<'input>(
for var in f.globals.iter() {
emit_variable(builder, map, var)?;
}
+ let func_decl = (*f.func_decl).borrow();
let fn_id = emit_function_header(
builder,
map,
&id_defs,
&f.globals,
- &f.spirv_decl,
+ &*func_decl,
&denorm_information,
call_map,
&directives,
@@ -623,8 +581,13 @@ fn emit_directives<'input>(
}
emit_function_body_ops(builder, map, opencl_id, &f_body)?;
builder.end_function()?;
- if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) =
- (&f.func_decl, &f.import_as)
+ if let (
+ ast::MethodDeclaration {
+ name: ast::MethodName::Func(fn_id),
+ ..
+ },
+ Some(name),
+ ) = (&*func_decl, &f.import_as)
{
builder.decorate(
*fn_id,
@@ -643,7 +606,7 @@ fn emit_directives<'input>(
Ok(())
}
-fn get_call_map<'input>(
+fn get_kernels_call_map<'input>(
module: &[Directive<'input>],
) -> HashMap<&'input str, HashSet<spirv::Word>> {
let mut directly_called_by = HashMap::new();
@@ -654,14 +617,14 @@ fn get_call_map<'input>(
body: Some(statements),
..
}) => {
- let call_key = MethodName::new(&func_decl);
+ let call_key: ast::MethodName<_> = (**func_decl).borrow().name;
if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
entry.insert(Vec::new());
}
for statement in statements {
match statement {
Statement::Call(call) => {
- multi_hash_map_append(&mut directly_called_by, call_key, call.func);
+ multi_hash_map_append(&mut directly_called_by, call_key, call.name);
}
_ => {}
}
@@ -673,28 +636,28 @@ fn get_call_map<'input>(
let mut result = HashMap::new();
for (method_key, children) in directly_called_by.iter() {
match method_key {
- MethodName::Kernel(name) => {
+ ast::MethodName::Kernel(name) => {
let mut visited = HashSet::new();
for child in children {
add_call_map_single(&directly_called_by, &mut visited, *child);
}
result.insert(*name, visited);
}
- MethodName::Func(_) => {}
+ ast::MethodName::Func(_) => {}
}
}
result
}
fn add_call_map_single<'input>(
- directly_called_by: &MultiHashMap<MethodName<'input>, spirv::Word>,
+ directly_called_by: &MultiHashMap<ast::MethodName<'input, spirv::Word>, spirv::Word>,
visited: &mut HashSet<spirv::Word>,
current: spirv::Word,
) {
if !visited.insert(current) {
return;
}
- if let Some(children) = directly_called_by.get(&MethodName::Func(current)) {
+ if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) {
for child in children {
add_call_map_single(directly_called_by, visited, *child);
}
@@ -714,11 +677,29 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
}
}
-// PTX represents dynamically allocated shared local memory as
-// .extern .shared .align 4 .b8 shared_mem[];
-// In SPIRV/OpenCL world this is expressed as an additional argument
-// This pass looks for all uses of .extern .shared and converts them to
-// an additional method argument
+/*
+ PTX represents dynamically allocated shared local memory as
+ .extern .shared .b32 shared_mem[];
+ In SPIRV/OpenCL world this is expressed as an additional argument
+ This pass looks for all uses of .extern .shared and converts them to
+ an additional method argument
+ The question is how this artificial argument should be expressed. There are
+ several options:
+ * Straight conversion:
+ .shared .b32 shared_mem[]
+ * Introduce .param_shared statespace:
+ .param_shared .b32 shared_mem
+ or
+ .param_shared .b32 shared_mem[]
+ * Introduce .shared_ptr <SCALAR> type:
+ .param .shared_ptr .b32 shared_mem
+ * Reuse .ptr hint:
+ .param .u64 .ptr shared_mem
+ This is the most tempting, but also the most nonsensical, .ptr is just a
+ hint, which has no semantical meaning (and the output of our
+ transformation has a semantical meaning - we emit additional
+ "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
+*/
fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>,
new_id: &mut impl FnMut() -> spirv::Word,
@@ -726,12 +707,16 @@ fn convert_dynamic_shared_memory_usage<'input>(
let mut extern_shared_decls = HashMap::new();
for dir in module.iter() {
match dir {
- Directive::Variable(var) => {
- if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) =
- var.v_type
- {
- extern_shared_decls.insert(var.name, p_type);
- }
+ Directive::Variable(
+ linking,
+ ast::Variable {
+ v_type: ast::Type::Array(p_type, dims),
+ state_space: ast::StateSpace::Shared,
+ name,
+ ..
+ },
+ ) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => {
+ extern_shared_decls.insert(*name, *p_type);
}
_ => {}
}
@@ -749,15 +734,14 @@ fn convert_dynamic_shared_memory_usage<'input>(
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
}) => {
- let call_key = MethodName::new(&func_decl);
+ let call_key = (*func_decl).borrow().name;
let statements = statements
.into_iter()
.map(|statement| match statement {
Statement::Call(call) => {
- multi_hash_map_append(&mut directly_called_by, call.func, call_key);
+ multi_hash_map_append(&mut directly_called_by, call.name, call_key);
Statement::Call(call)
}
statement => statement.map_id(&mut |id, _| {
@@ -773,7 +757,6 @@ fn convert_dynamic_shared_memory_usage<'input>(
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
})
}
@@ -792,66 +775,34 @@ fn convert_dynamic_shared_memory_usage<'input>(
globals,
body: Some(statements),
import_as,
- mut spirv_decl,
tuning,
}) => {
- if !methods_using_extern_shared.contains(&spirv_decl.name) {
+ if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
return Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
});
}
let shared_id_param = new_id();
- spirv_decl.input.push({
- ast::Variable {
- align: None,
- v_type: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Shared,
- ),
- array_init: Vec::new(),
- name: shared_id_param,
- }
- });
- spirv_decl.uses_shared_mem = true;
- let shared_var_id = new_id();
- let shared_var = ExpandedStatement::Variable(ast::Variable {
- align: None,
- name: shared_var_id,
- array_init: Vec::new(),
- v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
- ast::SizedScalarType::B8,
- ast::PointerStateSpace::Shared,
- )),
- });
- let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
- arg: ast::Arg2St {
- src1: shared_var_id,
- src2: shared_id_param,
- },
- typ: ast::Type::Scalar(ast::ScalarType::B8),
- member_index: None,
- });
- let mut new_statements = vec![shared_var, shared_var_st];
- replace_uses_of_shared_memory(
- &mut new_statements,
+ {
+ let mut func_decl = (*func_decl).borrow_mut();
+ func_decl.shared_mem = Some(shared_id_param);
+ }
+ let statements = replace_uses_of_shared_memory(
new_id,
&extern_shared_decls,
&mut methods_using_extern_shared,
shared_id_param,
- shared_var_id,
statements,
);
Directive::Method(Function {
func_decl,
globals,
- body: Some(new_statements),
+ body: Some(statements),
import_as,
- spirv_decl,
tuning,
})
}
@@ -861,47 +812,43 @@ fn convert_dynamic_shared_memory_usage<'input>(
}
fn replace_uses_of_shared_memory<'a>(
- result: &mut Vec<ExpandedStatement>,
new_id: &mut impl FnMut() -> spirv::Word,
- extern_shared_decls: &HashMap<spirv::Word, ast::SizedScalarType>,
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
+ extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
+ methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
shared_id_param: spirv::Word,
- shared_var_id: spirv::Word,
statements: Vec<ExpandedStatement>,
-) {
+) -> Vec<ExpandedStatement> {
+ let mut result = Vec::with_capacity(statements.len());
for statement in statements {
match statement {
Statement::Call(mut call) => {
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
- if methods_using_extern_shared.contains(&MethodName::Func(call.func)) {
- call.param_list
- .push((shared_id_param, ast::FnArgumentType::Shared));
+ if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) {
+ call.input_arguments.push((
+ shared_id_param,
+ ast::Type::Scalar(ast::ScalarType::B8),
+ ast::StateSpace::Shared,
+ ));
}
result.push(Statement::Call(call))
}
statement => {
let new_statement = statement.map_id(&mut |id, _| {
- if let Some(typ) = extern_shared_decls.get(&id) {
- if *typ == ast::SizedScalarType::B8 {
- return shared_var_id;
+ if let Some(scalar_type) = extern_shared_decls.get(&id) {
+ if *scalar_type == ast::ScalarType::B8 {
+ return shared_id_param;
}
let replacement_id = new_id();
result.push(Statement::Conversion(ImplicitConversion {
- src: shared_var_id,
+ src: shared_id_param,
dst: replacement_id,
- from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::B8),
- ast::LdStateSpace::Shared,
- ),
- to: ast::Type::Pointer(
- ast::PointerType::Scalar((*typ).into()),
- ast::LdStateSpace::Shared,
- ),
- kind: ConversionKind::PtrToPtr { spirv_ptr: true },
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
+ from_type: ast::Type::Scalar(ast::ScalarType::B8),
+ from_space: ast::StateSpace::Shared,
+ to_type: ast::Type::Scalar(*scalar_type),
+ to_space: ast::StateSpace::Shared,
+ kind: ConversionKind::PtrToPtr,
}));
replacement_id
} else {
@@ -912,16 +859,17 @@ fn replace_uses_of_shared_memory<'a>(
}
}
}
+ result
}
fn get_callers_of_extern_shared<'a>(
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
- directly_called_by: &MultiHashMap<spirv::Word, MethodName<'a>>,
+ methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
+ directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
) {
let direct_uses_of_extern_shared = methods_using_extern_shared
.iter()
.filter_map(|method| {
- if let MethodName::Func(f_id) = method {
+ if let ast::MethodName::Func(f_id) = method {
Some(*f_id)
} else {
None
@@ -934,14 +882,14 @@ fn get_callers_of_extern_shared<'a>(
}
fn get_callers_of_extern_shared_single<'a>(
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
- directly_called_by: &MultiHashMap<spirv::Word, MethodName<'a>>,
+ methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
+ directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
fn_id: spirv::Word,
) {
if let Some(callers) = directly_called_by.get(&fn_id) {
for caller in callers {
if methods_using_extern_shared.insert(*caller) {
- if let MethodName::Func(caller_fn) = caller {
+ if let ast::MethodName::Func(caller_fn) = caller {
get_callers_of_extern_shared_single(
methods_using_extern_shared,
directly_called_by,
@@ -983,18 +931,18 @@ fn denorm_count_map_update_impl<T: Eq + Hash>(
// and emit suitable execution mode
fn compute_denorm_information<'input>(
module: &[Directive<'input>],
-) -> HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
+) -> HashMap<ast::MethodName<'input, spirv::Word>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
let mut denorm_methods = HashMap::new();
for directive in module {
match directive {
- Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {}
+ Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {}
Directive::Method(Function {
func_decl,
body: Some(statements),
..
}) => {
let mut flush_counter = DenormCountMap::new();
- let method_key = MethodName::new(func_decl);
+ let method_key = (**func_decl).borrow().name;
for statement in statements {
match statement {
Statement::Instruction(inst) => {
@@ -1038,21 +986,6 @@ fn compute_denorm_information<'input>(
.collect()
}
-#[derive(Hash, PartialEq, Eq, Copy, Clone)]
-enum MethodName<'input> {
- Kernel(&'input str),
- Func(spirv::Word),
-}
-
-impl<'input> MethodName<'input> {
- fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
- match decl {
- ast::MethodDecl::Kernel { name, .. } => MethodName::Kernel(name),
- ast::MethodDecl::Func(_, id, _) => MethodName::Func(*id),
- }
- }
-}
-
fn emit_builtins(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1061,10 +994,7 @@ fn emit_builtins(
for (reg, id) in id_defs.special_registers.builtins() {
let result_type = map.get_or_add(
builder,
- SpirvType::Pointer(
- Box::new(SpirvType::from(reg.get_type())),
- spirv::StorageClass::Input,
- ),
+ SpirvType::pointer_to(reg.get_type(), spirv::StorageClass::Input),
);
builder.variable(result_type, Some(id), spirv::StorageClass::Input, None);
builder.decorate(
@@ -1079,18 +1009,21 @@ fn emit_function_header<'a>(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
defined_globals: &GlobalStringIdResolver<'a>,
- synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
- func_decl: &SpirvMethodDecl<'a>,
- _denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+ synthetic_globals: &[ast::Variable<spirv::Word>],
+ func_decl: &ast::MethodDeclaration<'a, spirv::Word>,
+ _denorm_information: &HashMap<
+ ast::MethodName<'a, spirv::Word>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
direcitves: &[Directive],
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<spirv::Word, TranslateError> {
- if let MethodName::Kernel(name) = func_decl.name {
- let input_args = if !func_decl.uses_shared_mem {
- func_decl.input.as_slice()
+ if let ast::MethodName::Kernel(name) = func_decl.name {
+ let input_args = if func_decl.shared_mem.is_none() {
+ func_decl.input_arguments.as_slice()
} else {
- &func_decl.input[0..func_decl.input.len() - 1]
+ &func_decl.input_arguments[0..func_decl.input_arguments.len() - 1]
};
let args_lens = input_args
.iter()
@@ -1100,14 +1033,18 @@ fn emit_function_header<'a>(
name.to_string(),
KernelInfo {
arguments_sizes: args_lens,
- uses_shared_mem: func_decl.uses_shared_mem,
+ uses_shared_mem: func_decl.shared_mem.is_some(),
},
);
}
- let (ret_type, func_type) =
- get_function_type(builder, map, &func_decl.input, &func_decl.output);
+ let (ret_type, func_type) = get_function_type(
+ builder,
+ map,
+ func_decl.effective_input_arguments().map(|(_, typ)| typ),
+ &func_decl.return_arguments,
+ );
let fn_id = match func_decl.name {
- MethodName::Kernel(name) => {
+ ast::MethodName::Kernel(name) => {
let fn_id = defined_globals.get_id(name)?;
let mut global_variables = defined_globals
.variables_type_check
@@ -1123,15 +1060,18 @@ fn emit_function_header<'a>(
for directive in direcitves {
match directive {
Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- globals,
- ..
+ func_decl, globals, ..
}) => {
- if child_fns.contains(name) {
- for var in globals {
- interface.push(var.name);
+ match (**func_decl).borrow().name {
+ ast::MethodName::Func(name) => {
+ if child_fns.contains(&name) {
+ for var in globals {
+ interface.push(var.name);
+ }
+ }
}
- }
+ ast::MethodName::Kernel(_) => {}
+ };
}
_ => {}
}
@@ -1140,7 +1080,7 @@ fn emit_function_header<'a>(
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
fn_id
}
- MethodName::Func(name) => name,
+ ast::MethodName::Func(name) => name,
};
builder.begin_function(
ret_type,
@@ -1163,9 +1103,9 @@ fn emit_function_header<'a>(
}
}
*/
- for input in &func_decl.input {
- let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone()));
- builder.function_parameter(Some(input.name), result_type)?;
+ for (name, typ) in func_decl.effective_input_arguments() {
+ let result_type = map.get_or_add(builder, typ);
+ builder.function_parameter(Some(name), result_type)?;
}
Ok(fn_id)
}
@@ -1207,55 +1147,32 @@ fn translate_directive<'input>(
d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
) -> Result<Option<Directive<'input>>, TranslateError> {
Ok(match d {
- ast::Directive::Variable(v) => Some(Directive::Variable(translate_variable(id_defs, v)?)),
- ast::Directive::Method(f) => {
+ ast::Directive::Variable(linking, var) => Some(Directive::Variable(
+ linking,
+ ast::Variable {
+ align: var.align,
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
+ name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
+ array_init: var.array_init,
+ },
+ )),
+ ast::Directive::Method(_, f) => {
translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method)
}
})
}
-fn translate_variable<'a>(
- id_defs: &mut GlobalStringIdResolver<'a>,
- var: ast::Variable<ast::VariableType, &'a str>,
-) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
- let (space, var_type) = var.v_type.to_type();
- let mut is_variable = false;
- let var_type = match space {
- ast::StateSpace::Reg => {
- is_variable = true;
- var_type
- }
- ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?,
- ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
- ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
- ast::StateSpace::Shared => {
- // If it's a pointer it will be translated to a method parameter later
- if let ast::Type::Pointer(..) = var_type {
- is_variable = true;
- var_type
- } else {
- var_type.param_pointer_to(ast::LdStateSpace::Shared)?
- }
- }
- ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
- };
- Ok(ast::Variable {
- align: var.align,
- v_type: var.v_type,
- name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
- array_init: var.array_init,
- })
-}
-
fn translate_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
ptx_impl_imports: &mut HashMap<String, Directive<'a>>,
f: ast::ParsedFunction<'a>,
) -> Result<Option<Function<'a>>, TranslateError> {
let import_as = match &f.func_directive {
- ast::MethodDecl::Func(_, "__assertfail", _) => {
- Some("__zluda_ptx_impl____assertfail".to_owned())
- }
+ ast::MethodDeclaration {
+ name: ast::MethodName::Func("__assertfail"),
+ ..
+ } => Some("__zluda_ptx_impl____assertfail".to_owned()),
_ => None,
};
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
@@ -1279,63 +1196,38 @@ fn translate_function<'a>(
}
}
-fn expand_kernel_params<'a, 'b>(
+fn rename_fn_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::KernelArgument<&'a str>>,
-) -> Result<Vec<ast::KernelArgument<spirv::Word>>, TranslateError> {
- args.map(|a| {
- Ok(ast::KernelArgument {
- name: fn_resolver.add_def(
- a.name,
- Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?),
- false,
- ),
+ args: &'b [ast::Variable<&'a str>],
+) -> Vec<ast::Variable<spirv::Word>> {
+ args.iter()
+ .map(|a| ast::Variable {
+ name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true),
v_type: a.v_type.clone(),
+ state_space: a.state_space,
align: a.align,
- array_init: Vec::new(),
+ array_init: a.array_init.clone(),
})
- })
- .collect::<Result<_, _>>()
-}
-
-fn expand_fn_params<'a, 'b>(
- fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
-) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
- args.map(|a| {
- let is_variable = match a.v_type {
- ast::FnArgumentType::Reg(_) => true,
- _ => false,
- };
- let var_type = a.v_type.to_func_type();
- Ok(ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some(var_type), is_variable),
- v_type: a.v_type.clone(),
- align: a.align,
- array_init: Vec::new(),
- })
- })
- .collect()
+ .collect()
}
fn to_ssa<'input, 'b>(
ptx_impl_imports: &mut HashMap<String, Directive>,
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
- f_args: ast::MethodDecl<'input, spirv::Word>,
+ func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
tuning: Vec<ast::TuningDirective>,
) -> Result<Function<'input>, TranslateError> {
- let mut spirv_decl = SpirvMethodDecl::new(&f_args);
+ //deparamize_function_decl(&func_decl)?;
let f_body = match f_body {
Some(vec) => vec,
None => {
return Ok(Function {
- func_decl: f_args,
+ func_decl: func_decl,
body: None,
globals: Vec::new(),
import_as: None,
- spirv_decl,
tuning,
})
}
@@ -1345,15 +1237,14 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
- let typed_statements =
- convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
+ let (func_decl, typed_statements) =
+ convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
- &f_args,
- &mut spirv_decl,
+ &mut (*func_decl).borrow_mut(),
)?;
- let ssa_statements = fix_builtins(ssa_statements, &mut numeric_id_defs)?;
+ let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
@@ -1363,16 +1254,15 @@ fn to_ssa<'input, 'b>(
let (f_body, globals) =
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs);
Ok(Function {
- func_decl: f_args,
+ func_decl: func_decl,
globals: globals,
body: Some(f_body),
import_as: None,
- spirv_decl,
tuning,
})
}
-fn fix_builtins(
+fn fix_special_registers(
typed_statements: Vec<TypedStatement>,
numeric_id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
@@ -1408,7 +1298,8 @@ fn fix_builtins(
continue;
}
};
- let temp_id = numeric_id_defs.new_non_variable(Some(details.typ.clone()));
+ let temp_id = numeric_id_defs
+ .register_intermediate(Some((details.typ.clone(), details.state_space)));
let real_dst = details.arg.dst;
details.arg.dst = temp_id;
result.push(Statement::LoadVar(LoadVarDetails {
@@ -1416,17 +1307,18 @@ fn fix_builtins(
src: sreg_src,
dst: temp_id,
},
+ state_space: ast::StateSpace::Sreg,
typ: ast::Type::Scalar(scalar_typ),
member_index: Some((index, Some(vector_width))),
}));
result.push(Statement::Conversion(ImplicitConversion {
src: temp_id,
dst: real_dst,
- from: ast::Type::Scalar(scalar_typ),
- to: ast::Type::Scalar(ast::ScalarType::U32),
+ from_type: ast::Type::Scalar(scalar_typ),
+ from_space: ast::StateSpace::Sreg,
+ to_type: ast::Type::Scalar(ast::ScalarType::U32),
+ to_space: ast::StateSpace::Sreg,
kind: ConversionKind::Default,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
}));
}
}
@@ -1456,10 +1348,7 @@ fn extract_globals<'input, 'b>(
sorted_statements: Vec<ExpandedStatement>,
ptx_impl_imports: &mut HashMap<String, Directive>,
id_def: &mut NumericIdResolver,
-) -> (
- Vec<ExpandedStatement>,
- Vec<ast::Variable<ast::VariableType, spirv::Word>>,
-) {
+) -> (Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>) {
let mut local = Vec::with_capacity(sorted_statements.len());
let mut global = Vec::new();
for statement in sorted_statements {
@@ -1468,7 +1357,7 @@ fn extract_globals<'input, 'b>(
var
@
ast::Variable {
- v_type: ast::VariableType::Shared(_),
+ state_space: ast::StateSpace::Shared,
..
},
)
@@ -1476,7 +1365,7 @@ fn extract_globals<'input, 'b>(
var
@
ast::Variable {
- v_type: ast::VariableType::Global(_),
+ state_space: ast::StateSpace::Global,
..
},
) => global.push(var),
@@ -1505,7 +1394,7 @@ fn extract_globals<'input, 'b>(
d,
a,
"inc",
- ast::SizedScalarType::U32,
+ ast::ScalarType::U32,
));
}
Statement::Instruction(ast::Instruction::Atom(
@@ -1527,7 +1416,7 @@ fn extract_globals<'input, 'b>(
d,
a,
"dec",
- ast::SizedScalarType::U32,
+ ast::ScalarType::U32,
));
}
Statement::Instruction(ast::Instruction::Atom(
@@ -1553,10 +1442,9 @@ fn extract_globals<'input, 'b>(
space,
};
let (op, typ) = match typ {
- ast::FloatType::F32 => ("add_f32", ast::SizedScalarType::F32),
- ast::FloatType::F64 => ("add_f64", ast::SizedScalarType::F64),
- ast::FloatType::F16 => unreachable!(),
- ast::FloatType::F16x2 => unreachable!(),
+ ast::ScalarType::F32 => ("add_f32", ast::ScalarType::F32),
+ ast::ScalarType::F64 => ("add_f64", ast::ScalarType::F64),
+ _ => unreachable!(),
};
local.push(to_ptx_impl_atomic_call(
id_def,
@@ -1599,47 +1487,13 @@ fn convert_to_typed_statements(
match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Call(call) => {
- // TODO: error out if lengths don't match
- let fn_def = fn_defs.get_fn_decl(call.func)?;
- let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
- let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params);
- let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args
- .into_iter()
- .partition(|(_, arg_type)| arg_type.is_param());
- let normalized_input_args = out_params
- .into_iter()
- .map(|(id, typ)| (ast::Operand::Reg(id), typ))
- .chain(in_args.into_iter())
- .collect();
- let resolved_call = ResolvedCall {
- uniform: call.uniform,
- ret_params: out_non_params,
- func: call.func,
- param_list: normalized_input_args,
- };
+ let resolver = fn_defs.get_fn_sig_resolver(call.func)?;
+ let resolved_call = resolver.resolve_in_spirv_repr(call)?;
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let reresolved_call = resolved_call.visit(&mut visitor)?;
visitor.func.push(reresolved_call);
visitor.func.extend(visitor.post_stmts);
}
- ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => {
- if let Some(src_id) = src.underlying() {
- let (typ, _) = id_defs.get_typed(*src_id)?;
- let take_address = match typ {
- ast::Type::Scalar(_) => false,
- ast::Type::Vector(_, _) => false,
- ast::Type::Array(_, _) => true,
- ast::Type::Pointer(_, _) => true,
- };
- d.src_is_address = take_address;
- }
- let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
- let instruction = Statement::Instruction(
- ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?,
- );
- visitor.func.push(instruction);
- visitor.func.extend(visitor.post_stmts);
- }
inst => {
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let instruction = Statement::Instruction(inst.map(&mut visitor)?);
@@ -1674,8 +1528,14 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
fn convert_vector(
&mut self,
is_dst: bool,
- vector_sema: ArgumentSemantics,
+ non_default_implicit_conversion: Option<
+ fn(
+ (ast::StateSpace, &ast::Type),
+ (ast::StateSpace, &ast::Type),
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ >,
typ: &ast::Type,
+ state_space: ast::StateSpace,
idx: Vec<spirv::Word>,
) -> Result<spirv::Word, TranslateError> {
// mov.u32 foobar, {a,b};
@@ -1683,13 +1543,15 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
ast::Type::Vector(scalar_t, _) => *scalar_t,
_ => return Err(TranslateError::MismatchedType),
};
- let temp_vec = self.id_def.new_non_variable(Some(typ.clone()));
+ let temp_vec = self
+ .id_def
+ .register_intermediate(Some((typ.clone(), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
packed: temp_vec,
unpacked: idx,
- vector_sema,
+ non_default_implicit_conversion,
});
if is_dst {
self.post_stmts = Some(statement);
@@ -1706,7 +1568,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- _: Option<&ast::Type>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(desc.op)
}
@@ -1715,15 +1577,20 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
ast::Operand::Reg(reg) => TypedOperand::Reg(reg),
ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
ast::Operand::Imm(x) => TypedOperand::Imm(x),
ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
- ast::Operand::VecPack(vec) => {
- TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?)
- }
+ ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector(
+ desc.is_dst,
+ desc.non_default_implicit_conversion,
+ typ,
+ state_space,
+ vec,
+ )?),
})
}
}
@@ -1735,7 +1602,7 @@ fn to_ptx_impl_atomic_call(
details: ast::AtomDetails,
arg: ast::Arg3<ExpandedArgParams>,
op: &'static str,
- typ: ast::SizedScalarType,
+ typ: ast::ScalarType,
) -> ExpandedStatement {
let semantics = ptx_semantics_name(details.semantics);
let scope = ptx_scope_name(details.scope);
@@ -1745,75 +1612,70 @@ fn to_ptx_impl_atomic_call(
semantics, scope, space, op
);
// TODO: extract to a function
- let ptr_space = match details.space {
- ast::AtomSpace::Generic => ast::PointerStateSpace::Generic,
- ast::AtomSpace::Global => ast::PointerStateSpace::Global,
- ast::AtomSpace::Shared => ast::PointerStateSpace::Shared,
- };
+ let ptr_space = details.space;
let scalar_typ = ast::ScalarType::from(typ);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
+ let fn_id = id_defs.register_intermediate(None);
+ let func_decl = ast::MethodDeclaration::<spirv::Word> {
+ return_arguments: vec![ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(scalar_typ),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
- fn_id,
- vec![
- ast::FnArgument {
+ name: ast::MethodName::Func(fn_id),
+ input_arguments: vec![
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
- typ, ptr_space,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Pointer(typ, ptr_space),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(scalar_typ),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ shared_mem: None,
+ };
let func = Function {
- func_decl,
+ func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
- spirv_decl,
tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => fn_id,
+ ast::MethodName::Kernel(_) => unreachable!(),
+ },
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
uniform: false,
- func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
- )],
- param_list: vec![
+ name: fn_id,
+ return_arguments: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
+ input_arguments: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(typ, ptr_space)),
+ ast::Type::Pointer(typ, ptr_space),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
+ ast::Type::Scalar(scalar_typ),
+ ast::StateSpace::Reg,
),
],
})
@@ -1822,93 +1684,92 @@ fn to_ptx_impl_atomic_call(
fn to_ptx_impl_bfe_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
- typ: ast::IntType,
+ typ: ast::ScalarType,
arg: ast::Arg4<ExpandedArgParams>,
) -> ExpandedStatement {
let prefix = "__zluda_ptx_impl__";
let suffix = match typ {
- ast::IntType::U32 => "bfe_u32",
- ast::IntType::U64 => "bfe_u64",
- ast::IntType::S32 => "bfe_s32",
- ast::IntType::S64 => "bfe_s64",
+ ast::ScalarType::U32 => "bfe_u32",
+ ast::ScalarType::U64 => "bfe_u64",
+ ast::ScalarType::S32 => "bfe_s32",
+ ast::ScalarType::S64 => "bfe_s64",
_ => unreachable!(),
};
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
+ let fn_id = id_defs.register_intermediate(None);
+ let func_decl = ast::MethodDeclaration::<spirv::Word> {
+ return_arguments: vec![ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
- fn_id,
- vec![
- ast::FnArgument {
+ name: ast::MethodName::Func(fn_id),
+ input_arguments: vec![
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ shared_mem: None,
+ };
let func = Function {
- func_decl,
+ func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
- spirv_decl,
tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => fn_id,
+ ast::MethodName::Kernel(_) => unreachable!(),
+ },
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
uniform: false,
- func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- )],
- param_list: vec![
+ name: fn_id,
+ return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
+ input_arguments: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
(
arg.src3,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
],
})
@@ -1917,117 +1778,107 @@ fn to_ptx_impl_bfe_call(
fn to_ptx_impl_bfi_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
- typ: ast::BitType,
+ typ: ast::ScalarType,
arg: ast::Arg5<ExpandedArgParams>,
) -> ExpandedStatement {
let prefix = "__zluda_ptx_impl__";
let suffix = match typ {
- ast::BitType::B32 => "bfi_b32",
- ast::BitType::B64 => "bfi_b64",
- ast::BitType::B8 | ast::BitType::B16 => unreachable!(),
+ ast::ScalarType::B32 => "bfi_b32",
+ ast::ScalarType::B64 => "bfi_b64",
+ _ => unreachable!(),
};
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
+ let fn_id = id_defs.register_intermediate(None);
+ let func_decl = ast::MethodDeclaration::<spirv::Word> {
+ return_arguments: vec![ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
- fn_id,
- vec![
- ast::FnArgument {
+ name: ast::MethodName::Func(fn_id),
+ input_arguments: vec![
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ shared_mem: None,
+ };
let func = Function {
- func_decl,
+ func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
- spirv_decl,
tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => fn_id,
+ ast::MethodName::Kernel(_) => unreachable!(),
+ },
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
uniform: false,
- func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- )],
- param_list: vec![
+ name: fn_id,
+ return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
+ input_arguments: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src3,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
(
arg.src4,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
],
})
}
-fn to_resolved_fn_args<T>(
- params: Vec<T>,
- params_decl: &[ast::FnArgumentType],
-) -> Vec<(T, ast::FnArgumentType)> {
- params
- .into_iter()
- .zip(params_decl.iter())
- .map(|(id, typ)| (id, typ.clone()))
- .collect::<Vec<_>>()
-}
-
fn normalize_labels(
func: Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
@@ -2056,7 +1907,7 @@ fn normalize_labels(
| Statement::RepackVector(..) => {}
}
}
- iter::once(Statement::Label(id_def.new_non_variable(None)))
+ iter::once(Statement::Label(id_def.register_intermediate(None)))
.chain(func.into_iter().filter(|s| match s {
Statement::Label(i) => labels_in_use.contains(i),
_ => true,
@@ -2074,8 +1925,8 @@ fn normalize_predicates(
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Instruction((pred, inst)) => {
if let Some(pred) = pred {
- let if_true = id_def.new_non_variable(None);
- let if_false = id_def.new_non_variable(None);
+ let if_true = id_def.register_intermediate(None);
+ let if_false = id_def.register_intermediate(None);
let folded_bra = match &inst {
ast::Instruction::Bra(_, arg) => Some(arg.src),
_ => None,
@@ -2106,53 +1957,52 @@ fn normalize_predicates(
Ok(result)
}
+/*
+ How do we handle arguments:
+ - input .params in kernels
+ .param .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %ulong
+ %2 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %1
+ We do this for two reasons. One, common treatment for argument-declared
+ .param variables and .param variables inside function (we assume that
+ at SPIR-V level every .param is a pointer in Function storage class)
+ - input .params in functions
+ .param .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %_ptr_Function_ulong
+ - input .regs
+ .reg .b64 in_arg
+ get turned into the same SPIR-V as kernel .params:
+ %1 = OpFunctionParameter %ulong
+ %2 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %1
+ - output .regs
+ .reg .b64 out_arg
+ get just a variable declaration:
+ %2 = OpVariable %%_ptr_Function_ulong Function
+ - output .params don't exist, they have been moved to input positions
+ by an earlier pass
+ Distinguishing betweem kernel .params and function .params is not the
+ cleanest solution. Alternatively, we could "deparamize" all kernel .param
+ arguments by turning them into .reg arguments like this:
+ .param .b64 arg -> .reg ptr<.b64,.param> arg
+ This has the massive downside that this transformation would have to run
+ very early and would muddy up already difficult code. It's simpler to just
+ have an if here
+*/
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
- ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>,
- fn_decl: &mut SpirvMethodDecl,
+ fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>,
) -> Result<Vec<TypedStatement>, TranslateError> {
- let is_func = match ast_fn_decl {
- ast::MethodDecl::Func(..) => true,
- ast::MethodDecl::Kernel { .. } => false,
- };
let mut result = Vec::with_capacity(func.len());
- for arg in fn_decl.output.iter() {
- match type_to_variable_type(&arg.v_type, is_func)? {
- Some(var_type) => {
- result.push(Statement::Variable(ast::Variable {
- align: arg.align,
- v_type: var_type,
- name: arg.name,
- array_init: arg.array_init.clone(),
- }));
- }
- None => return Err(error_unreachable()),
- }
+ for arg in fn_decl.input_arguments.iter_mut() {
+ insert_mem_ssa_argument(id_def, &mut result, arg, fn_decl.name.is_kernel());
}
- for spirv_arg in fn_decl.input.iter_mut() {
- match type_to_variable_type(&spirv_arg.v_type, is_func)? {
- Some(var_type) => {
- let typ = spirv_arg.v_type.clone();
- let new_id = id_def.new_non_variable(Some(typ.clone()));
- result.push(Statement::Variable(ast::Variable {
- align: spirv_arg.align,
- v_type: var_type,
- name: spirv_arg.name,
- array_init: spirv_arg.array_init.clone(),
- }));
- result.push(Statement::StoreVar(StoreVarDetails {
- arg: ast::Arg2St {
- src1: spirv_arg.name,
- src2: new_id,
- },
- typ,
- member_index: None,
- }));
- spirv_arg.name = new_id;
- }
- None => {}
- }
+ for arg in fn_decl.return_arguments.iter() {
+ insert_mem_ssa_argument_reg_return(&mut result, arg);
}
for s in func {
match s {
@@ -2162,32 +2012,41 @@ fn insert_mem_ssa_statements<'a, 'b>(
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
// TODO: handle multiple output args
- if let &[out_param] = &fn_decl.output.as_slice() {
- let (typ, _) = id_def.get_typed(out_param.name)?;
- let new_id = id_def.new_non_variable(Some(typ.clone()));
- result.push(Statement::LoadVar(LoadVarDetails {
- arg: ast::Arg2 {
- dst: new_id,
- src: out_param.name,
- },
- typ: typ.clone(),
- member_index: None,
- }));
- result.push(Statement::RetValue(d, new_id));
- } else {
- result.push(Statement::Instruction(ast::Instruction::Ret(d)))
+ match &fn_decl.return_arguments[..] {
+ [return_reg] => {
+ let new_id = id_def.register_intermediate(Some((
+ return_reg.v_type.clone(),
+ ast::StateSpace::Reg,
+ )));
+ result.push(Statement::LoadVar(LoadVarDetails {
+ arg: ast::Arg2 {
+ dst: new_id,
+ src: return_reg.name,
+ },
+ // TODO: ret with stateful conversion
+ state_space: ast::StateSpace::Reg,
+ typ: return_reg.v_type.clone(),
+ member_index: None,
+ }));
+ result.push(Statement::RetValue(d, new_id));
+ }
+ [] => result.push(Statement::Instruction(ast::Instruction::Ret(d))),
+ _ => unimplemented!(),
}
}
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
},
Statement::Conditional(mut bra) => {
- let generated_id =
- id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
+ let generated_id = id_def.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )));
result.push(Statement::LoadVar(LoadVarDetails {
arg: Arg2 {
dst: generated_id,
src: bra.predicate,
},
+ state_space: ast::StateSpace::Reg,
typ: ast::Type::Scalar(ast::ScalarType::Pred),
member_index: None,
}));
@@ -2210,39 +2069,45 @@ fn insert_mem_ssa_statements<'a, 'b>(
Ok(result)
}
-fn type_to_variable_type(
- t: &ast::Type,
- is_func: bool,
-) -> Result<Option<ast::VariableType>, TranslateError> {
- Ok(match t {
- ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
- ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
- (*typ)
- .try_into()
- .map_err(|_| TranslateError::MismatchedType)?,
- *len,
- ))),
- ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array(
- (*typ)
- .try_into()
- .map_err(|_| TranslateError::MismatchedType)?,
- len.clone(),
- ))),
- ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
- if is_func {
- return Ok(None);
- }
- Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
- scalar_type
- .clone()
- .try_into()
- .map_err(|_| error_unreachable())?,
- (*space).try_into().map_err(|_| error_unreachable())?,
- )))
- }
- ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
- _ => return Err(error_unreachable()),
- })
+fn insert_mem_ssa_argument(
+ id_def: &mut NumericIdResolver,
+ func: &mut Vec<TypedStatement>,
+ arg: &mut ast::Variable<spirv::Word>,
+ is_kernel: bool,
+) {
+ if !is_kernel && arg.state_space == ast::StateSpace::Param {
+ return;
+ }
+ let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
+ func.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: ast::StateSpace::Reg,
+ name: arg.name,
+ array_init: Vec::new(),
+ }));
+ func.push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::Arg2St {
+ src1: arg.name,
+ src2: new_id,
+ },
+ typ: arg.v_type.clone(),
+ member_index: None,
+ }));
+ arg.name = new_id;
+}
+
+fn insert_mem_ssa_argument_reg_return(
+ func: &mut Vec<TypedStatement>,
+ arg: &ast::Variable<spirv::Word>,
+) {
+ func.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: arg.state_space,
+ name: arg.name,
+ array_init: arg.array_init.clone(),
+ }));
}
trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
@@ -2259,6 +2124,7 @@ struct VisitArgumentDescriptor<
> {
desc: ArgumentDescriptor<spirv::Word>,
typ: &'a ast::Type,
+ state_space: ast::StateSpace,
stmt_ctor: Ctor,
}
@@ -2273,7 +2139,9 @@ impl<
self,
visitor: &mut impl ArgumentMapVisitor<T, U>,
) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
- Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?))
+ Ok((self.stmt_ctor)(
+ visitor.id(self.desc, Some((self.typ, self.state_space)))?,
+ ))
}
}
@@ -2287,14 +2155,14 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn symbol(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, Option<u8>)>,
- expected_type: Option<&ast::Type>,
+ expected: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
let symbol = desc.op.0;
- if expected_type.is_none() {
+ if expected.is_none() {
return Ok(symbol);
};
- let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?;
- if !is_variable {
+ let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
+ if !var_space.is_compatible(ast::StateSpace::Reg) || !is_variable {
return Ok(symbol);
};
let member_index = match desc.op.1 {
@@ -2317,13 +2185,16 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
}
None => None,
};
- let generated_id = self.id_def.new_non_variable(Some(var_type.clone()));
+ let generated_id = self
+ .id_def
+ .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
if !desc.is_dst {
self.func.push(Statement::LoadVar(LoadVarDetails {
arg: Arg2 {
dst: generated_id,
src: symbol,
},
+ state_space: ast::StateSpace::Reg,
typ: var_type,
member_index,
}));
@@ -2348,7 +2219,7 @@ impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- typ: Option<&ast::Type>,
+ typ: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self.symbol(desc.new_op((desc.op, None)), typ)
}
@@ -2357,18 +2228,20 @@ impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
&mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
TypedOperand::Reg(reg) => {
- TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
- }
- TypedOperand::RegOffset(reg, offset) => {
- TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset)
+ TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?)
}
+ TypedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(
+ self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?,
+ offset,
+ ),
op @ TypedOperand::Imm(..) => op,
- TypedOperand::VecMember(symbol, index) => {
- TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
- }
+ TypedOperand::VecMember(symbol, index) => TypedOperand::Reg(
+ self.symbol(desc.new_op((symbol, Some(index))), Some((typ, state_space)))?,
+ ),
})
}
}
@@ -2411,11 +2284,13 @@ fn expand_arguments<'a, 'b>(
Statement::Variable(ast::Variable {
align,
v_type,
+ state_space,
name,
array_init,
}) => result.push(Statement::Variable(ast::Variable {
align,
v_type,
+ state_space,
name,
array_init,
})),
@@ -2464,7 +2339,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn reg(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- _: Option<&ast::Type>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(desc.op)
}
@@ -2473,108 +2348,86 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
&mut self,
desc: ArgumentDescriptor<(spirv::Word, i32)>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
let (reg, offset) = desc.op;
- let add_type;
- match typ {
- ast::Type::Pointer(underlying_type, state_space) => {
- let reg_typ = self.id_def.get_typed(reg)?;
- if let ast::Type::Pointer(_, _) = reg_typ {
- let id_constant_stmt = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: ast::ScalarType::S64,
- value: ast::ImmediateValue::S64(offset as i64),
- }));
- let dst = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::PtrAccess(PtrAccess {
- underlying_type: underlying_type.clone(),
- state_space: *state_space,
- dst,
- ptr_src: reg,
- offset_src: id_constant_stmt,
- }));
- return Ok(dst);
- } else {
- add_type = self.id_def.get_typed(reg)?;
- }
- }
- _ => {
- add_type = typ.clone();
+ if !desc.is_memory_access {
+ let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
+ if !reg_space.is_compatible(ast::StateSpace::Reg) {
+ return Err(TranslateError::MismatchedType);
}
- };
- let (width, kind) = match add_type {
- ast::Type::Scalar(scalar_t) => {
- let kind = match scalar_t.kind() {
- kind @ ScalarKind::Bit
- | kind @ ScalarKind::Unsigned
- | kind @ ScalarKind::Signed => kind,
- ScalarKind::Float => return Err(TranslateError::MismatchedType),
- ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
- ScalarKind::Pred => return Err(TranslateError::MismatchedType),
- };
- (scalar_t.size_of(), kind)
- }
- _ => return Err(TranslateError::MismatchedType),
- };
- let arith_detail = if kind == ScalarKind::Signed {
- ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::from_size(width),
- saturate: false,
- })
- } else {
- ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
- };
- let id_constant_stmt = self.id_def.new_non_variable(add_type.clone());
- let result_id = self.id_def.new_non_variable(add_type);
- // TODO: check for edge cases around min value/max value/wrapping
- if offset < 0 && kind != ScalarKind::Signed {
+ let reg_scalar_type = match reg_type {
+ ast::Type::Scalar(underlying_type) => underlying_type,
+ _ => return Err(TranslateError::MismatchedType),
+ };
+ let id_constant_stmt = self
+ .id_def
+ .register_intermediate(reg_type.clone(), ast::StateSpace::Reg);
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
- typ: ast::ScalarType::from_parts(width, kind),
- value: ast::ImmediateValue::U64(-(offset as i64) as u64),
+ typ: reg_scalar_type,
+ value: ast::ImmediateValue::S64(offset as i64),
}));
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Sub(
- arith_detail,
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
+ let arith_details = match reg_scalar_type.kind() {
+ ast::ScalarKind::Signed => ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: reg_scalar_type,
+ saturate: false,
+ }),
+ ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
+ ast::ArithDetails::Unsigned(reg_scalar_type)
+ }
+ _ => return Err(error_unreachable()),
+ };
+ let id_add_result = self.id_def.register_intermediate(reg_type, state_space);
+ self.func.push(Statement::Instruction(ast::Instruction::Add(
+ arith_details,
+ ast::Arg3 {
+ dst: id_add_result,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ )));
+ Ok(id_add_result)
} else {
+ let scalar_type = match typ {
+ ast::Type::Scalar(underlying_type) => *underlying_type,
+ _ => return Err(error_unreachable()),
+ };
+ let id_constant_stmt = self.id_def.register_intermediate(
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ );
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
- typ: ast::ScalarType::from_parts(width, kind),
+ typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64),
}));
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- arith_detail,
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
+ let dst = self.id_def.register_intermediate(typ.clone(), state_space);
+ self.func.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: scalar_type,
+ state_space: state_space,
+ dst,
+ ptr_src: reg,
+ offset_src: id_constant_stmt,
+ }));
+ Ok(dst)
}
- Ok(result_id)
}
fn immediate(
&mut self,
desc: ArgumentDescriptor<ast::ImmediateValue>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
*scalar
} else {
todo!()
};
- let id = self.id_def.new_non_variable(ast::Type::Scalar(scalar_t));
+ let id = self
+ .id_def
+ .register_intermediate(ast::Type::Scalar(scalar_t), state_space);
self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
@@ -2588,7 +2441,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self.reg(desc, t)
}
@@ -2597,12 +2450,13 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
&mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
- TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
- TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ),
+ TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some((typ, state_space))),
+ TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ, state_space),
TypedOperand::RegOffset(reg, offset) => {
- self.reg_offset(desc.new_op((reg, offset)), typ)
+ self.reg_offset(desc.new_op((reg, offset)), typ, state_space)
}
TypedOperand::VecMember(..) => Err(error_unreachable()),
}
@@ -2630,79 +2484,18 @@ fn insert_implicit_conversions(
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
match s {
- Statement::Call(call) => insert_implicit_conversions_impl(
- &mut result,
- id_def,
- call,
- should_bitcast_wrapper,
- None,
- )?,
+ Statement::Call(call) => {
+ insert_implicit_conversions_impl(&mut result, id_def, call)?;
+ }
Statement::Instruction(inst) => {
- let mut default_conversion_fn =
- should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _;
- let mut state_space = None;
- if let ast::Instruction::Ld(d, _) = &inst {
- state_space = Some(d.state_space);
- }
- if let ast::Instruction::St(d, _) = &inst {
- state_space = Some(d.state_space.to_ld_ss());
- }
- if let ast::Instruction::Atom(d, _) = &inst {
- state_space = Some(d.space.to_ld_ss());
- }
- if let ast::Instruction::AtomCas(d, _) = &inst {
- state_space = Some(d.space.to_ld_ss());
- }
- if let ast::Instruction::Mov(..) = &inst {
- default_conversion_fn = should_bitcast_packed;
- }
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- inst,
- default_conversion_fn,
- state_space,
- )?;
+ insert_implicit_conversions_impl(&mut result, id_def, inst)?;
}
- Statement::PtrAccess(PtrAccess {
- underlying_type,
- state_space,
- dst,
- ptr_src,
- offset_src: constant_src,
- }) => {
- let visit_desc = VisitArgumentDescriptor {
- desc: ArgumentDescriptor {
- op: ptr_src,
- is_dst: false,
- sema: ArgumentSemantics::PhysicalPointer,
- },
- typ: &ast::Type::Pointer(underlying_type.clone(), state_space),
- stmt_ctor: |new_ptr_src| {
- Statement::PtrAccess(PtrAccess {
- underlying_type,
- state_space,
- dst,
- ptr_src: new_ptr_src,
- offset_src: constant_src,
- })
- },
- };
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- visit_desc,
- bitcast_physical_pointer,
- Some(state_space),
- )?;
+ Statement::PtrAccess(access) => {
+ insert_implicit_conversions_impl(&mut result, id_def, access)?;
+ }
+ Statement::RepackVector(repack) => {
+ insert_implicit_conversions_impl(&mut result, id_def, repack)?;
}
- Statement::RepackVector(repack) => insert_implicit_conversions_impl(
- &mut result,
- id_def,
- repack,
- should_bitcast_wrapper,
- None,
- )?,
s @ Statement::Conditional(_)
| s @ Statement::Conversion(_)
| s @ Statement::Label(_)
@@ -2720,72 +2513,56 @@ fn insert_implicit_conversions_impl(
func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
stmt: impl Visitable<ExpandedArgParams, ExpandedArgParams>,
- default_conversion_fn: for<'a> fn(
- &'a ast::Type,
- &'a ast::Type,
- Option<ast::LdStateSpace>,
- ) -> Result<Option<ConversionKind>, TranslateError>,
- state_space: Option<ast::LdStateSpace>,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
- let statement = stmt.visit(
- &mut |desc: ArgumentDescriptor<spirv::Word>, typ: Option<&ast::Type>| {
- let instr_type = match typ {
+ let statement =
+ stmt.visit(&mut |desc: ArgumentDescriptor<spirv::Word>,
+ typ: Option<(&ast::Type, ast::StateSpace)>| {
+ let (instr_type, instruction_space) = match typ {
None => return Ok(desc.op),
Some(t) => t,
};
- let operand_type = id_def.get_typed(desc.op)?;
- let mut conversion_fn = default_conversion_fn;
- match desc.sema {
- ArgumentSemantics::Default => {}
- ArgumentSemantics::DefaultRelaxed => {
- if desc.is_dst {
- conversion_fn = should_convert_relaxed_dst_wrapper;
- } else {
- conversion_fn = should_convert_relaxed_src_wrapper;
- }
- }
- ArgumentSemantics::PhysicalPointer => {
- conversion_fn = bitcast_physical_pointer;
- }
- ArgumentSemantics::RegisterPointer => {
- conversion_fn = bitcast_register_pointer;
- }
- ArgumentSemantics::Address => {
- conversion_fn = force_bitcast_ptr_to_bit;
- }
- };
- match conversion_fn(&operand_type, instr_type, state_space)? {
+ let (operand_type, operand_space) = id_def.get_typed(desc.op)?;
+ let conversion_fn = desc
+ .non_default_implicit_conversion
+ .unwrap_or(default_implicit_conversion);
+ match conversion_fn(
+ (operand_space, &operand_type),
+ (instruction_space, instr_type),
+ )? {
Some(conv_kind) => {
let conv_output = if desc.is_dst {
&mut post_conv
} else {
&mut *func
};
- let mut from = instr_type.clone();
- let mut to = operand_type;
- let mut src = id_def.new_non_variable(instr_type.clone());
+ let mut from_type = instr_type.clone();
+ let mut from_space = instruction_space;
+ let mut to_type = operand_type;
+ let mut to_space = operand_space;
+ let mut src =
+ id_def.register_intermediate(instr_type.clone(), instruction_space);
let mut dst = desc.op;
let result = Ok(src);
if !desc.is_dst {
mem::swap(&mut src, &mut dst);
- mem::swap(&mut from, &mut to);
+ mem::swap(&mut from_type, &mut to_type);
+ mem::swap(&mut from_space, &mut to_space);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
- from,
- to,
+ from_type,
+ from_space,
+ to_type,
+ to_space,
kind: conv_kind,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
}));
result
}
None => Ok(desc.op),
}
- },
- )?;
+ })?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
@@ -2794,17 +2571,15 @@ fn insert_implicit_conversions_impl(
fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- spirv_input: &[ast::Variable<ast::Type, spirv::Word>],
- spirv_output: &[ast::Variable<ast::Type, spirv::Word>],
+ spirv_input: impl Iterator<Item = SpirvType>,
+ spirv_output: &[ast::Variable<spirv::Word>],
) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn(
builder,
- spirv_input
- .iter()
- .map(|var| SpirvType::from(var.v_type.clone())),
+ spirv_input,
spirv_output
.iter()
- .map(|var| SpirvType::from(var.v_type.clone())),
+ .map(|var| SpirvType::new(var.v_type.clone())),
)
}
@@ -2831,20 +2606,25 @@ fn emit_function_body_ops(
match s {
Statement::Label(_) => (),
Statement::Call(call) => {
- let (result_type, result_id) = match &*call.ret_params {
- [(id, typ)] => (
- map.get_or_add(builder, SpirvType::from(typ.to_func_type())),
- Some(*id),
- ),
+ let (result_type, result_id) = match &*call.return_arguments {
+ [(id, typ, space)] => {
+ if *space != ast::StateSpace::Reg {
+ return Err(error_unreachable());
+ }
+ (
+ map.get_or_add(builder, SpirvType::new(typ.clone())),
+ Some(*id),
+ )
+ }
[] => (map.void(), None),
_ => todo!(),
};
let arg_list = call
- .param_list
+ .input_arguments
.iter()
- .map(|(id, _)| *id)
+ .map(|(id, _, _)| *id)
.collect::<Vec<_>>();
- builder.function_call(result_type, result_id, call.func, arg_list)?;
+ builder.function_call(result_type, result_id, call.name, arg_list)?;
}
Statement::Variable(var) => {
emit_variable(builder, map, var)?;
@@ -2966,7 +2746,7 @@ fn emit_function_body_ops(
todo!()
}
let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone())));
+ map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone())));
builder.load(
result_type,
Some(arg.dst),
@@ -2998,7 +2778,7 @@ fn emit_function_body_ops(
ast::Instruction::Ret(_) => builder.ret()?,
ast::Instruction::Mov(d, arg) => {
let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone())));
+ map.get_or_add(builder, SpirvType::new(ast::Type::from(d.typ.clone())));
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::Mul(mul, arg) => match mul {
@@ -3026,20 +2806,20 @@ fn emit_function_body_ops(
emit_setp(builder, map, setp, arg)?;
}
ast::Instruction::Not(t, a) => {
- let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
+ let result_type = map.get_or_add(builder, SpirvType::from(*t));
let result_id = Some(a.dst);
let operand = a.src;
match t {
- ast::BooleanType::Pred => {
+ ast::ScalarType::Pred => {
logical_not(builder, result_type, result_id, operand)
}
_ => builder.not(result_type, result_id, operand),
}?;
}
ast::Instruction::Shl(t, a) => {
- let full_type = t.to_type();
+ let full_type = ast::Type::Scalar(*t);
let size_of = full_type.size_of();
- let result_type = map.get_or_add(builder, SpirvType::from(full_type));
+ let result_type = map.get_or_add(builder, SpirvType::new(full_type));
let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?;
builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?;
}
@@ -3048,7 +2828,7 @@ fn emit_function_body_ops(
let size_of = full_type.size_of();
let result_type = map.get_or_add_scalar(builder, full_type);
let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?;
- if t.signed() {
+ if t.kind() == ast::ScalarKind::Signed {
builder.shift_right_arithmetic(
result_type,
Some(a.dst),
@@ -3088,7 +2868,7 @@ fn emit_function_body_ops(
},
ast::Instruction::Or(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- if *t == ast::BooleanType::Pred {
+ if *t == ast::ScalarType::Pred {
builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
@@ -3116,7 +2896,7 @@ fn emit_function_body_ops(
}
ast::Instruction::And(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- if *t == ast::BooleanType::Pred {
+ if *t == ast::ScalarType::Pred {
builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?;
@@ -3202,7 +2982,7 @@ fn emit_function_body_ops(
}
ast::Instruction::Neg(details, arg) => {
let result_type = map.get_or_add_scalar(builder, details.typ);
- let negate_func = if details.typ.kind() == ScalarKind::Float {
+ let negate_func = if details.typ.kind() == ast::ScalarKind::Float {
dr::Builder::f_negate
} else {
dr::Builder::s_negate
@@ -3269,7 +3049,7 @@ fn emit_function_body_ops(
}
ast::Instruction::Xor { typ, arg } => {
let builder_fn = match typ {
- ast::BooleanType::Pred => emit_logical_xor_spirv,
+ ast::ScalarType::Pred => emit_logical_xor_spirv,
_ => dr::Builder::bitwise_xor,
};
let result_type = map.get_or_add_scalar(builder, (*typ).into());
@@ -3284,7 +3064,7 @@ fn emit_function_body_ops(
return Err(error_unreachable());
}
ast::Instruction::Rem { typ, arg } => {
- let builder_fn = if typ.is_signed() {
+ let builder_fn = if typ.kind() == ast::ScalarKind::Signed {
dr::Builder::s_mod
} else {
dr::Builder::u_mod
@@ -3301,7 +3081,7 @@ fn emit_function_body_ops(
Some(index) => {
let result_ptr_type = map.get_or_add(
builder,
- SpirvType::new_pointer(
+ SpirvType::pointer_to(
details.typ.clone(),
spirv::StorageClass::Function,
),
@@ -3334,14 +3114,11 @@ fn emit_function_body_ops(
}) => {
let u8_pointer = map.get_or_add(
builder,
- SpirvType::from(ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- *state_space,
- )),
+ SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)),
);
let result_type = map.get_or_add(
builder,
- SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)),
+ SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)),
);
let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
let temp = builder.in_bounds_ptr_access_chain(
@@ -3553,11 +3330,16 @@ fn ptx_scope_name(scope: ast::MemScope) -> &'static str {
}
}
-fn ptx_space_name(space: ast::AtomSpace) -> &'static str {
+fn ptx_space_name(space: ast::StateSpace) -> &'static str {
match space {
- ast::AtomSpace::Generic => "generic",
- ast::AtomSpace::Global => "global",
- ast::AtomSpace::Shared => "shared",
+ ast::StateSpace::Generic => "generic",
+ ast::StateSpace::Global => "global",
+ ast::StateSpace::Shared => "shared",
+ ast::StateSpace::Reg => "reg",
+ ast::StateSpace::Const => "const",
+ ast::StateSpace::Local => "local",
+ ast::StateSpace::Param => "param",
+ ast::StateSpace::Sreg => "sreg",
}
}
@@ -3612,14 +3394,17 @@ fn vec_repr<T: Copy>(t: T) -> Vec<u8> {
fn emit_variable(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- var: &ast::Variable<ast::VariableType, spirv::Word>,
+ var: &ast::Variable<spirv::Word>,
) -> Result<(), TranslateError> {
- let (must_init, st_class) = match var.v_type {
- ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
+ let (must_init, st_class) = match var.state_space {
+ ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
(false, spirv::StorageClass::Function)
}
- ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
- ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup),
+ ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
+ ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
+ ast::StateSpace::Const => todo!(),
+ ast::StateSpace::Generic => todo!(),
+ ast::StateSpace::Sreg => todo!(),
};
let initalizer = if var.array_init.len() > 0 {
Some(map.get_or_add_constant(
@@ -3628,18 +3413,12 @@ fn emit_variable(
&*var.array_init,
)?)
} else if must_init {
- let type_id = map.get_or_add(
- builder,
- SpirvType::from(ast::Type::from(var.v_type.clone())),
- );
+ let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone()));
Some(builder.constant_null(type_id, None))
} else {
None
};
- let ptr_type_id = map.get_or_add(
- builder,
- SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class),
- );
+ let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class));
builder.variable(ptr_type_id, Some(var.name), st_class, initalizer);
if let Some(align) = var.align {
builder.decorate(
@@ -3777,7 +3556,7 @@ fn emit_min(
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
};
- let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
+ let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
@@ -3802,7 +3581,7 @@ fn emit_max(
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
};
- let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
+ let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
@@ -3882,7 +3661,7 @@ fn emit_cvt(
}
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- if desc.src.is_signed() {
+ if desc.src.kind() == ast::ScalarKind::Signed {
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
} else {
builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?;
@@ -3892,7 +3671,7 @@ fn emit_cvt(
ast::CvtDetails::IntFromFloat(desc) => {
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- if desc.dst.is_signed() {
+ if desc.dst.kind() == ast::ScalarKind::Signed {
builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?;
} else {
builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?;
@@ -3904,7 +3683,7 @@ fn emit_cvt(
let dest_t: ast::ScalarType = desc.dst.into();
let src_t: ast::ScalarType = desc.src.into();
// first do shortening/widening
- let src = if desc.dst.width() != desc.src.width() {
+ let src = if desc.dst.size_of() != desc.src.size_of() {
let new_dst = if dest_t.kind() == src_t.kind() {
arg.dst
} else {
@@ -3913,14 +3692,14 @@ fn emit_cvt(
let cv = ImplicitConversion {
src: arg.src,
dst: new_dst,
- from: ast::Type::Scalar(src_t),
- to: ast::Type::Scalar(ast::ScalarType::from_parts(
+ from_type: ast::Type::Scalar(src_t),
+ from_space: ast::StateSpace::Reg,
+ to_type: ast::Type::Scalar(ast::ScalarType::from_parts(
dest_t.size_of(),
src_t.kind(),
)),
+ to_space: ast::StateSpace::Reg,
kind: ConversionKind::Default,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
};
emit_implicit_conversion(builder, map, &cv)?;
new_dst
@@ -3933,7 +3712,7 @@ fn emit_cvt(
// now do actual conversion
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.saturate {
- if desc.dst.is_signed() {
+ if desc.dst.kind() == ast::ScalarKind::Signed {
builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?;
} else {
builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?;
@@ -3989,60 +3768,60 @@ fn emit_setp(
let operand_1 = arg.src1;
let operand_2 = arg.src2;
match (setp.cmp_op, setp.typ.kind()) {
- (ast::SetpCompareOp::Eq, ScalarKind::Signed)
- | (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Eq, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::Eq, ast::ScalarKind::Signed)
+ | (ast::SetpCompareOp::Eq, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::Eq, ast::ScalarKind::Bit) => {
builder.i_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Eq, ScalarKind::Float) => {
+ (ast::SetpCompareOp::Eq, ast::ScalarKind::Float) => {
builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::NotEq, ScalarKind::Signed)
- | (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::NotEq, ast::ScalarKind::Signed)
+ | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Bit) => {
builder.i_not_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
+ (ast::SetpCompareOp::NotEq, ast::ScalarKind::Float) => {
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Less, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Less, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::Less, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::Less, ast::ScalarKind::Bit) => {
builder.u_less_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Less, ScalarKind::Signed) => {
+ (ast::SetpCompareOp::Less, ast::ScalarKind::Signed) => {
builder.s_less_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Less, ScalarKind::Float) => {
+ (ast::SetpCompareOp::Less, ast::ScalarKind::Float) => {
builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Bit) => {
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
+ (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Signed) => {
builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => {
+ (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Float) => {
builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Greater, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::Greater, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::Greater, ast::ScalarKind::Bit) => {
builder.u_greater_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
+ (ast::SetpCompareOp::Greater, ast::ScalarKind::Signed) => {
builder.s_greater_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Greater, ScalarKind::Float) => {
+ (ast::SetpCompareOp::Greater, ast::ScalarKind::Float) => {
builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Bit) => {
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
+ (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Signed) => {
builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
+ (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Float) => {
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanEq, _) => {
@@ -4222,7 +4001,7 @@ fn emit_abs(
) -> Result<(), dr::Error> {
let scalar_t = ast::ScalarType::from(d.typ);
let result_type = map.get_or_add(builder, SpirvType::from(scalar_t));
- let cl_abs = if scalar_t.kind() == ScalarKind::Signed {
+ let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed {
spirv::CLOp::s_abs
} else {
spirv::CLOp::fabs
@@ -4272,22 +4051,21 @@ fn emit_implicit_conversion(
map: &mut TypeWordMap,
cv: &ImplicitConversion,
) -> Result<(), TranslateError> {
- let from_parts = cv.from.to_parts();
- let to_parts = cv.to.to_parts();
- match (from_parts.kind, to_parts.kind, cv.kind) {
- (_, _, ConversionKind::PtrToBit(typ)) => {
- let dst_type = map.get_or_add_scalar(builder, typ.into());
- builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
- }
- (_, _, ConversionKind::BitToPtr(_)) => {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ let from_parts = cv.from_type.to_parts();
+ let to_parts = cv.to_type.to_parts();
+ match (from_parts.kind, to_parts.kind, &cv.kind) {
+ (_, _, &ConversionKind::BitToPtr) => {
+ let dst_type = map.get_or_add(
+ builder,
+ SpirvType::pointer_to(cv.to_type.clone(), cv.to_space.to_spirv()),
+ );
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
- (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
+ (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => {
if from_parts.width == to_parts.width {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
- if from_parts.scalar_kind != ScalarKind::Float
- && to_parts.scalar_kind != ScalarKind::Float
+ let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ if from_parts.scalar_kind != ast::ScalarKind::Float
+ && to_parts.scalar_kind != ast::ScalarKind::Float
{
// It is noop, but another instruction expects result of this conversion
builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
@@ -4295,28 +4073,28 @@ fn emit_implicit_conversion(
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
}
} else {
- // This block is safe because it's illegal to implictly convert between floating point instructions
+ // This block is safe because it's illegal to implictly convert between floating point values
let same_width_bit_type = map.get_or_add(
builder,
- SpirvType::from(ast::Type::from_parts(TypeParts {
- scalar_kind: ScalarKind::Bit,
+ SpirvType::new(ast::Type::from_parts(TypeParts {
+ scalar_kind: ast::ScalarKind::Bit,
..from_parts
})),
);
let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
let wide_bit_type = ast::Type::from_parts(TypeParts {
- scalar_kind: ScalarKind::Bit,
+ scalar_kind: ast::ScalarKind::Bit,
..to_parts
});
let wide_bit_type_spirv =
- map.get_or_add(builder, SpirvType::from(wide_bit_type.clone()));
- if to_parts.scalar_kind == ScalarKind::Unsigned
- || to_parts.scalar_kind == ScalarKind::Bit
+ map.get_or_add(builder, SpirvType::new(wide_bit_type.clone()));
+ if to_parts.scalar_kind == ast::ScalarKind::Unsigned
+ || to_parts.scalar_kind == ast::ScalarKind::Bit
{
builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?;
} else {
- let conversion_fn = if from_parts.scalar_kind == ScalarKind::Signed
- && to_parts.scalar_kind == ScalarKind::Signed
+ let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed
+ && to_parts.scalar_kind == ast::ScalarKind::Signed
{
dr::Builder::s_convert
} else {
@@ -4330,40 +4108,48 @@ fn emit_implicit_conversion(
&ImplicitConversion {
src: wide_bit_value,
dst: cv.dst,
- from: wide_bit_type,
- to: cv.to.clone(),
+ from_type: wide_bit_type,
+ from_space: cv.from_space,
+ to_type: cv.to_type.clone(),
+ to_space: cv.to_space,
kind: ConversionKind::Default,
- src_sema: cv.src_sema,
- dst_sema: cv.dst_sema,
},
)?;
}
}
}
- (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => {
- let result_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
builder.s_convert(result_type, Some(cv.dst), cv.src)?;
}
- (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
- | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default)
- | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
- let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default)
+ | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default)
+ | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => {
+ let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
}
- (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => {
- let result_type = if spirv_ptr {
- map.get_or_add(
- builder,
- SpirvType::Pointer(
- Box::new(SpirvType::from(cv.to.clone())),
- spirv::StorageClass::Function,
- ),
- )
- } else {
- map.get_or_add(builder, SpirvType::from(cv.to.clone()))
- };
+ (_, _, &ConversionKind::PtrToPtr) => {
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(
+ Box::new(SpirvType::new(cv.to_type.clone())),
+ cv.to_space.to_spirv(),
+ ),
+ );
builder.bitcast(result_type, Some(cv.dst), cv.src)?;
}
+ (_, _, &ConversionKind::AddressOf) => {
+ let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
+ }
+ (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_ptr_to_u(result_type, Some(cv.dst), cv.src)?;
+ }
+ (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_u_to_ptr(result_type, Some(cv.dst), cv.src)?;
+ }
_ => unreachable!(),
}
Ok(())
@@ -4374,14 +4160,14 @@ fn emit_load_var(
map: &mut TypeWordMap,
details: &LoadVarDetails,
) -> Result<(), TranslateError> {
- let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone()));
+ let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone()));
match details.member_index {
Some((index, Some(width))) => {
let vector_type = match details.typ {
ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
_ => return Err(TranslateError::MismatchedType),
};
- let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type));
+ let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
let vector_temp = builder.load(
vector_type_spirv,
None,
@@ -4399,7 +4185,7 @@ fn emit_load_var(
Some((index, None)) => {
let result_ptr_type = map.get_or_add(
builder,
- SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function),
+ SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function),
);
let index_spirv = map.get_or_add_constant(
builder,
@@ -4427,10 +4213,10 @@ fn emit_load_var(
Ok(())
}
-fn normalize_identifiers<'a, 'b>(
- id_defs: &mut FnStringIdResolver<'a, 'b>,
- fn_defs: &GlobalFnDeclResolver<'a, 'b>,
- func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
+fn normalize_identifiers<'input, 'b>(
+ id_defs: &mut FnStringIdResolver<'input, 'b>,
+ fn_defs: &GlobalFnDeclResolver<'input, 'b>,
+ func: Vec<ast::Statement<ast::ParsedArgParams<'input>>>,
) -> Result<Vec<NormalizedStatement>, TranslateError> {
for s in func.iter() {
match s {
@@ -4468,48 +4254,28 @@ fn expand_map_variables<'a, 'b>(
i.map_variable(&mut |id| id_defs.get_id(id))?,
))),
ast::Statement::Variable(var) => {
- let mut var_type = ast::Type::from(var.var.v_type.clone());
- let mut is_variable = false;
- var_type = match var.var.v_type {
- ast::VariableType::Reg(_) => {
- is_variable = true;
- var_type
- }
- ast::VariableType::Shared(_) => {
- // If it's a pointer it will be translated to a method parameter later
- if let ast::Type::Pointer(..) = var_type {
- is_variable = true;
- var_type
- } else {
- var_type.param_pointer_to(ast::LdStateSpace::Shared)?
- }
- }
- ast::VariableType::Global(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Global)?
- }
- ast::VariableType::Param(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Param)?
- }
- ast::VariableType::Local(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Local)?
- }
- };
+ let var_type = var.var.v_type.clone();
match var.count {
Some(count) => {
- for new_id in id_defs.add_defs(var.var.name, count, var_type, is_variable) {
+ for new_id in
+ id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true)
+ {
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init.clone(),
}))
}
}
None => {
- let new_id = id_defs.add_def(var.var.name, Some(var_type), is_variable);
+ let new_id =
+ id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true);
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init,
}));
@@ -4520,18 +4286,62 @@ fn expand_map_variables<'a, 'b>(
Ok(())
}
+/*
+ Our goal here is to transform
+ .visible .entry foobar(.param .u64 input) {
+ .reg .b64 in_addr;
+ .reg .b64 in_addr2;
+ ld.param.u64 in_addr, [input];
+ cvta.to.global.u64 in_addr2, in_addr;
+ }
+ into:
+ .visible .entry foobar(.param .u8 input[]) {
+ .reg .u8 in_addr[];
+ .reg .u8 in_addr2[];
+ ld.param.u8[] in_addr, [input];
+ mov.u8[] in_addr2, in_addr;
+ }
+ or:
+ .visible .entry foobar(.reg .u8 input[]) {
+ .reg .u8 in_addr[];
+ .reg .u8 in_addr2[];
+ mov.u8[] in_addr, input;
+ mov.u8[] in_addr2, in_addr;
+ }
+ or:
+ .visible .entry foobar(.param ptr<u8, global> input) {
+ .reg ptr<u8, global> in_addr;
+ .reg ptr<u8, global> in_addr2;
+ ld.param.ptr<u8, global> in_addr, [input];
+ mov.ptr<u8, global> in_addr2, in_addr;
+ }
+*/
// TODO: detect more patterns (mov, call via reg, call via param)
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
-// TODO: propagate through calls?
-fn convert_to_stateful_memory_access<'a>(
- func_args: &mut SpirvMethodDecl,
+// TODO: propagate out of calls and into calls
+fn convert_to_stateful_memory_access<'a, 'input>(
+ func_args: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
func_body: Vec<TypedStatement>,
id_defs: &mut NumericIdResolver<'a>,
-) -> Result<Vec<TypedStatement>, TranslateError> {
- let func_args_64bit = func_args
- .input
+) -> Result<
+ (
+ Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
+ Vec<TypedStatement>,
+ ),
+ TranslateError,
+> {
+ let mut method_decl = func_args.borrow_mut();
+ if !method_decl.name.is_kernel() {
+ drop(method_decl);
+ return Ok((func_args, func_body));
+ }
+ if Rc::strong_count(&func_args) != 1 {
+ return Err(error_unreachable());
+ }
+ let func_args_64bit = (*method_decl)
+ .input_arguments
.iter()
.filter_map(|arg| match arg.v_type {
ast::Type::Scalar(ast::ScalarType::U64)
@@ -4546,9 +4356,9 @@ fn convert_to_stateful_memory_access<'a>(
match statement {
Statement::Instruction(ast::Instruction::Cvta(
ast::CvtaDetails {
- to: ast::CvtaStateSpace::Global,
+ to: ast::StateSpace::Global,
size: ast::CvtaSize::U64,
- from: ast::CvtaStateSpace::Generic,
+ from: ast::StateSpace::Generic,
},
arg,
)) => {
@@ -4562,24 +4372,24 @@ fn convert_to_stateful_memory_access<'a>(
}
Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::LdStType::Scalar(ast::LdStScalarType::U64),
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::U64),
..
},
arg,
))
| Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::LdStType::Scalar(ast::LdStScalarType::S64),
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::S64),
..
},
arg,
))
| Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::LdStType::Scalar(ast::LdStScalarType::B64),
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::B64),
..
},
arg,
@@ -4595,6 +4405,10 @@ fn convert_to_stateful_memory_access<'a>(
_ => {}
}
}
+ if stateful_markers.len() == 0 {
+ drop(method_decl);
+ return Ok((func_args, func_body));
+ }
let mut func_args_ptr = HashSet::new();
let mut regs_ptr_current = HashSet::new();
for (dst, src) in stateful_markers {
@@ -4614,23 +4428,23 @@ fn convert_to_stateful_memory_access<'a>(
for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Add(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
+ typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
))
| Statement::Instruction(ast::Instruction::Sub(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
+ typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
@@ -4661,21 +4475,32 @@ fn convert_to_stateful_memory_access<'a>(
let mut remapped_ids = HashMap::new();
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
for reg in regs_ptr_seen {
- let new_id = id_defs.new_variable(ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ));
+ let new_id = id_defs.register_variable(
+ ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ ast::StateSpace::Reg,
+ );
result.push(Statement::Variable(ast::Variable {
align: None,
name: new_id,
array_init: Vec::new(),
- v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
- ast::SizedScalarType::U8,
- ast::PointerStateSpace::Global,
- )),
+ v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ state_space: ast::StateSpace::Reg,
}));
remapped_ids.insert(reg, new_id);
}
+ for arg in (*method_decl).input_arguments.iter_mut() {
+ if !func_args_ptr.contains(&arg.name) {
+ continue;
+ }
+ let new_id = id_defs.register_variable(
+ ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ ast::StateSpace::Param,
+ );
+ let old_name = arg.name;
+ arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
+ arg.name = new_id;
+ remapped_ids.insert(old_name, new_id);
+ }
for statement in func_body {
match statement {
l @ Statement::Label(_) => result.push(l),
@@ -4686,12 +4511,12 @@ fn convert_to_stateful_memory_access<'a>(
}
}
Statement::Instruction(ast::Instruction::Add(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
+ typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
@@ -4707,20 +4532,20 @@ fn convert_to_stateful_memory_access<'a>(
};
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
- state_space: ast::LdStateSpace::Global,
+ underlying_type: ast::ScalarType::U8,
+ state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: offset,
}))
}
Statement::Instruction(ast::Instruction::Sub(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
+ typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
@@ -4734,8 +4559,10 @@ fn convert_to_stateful_memory_access<'a>(
}
_ => return Err(error_unreachable()),
};
- let offset_neg =
- id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64)));
+ let offset_neg = id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ )));
result.push(Statement::Instruction(ast::Instruction::Neg(
ast::NegDetails {
typ: ast::ScalarType::S64,
@@ -4748,8 +4575,8 @@ fn convert_to_stateful_memory_access<'a>(
)));
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
- state_space: ast::LdStateSpace::Global,
+ underlying_type: ast::ScalarType::U8,
+ state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: TypedOperand::Reg(offset_neg),
@@ -4757,151 +4584,116 @@ fn convert_to_stateful_memory_access<'a>(
}
Statement::Instruction(inst) => {
let mut post_statements = Vec::new();
- let new_statement = inst.visit(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>| {
+ let new_statement =
+ inst.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
- &func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
expected_type,
)
- },
- )?;
+ })?;
result.push(new_statement);
result.extend(post_statements);
}
Statement::Call(call) => {
let mut post_statements = Vec::new();
- let new_statement = call.visit(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>| {
+ let new_statement =
+ call.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
- &func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
expected_type,
)
- },
- )?;
+ })?;
result.push(new_statement);
result.extend(post_statements);
}
Statement::RepackVector(pack) => {
let mut post_statements = Vec::new();
- let new_statement = pack.visit(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>| {
+ let new_statement =
+ pack.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
- &func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
expected_type,
)
- },
- )?;
+ })?;
result.push(new_statement);
result.extend(post_statements);
}
_ => return Err(error_unreachable()),
}
}
- for arg in func_args.input.iter_mut() {
- if func_args_ptr.contains(&arg.name) {
- arg.v_type = ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- );
- }
- }
- Ok(result)
+ drop(method_decl);
+ Ok((func_args, result))
}
fn convert_to_stateful_memory_access_postprocess(
id_defs: &mut NumericIdResolver,
remapped_ids: &HashMap<spirv::Word, spirv::Word>,
- func_args_ptr: &HashSet<spirv::Word>,
result: &mut Vec<TypedStatement>,
post_statements: &mut Vec<TypedStatement>,
arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>,
+ expected_type: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(match remapped_ids.get(&arg_desc.op) {
Some(new_id) => {
- // We skip conversion here to trigger PtrAcces in a later pass
- let old_type = match expected_type {
- Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
- _ => id_defs.get_typed(arg_desc.op)?.0,
+ let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
+ if let Some((expected_type, expected_space)) = expected_type {
+ let implicit_conversion = arg_desc
+ .non_default_implicit_conversion
+ .unwrap_or(default_implicit_conversion);
+ if implicit_conversion(
+ (new_operand_space, &new_operand_type),
+ (expected_space, expected_type),
+ )
+ .is_ok()
+ {
+ return Ok(*new_id);
+ }
+ }
+ let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?;
+ let converting_id =
+ id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
+ let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) {
+ ConversionKind::Default
+ } else {
+ ConversionKind::PtrToPtr
};
- let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type_clone));
if arg_desc.is_dst {
post_statements.push(Statement::Conversion(ImplicitConversion {
src: converting_id,
dst: *new_id,
- from: old_type,
- to: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global),
- src_sema: ArgumentSemantics::Default,
- dst_sema: arg_desc.sema,
+ from_type: old_operand_type,
+ from_space: old_operand_space,
+ to_type: new_operand_type,
+ to_space: new_operand_space,
+ kind,
}));
converting_id
} else {
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- to: old_type,
- kind: ConversionKind::PtrToBit(ast::UIntType::U64),
- src_sema: arg_desc.sema,
- dst_sema: ArgumentSemantics::Default,
+ from_type: new_operand_type,
+ from_space: new_operand_space,
+ to_type: old_operand_type,
+ to_space: old_operand_space,
+ kind,
}));
converting_id
}
}
- None => match func_args_ptr.get(&arg_desc.op) {
- Some(new_id) => {
- if arg_desc.is_dst {
- return Err(error_unreachable());
- }
- // We skip conversion here to trigger PtrAcces in a later pass
- let old_type = match expected_type {
- Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
- _ => id_defs.get_typed(arg_desc.op)?.0,
- };
- let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type));
- result.push(Statement::Conversion(ImplicitConversion {
- src: *new_id,
- dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global),
- ast::LdStateSpace::Param,
- ),
- to: old_type_clone,
- kind: ConversionKind::PtrToPtr { spirv_ptr: false },
- src_sema: arg_desc.sema,
- dst_sema: ArgumentSemantics::Default,
- }));
- converting_id
- }
- None => arg_desc.op,
- },
+ None => arg_desc.op,
})
}
@@ -4925,9 +4717,9 @@ fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgP
fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
match id_defs.get_typed(id) {
- Ok((ast::Type::Scalar(ast::ScalarType::U64), _))
- | Ok((ast::Type::Scalar(ast::ScalarType::S64), _))
- | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true,
+ Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true,
_ => false,
}
}
@@ -5055,20 +4847,95 @@ impl SpecialRegistersMap {
}
}
+struct FnSigMapper<'input> {
+ // true - stays as return argument
+ // false - is moved to input argument
+ return_param_args: Vec<bool>,
+ func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
+}
+
+impl<'input> FnSigMapper<'input> {
+ fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, spirv::Word>) -> Self {
+ let return_param_args = method
+ .return_arguments
+ .iter()
+ .map(|a| a.state_space != ast::StateSpace::Param)
+ .collect::<Vec<_>>();
+ let mut new_return_arguments = Vec::new();
+ for arg in method.return_arguments.into_iter() {
+ if arg.state_space == ast::StateSpace::Param {
+ method.input_arguments.push(arg);
+ } else {
+ new_return_arguments.push(arg);
+ }
+ }
+ method.return_arguments = new_return_arguments;
+ FnSigMapper {
+ return_param_args,
+ func_decl: Rc::new(RefCell::new(method)),
+ }
+ }
+
+ fn resolve_in_spirv_repr(
+ &self,
+ call_inst: ast::CallInst<NormalizedArgParams>,
+ ) -> Result<ResolvedCall<NormalizedArgParams>, TranslateError> {
+ let func_decl = (*self.func_decl).borrow();
+ let mut return_arguments = Vec::new();
+ let mut input_arguments = call_inst
+ .param_list
+ .into_iter()
+ .zip(func_decl.input_arguments.iter())
+ .map(|(id, var)| (id, var.v_type.clone(), var.state_space))
+ .collect::<Vec<_>>();
+ let mut func_decl_return_iter = func_decl.return_arguments.iter();
+ let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter();
+ for (idx, id) in call_inst.ret_params.iter().enumerate() {
+ let stays_as_return = match self.return_param_args.get(idx) {
+ Some(x) => *x,
+ None => return Err(TranslateError::MismatchedType),
+ };
+ if stays_as_return {
+ if let Some(var) = func_decl_return_iter.next() {
+ return_arguments.push((*id, var.v_type.clone(), var.state_space));
+ } else {
+ return Err(TranslateError::MismatchedType);
+ }
+ } else {
+ if let Some(var) = func_decl_input_iter.next() {
+ input_arguments.push((
+ ast::Operand::Reg(*id),
+ var.v_type.clone(),
+ var.state_space,
+ ));
+ } else {
+ return Err(TranslateError::MismatchedType);
+ }
+ }
+ }
+ if return_arguments.len() != func_decl.return_arguments.len()
+ || input_arguments.len() != func_decl.input_arguments.len()
+ {
+ return Err(TranslateError::MismatchedType);
+ }
+ Ok(ResolvedCall {
+ return_arguments,
+ input_arguments,
+ uniform: call_inst.uniform,
+ name: call_inst.func,
+ })
+ }
+}
+
struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
- variables_type_check: HashMap<u32, Option<(ast::Type, bool)>>,
+ variables_type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: SpecialRegistersMap,
- fns: HashMap<spirv::Word, FnDecl>,
-}
-
-pub struct FnDecl {
- ret_vals: Vec<ast::FnArgumentType>,
- params: Vec<ast::FnArgumentType>,
+ fns: HashMap<spirv::Word, FnSigMapper<'input>>,
}
-impl<'a> GlobalStringIdResolver<'a> {
+impl<'input> GlobalStringIdResolver<'input> {
fn new(start_id: spirv::Word) -> Self {
Self {
current_id: start_id,
@@ -5079,20 +4946,25 @@ impl<'a> GlobalStringIdResolver<'a> {
}
}
- fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word {
+ fn get_or_add_def(&mut self, id: &'input str) -> spirv::Word {
self.get_or_add_impl(id, None)
}
fn get_or_add_def_typed(
&mut self,
- id: &'a str,
+ id: &'input str,
typ: ast::Type,
+ state_space: ast::StateSpace,
is_variable: bool,
) -> spirv::Word {
- self.get_or_add_impl(id, Some((typ, is_variable)))
+ self.get_or_add_impl(id, Some((typ, state_space, is_variable)))
}
- fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word {
+ fn get_or_add_impl(
+ &mut self,
+ id: &'input str,
+ typ: Option<(ast::Type, ast::StateSpace, bool)>,
+ ) -> spirv::Word {
let id = match self.variables.entry(Cow::Borrowed(id)) {
hash_map::Entry::Occupied(e) => *(e.get()),
hash_map::Entry::Vacant(e) => {
@@ -5119,12 +4991,12 @@ impl<'a> GlobalStringIdResolver<'a> {
fn start_fn<'b>(
&'b mut self,
- header: &'b ast::MethodDecl<'a, &'a str>,
+ header: &'b ast::MethodDeclaration<'input, &'input str>,
) -> Result<
(
- FnStringIdResolver<'a, 'b>,
- GlobalFnDeclResolver<'a, 'b>,
- ast::MethodDecl<'a, spirv::Word>,
+ FnStringIdResolver<'input, 'b>,
+ GlobalFnDeclResolver<'input, 'b>,
+ Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
),
TranslateError,
> {
@@ -5138,60 +5010,51 @@ impl<'a> GlobalStringIdResolver<'a> {
variables: vec![HashMap::new(); 1],
type_check: HashMap::new(),
};
- let new_fn_decl = match header {
- ast::MethodDecl::Kernel { name, in_args } => ast::MethodDecl::Kernel {
- name,
- in_args: expand_kernel_params(&mut fn_resolver, in_args.iter())?,
- },
- ast::MethodDecl::Func(ret_params, _, params) => {
- let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter())?;
- let params_ids = expand_fn_params(&mut fn_resolver, params.iter())?;
- self.fns.insert(
- name_id,
- FnDecl {
- ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(),
- params: params_ids.iter().map(|p| p.v_type.clone()).collect(),
- },
- );
- ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
- }
+ let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments);
+ let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments);
+ let name = match header.name {
+ ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
+ ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
+ };
+ let fn_decl = ast::MethodDeclaration {
+ return_arguments,
+ name,
+ input_arguments,
+ shared_mem: None,
+ };
+ let new_fn_decl = if !fn_decl.name.is_kernel() {
+ let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl);
+ let new_fn_decl = resolver.func_decl.clone();
+ self.fns.insert(name_id, resolver);
+ new_fn_decl
+ } else {
+ Rc::new(RefCell::new(fn_decl))
};
Ok((
fn_resolver,
- GlobalFnDeclResolver {
- variables: &self.variables,
- fns: &self.fns,
- },
+ GlobalFnDeclResolver { fns: &self.fns },
new_fn_decl,
))
}
}
pub struct GlobalFnDeclResolver<'input, 'a> {
- variables: &'a HashMap<Cow<'input, str>, spirv::Word>,
- fns: &'a HashMap<spirv::Word, FnDecl>,
+ fns: &'a HashMap<spirv::Word, FnSigMapper<'input>>,
}
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
- fn get_fn_decl(&self, id: spirv::Word) -> Result<&FnDecl, TranslateError> {
+ fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> {
self.fns.get(&id).ok_or(TranslateError::UnknownSymbol)
}
-
- fn get_fn_decl_str(&self, id: &str) -> Result<&'a FnDecl, TranslateError> {
- match self.variables.get(id).map(|var_id| self.fns.get(var_id)) {
- Some(Some(fn_d)) => Ok(fn_d),
- _ => Err(TranslateError::UnknownSymbol),
- }
- }
}
struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut spirv::Word,
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
- global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: &'b mut SpecialRegistersMap,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
- type_check: HashMap<u32, Option<(ast::Type, bool)>>,
+ type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
@@ -5229,14 +5092,21 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
}
}
- fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>, is_variable: bool) -> spirv::Word {
+ fn add_def(
+ &mut self,
+ id: &'a str,
+ typ: Option<(ast::Type, ast::StateSpace)>,
+ is_variable: bool,
+ ) -> spirv::Word {
let numeric_id = *self.current_id;
self.variables
.last_mut()
.unwrap()
.insert(Cow::Borrowed(id), numeric_id);
- self.type_check
- .insert(numeric_id, typ.map(|t| (t, is_variable)));
+ self.type_check.insert(
+ numeric_id,
+ typ.map(|(typ, space)| (typ, space, is_variable)),
+ );
*self.current_id += 1;
numeric_id
}
@@ -5247,6 +5117,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
base_id: &'a str,
count: u32,
typ: ast::Type,
+ state_space: ast::StateSpace,
is_variable: bool,
) -> impl Iterator<Item = spirv::Word> {
let numeric_id = *self.current_id;
@@ -5255,8 +5126,10 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
.last_mut()
.unwrap()
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
- self.type_check
- .insert(numeric_id + i, Some((typ.clone(), is_variable)));
+ self.type_check.insert(
+ numeric_id + i,
+ Some((typ.clone(), state_space, is_variable)),
+ );
}
*self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
@@ -5265,8 +5138,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
- global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
- type_check: HashMap<u32, Option<(ast::Type, bool)>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
+ type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: &'b mut SpecialRegistersMap,
}
@@ -5275,12 +5148,15 @@ impl<'b> NumericIdResolver<'b> {
MutableNumericIdResolver { base: self }
}
- fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, bool), TranslateError> {
+ fn get_typed(
+ &self,
+ id: spirv::Word,
+ ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
match self.type_check.get(&id) {
Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(id) {
- Some(x) => Ok((x.get_type(), true)),
+ Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
None => match self.global_type_check.get(&id) {
Some(Some(result)) => Ok(result.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol),
@@ -5291,16 +5167,18 @@ impl<'b> NumericIdResolver<'b> {
// This is for identifiers which will be emitted later as OpVariable
// They are candidates for insertion of LoadVar/StoreVar
- fn new_variable(&mut self, typ: ast::Type) -> spirv::Word {
+ fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> spirv::Word {
let new_id = *self.current_id;
- self.type_check.insert(new_id, Some((typ, true)));
+ self.type_check
+ .insert(new_id, Some((typ, state_space, true)));
*self.current_id += 1;
new_id
}
- fn new_non_variable(&mut self, typ: Option<ast::Type>) -> spirv::Word {
+ fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> spirv::Word {
let new_id = *self.current_id;
- self.type_check.insert(new_id, typ.map(|t| (t, false)));
+ self.type_check
+ .insert(new_id, typ.map(|(t, space)| (t, space, false)));
*self.current_id += 1;
new_id
}
@@ -5315,18 +5193,22 @@ impl<'b> MutableNumericIdResolver<'b> {
self.base
}
- fn get_typed(&self, id: spirv::Word) -> Result<ast::Type, TranslateError> {
- self.base.get_typed(id).map(|(t, _)| t)
+ fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, ast::StateSpace), TranslateError> {
+ self.base.get_typed(id).map(|(t, space, _)| (t, space))
}
- fn new_non_variable(&mut self, typ: ast::Type) -> spirv::Word {
- self.base.new_non_variable(Some(typ))
+ fn register_intermediate(
+ &mut self,
+ typ: ast::Type,
+ state_space: ast::StateSpace,
+ ) -> spirv::Word {
+ self.base.register_intermediate(Some((typ, state_space)))
}
}
enum Statement<I, P: ast::ArgParams> {
Label(u32),
- Variable(ast::Variable<ast::VariableType, P::Id>),
+ Variable(ast::Variable<P::Id>),
Instruction(I),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
@@ -5349,7 +5231,8 @@ impl ExpandedStatement {
Statement::Variable(var)
}
Statement::Instruction(inst) => inst
- .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| {
+ .visit(&mut |arg: ArgumentDescriptor<_>,
+ _: Option<(&ast::Type, ast::StateSpace)>| {
Ok(f(arg.op, arg.is_dst))
})
.unwrap(),
@@ -5364,16 +5247,17 @@ impl ExpandedStatement {
Statement::StoreVar(details)
}
Statement::Call(mut call) => {
- for (id, typ) in call.ret_params.iter_mut() {
- let is_dst = match typ {
- ast::FnArgumentType::Reg(_) => true,
- ast::FnArgumentType::Param(_) => false,
- ast::FnArgumentType::Shared => false,
+ for (id, _, space) in call.return_arguments.iter_mut() {
+ let is_dst = match space {
+ ast::StateSpace::Reg => true,
+ ast::StateSpace::Param => false,
+ ast::StateSpace::Shared => false,
+ _ => todo!(),
};
*id = f(*id, is_dst);
}
- call.func = f(call.func, false);
- for (id, _) in call.param_list.iter_mut() {
+ call.name = f(call.name, false);
+ for (id, _, _) in call.input_arguments.iter_mut() {
*id = f(*id, false);
}
Statement::Call(call)
@@ -5435,6 +5319,7 @@ impl ExpandedStatement {
struct LoadVarDetails {
arg: ast::Arg2<ExpandedArgParams>,
typ: ast::Type,
+ state_space: ast::StateSpace,
// (index, vector_width)
// HACK ALERT
// For some reason IGC explodes when you try to load from builtin vectors
@@ -5454,7 +5339,12 @@ struct RepackVectorDetails {
typ: ast::ScalarType,
packed: spirv::Word,
unpacked: Vec<spirv::Word>,
- vector_sema: ArgumentSemantics,
+ non_default_implicit_conversion: Option<
+ fn(
+ (ast::StateSpace, &ast::Type),
+ (ast::StateSpace, &ast::Type),
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ >,
}
impl RepackVectorDetails {
@@ -5470,13 +5360,17 @@ impl RepackVectorDetails {
ArgumentDescriptor {
op: self.packed,
is_dst: !self.is_extract,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)),
+ Some((
+ &ast::Type::Vector(self.typ, self.unpacked.len() as u8),
+ ast::StateSpace::Reg,
+ )),
)?;
let scalar_type = self.typ;
let is_extract = self.is_extract;
- let vector_sema = self.vector_sema;
+ let non_default_implicit_conversion = self.non_default_implicit_conversion;
let vector = self
.unpacked
.into_iter()
@@ -5485,9 +5379,10 @@ impl RepackVectorDetails {
ArgumentDescriptor {
op: id,
is_dst: is_extract,
- sema: vector_sema,
+ is_memory_access: false,
+ non_default_implicit_conversion,
},
- Some(&ast::Type::Scalar(scalar_type)),
+ Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)),
)
})
.collect::<Result<_, _>>()?;
@@ -5496,7 +5391,7 @@ impl RepackVectorDetails {
typ: self.typ,
packed: scalar,
unpacked: vector,
- vector_sema,
+ non_default_implicit_conversion,
})
}
}
@@ -5514,18 +5409,18 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab
struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
- pub ret_params: Vec<(P::Id, ast::FnArgumentType)>,
- pub func: P::Id,
- pub param_list: Vec<(P::Operand, ast::FnArgumentType)>,
+ pub return_arguments: Vec<(P::Id, ast::Type, ast::StateSpace)>,
+ pub name: P::Id,
+ pub input_arguments: Vec<(P::Operand, ast::Type, ast::StateSpace)>,
}
impl<T: ast::ArgParams> ResolvedCall<T> {
fn cast<U: ast::ArgParams<Id = T::Id, Operand = T::Operand>>(self) -> ResolvedCall<U> {
ResolvedCall {
uniform: self.uniform,
- ret_params: self.ret_params,
- func: self.func,
- param_list: self.param_list,
+ return_arguments: self.return_arguments,
+ name: self.name,
+ input_arguments: self.input_arguments,
}
}
}
@@ -5535,49 +5430,53 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
self,
visitor: &mut V,
) -> Result<ResolvedCall<To>, TranslateError> {
- let ret_params = self
- .ret_params
+ let return_arguments = self
+ .return_arguments
.into_iter()
- .map::<Result<_, TranslateError>, _>(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.id(
ArgumentDescriptor {
op: id,
- is_dst: !typ.is_param(),
- sema: typ.semantics(),
+ is_dst: space != ast::StateSpace::Param,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&typ.to_func_type()),
+ Some((&typ, space)),
)?;
- Ok((new_id, typ))
+ Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
let func = visitor.id(
ArgumentDescriptor {
- op: self.func,
+ op: self.name,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
None,
)?;
- let param_list = self
- .param_list
+ let input_arguments = self
+ .input_arguments
.into_iter()
- .map::<Result<_, TranslateError>, _>(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.operand(
ArgumentDescriptor {
op: id,
is_dst: false,
- sema: typ.semantics(),
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- &typ.to_func_type(),
+ &typ,
+ space,
)?;
- Ok((new_id, typ))
+ Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedCall {
uniform: self.uniform,
- ret_params,
- func,
- param_list,
+ return_arguments,
+ name: func,
+ input_arguments,
})
}
}
@@ -5598,39 +5497,34 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
self,
visitor: &mut V,
) -> Result<PtrAccess<To>, TranslateError> {
- let sema = match self.state_space {
- ast::LdStateSpace::Const
- | ast::LdStateSpace::Global
- | ast::LdStateSpace::Shared
- | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
- ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
- ArgumentSemantics::RegisterPointer
- }
- };
- let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space);
+ let ptr_type = ast::Type::Scalar(self.underlying_type.clone());
let new_dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ptr_type),
+ Some((&ptr_type, self.state_space)),
)?;
let new_ptr_src = visitor.id(
ArgumentDescriptor {
op: self.ptr_src,
is_dst: false,
- sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ptr_type),
+ Some((&ptr_type, self.state_space)),
)?;
let new_constant_src = visitor.operand(
ArgumentDescriptor {
op: self.offset_src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
)?;
Ok(PtrAccess {
underlying_type: self.underlying_type,
@@ -5653,21 +5547,9 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab
}
}
-pub trait ArgParamsEx: ast::ArgParams + Sized {
- fn get_fn_decl<'x, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'x, 'b>,
- ) -> Result<&'b FnDecl, TranslateError>;
-}
+pub trait ArgParamsEx: ast::ArgParams + Sized {}
-impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
- fn get_fn_decl<'x, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'x, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl_str(id)
- }
-}
+impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {}
enum NormalizedArgParams {}
@@ -5676,14 +5558,7 @@ impl ast::ArgParams for NormalizedArgParams {
type Operand = ast::Operand<spirv::Word>;
}
-impl ArgParamsEx for NormalizedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for NormalizedArgParams {}
type NormalizedStatement = Statement<
(
@@ -5702,14 +5577,7 @@ impl ast::ArgParams for TypedArgParams {
type Operand = TypedOperand;
}
-impl ArgParamsEx for TypedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for TypedArgParams {}
#[derive(Copy, Clone)]
enum TypedOperand {
@@ -5740,24 +5608,16 @@ impl ast::ArgParams for ExpandedArgParams {
type Operand = spirv::Word;
}
-impl ArgParamsEx for ExpandedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for ExpandedArgParams {}
enum Directive<'input> {
- Variable(ast::Variable<ast::VariableType, spirv::Word>),
+ Variable(ast::LinkingDirective, ast::Variable<spirv::Word>),
Method(Function<'input>),
}
struct Function<'input> {
- pub func_decl: ast::MethodDecl<'input, spirv::Word>,
- pub spirv_decl: SpirvMethodDecl<'input>,
- pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
+ pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
+ pub globals: Vec<ast::Variable<spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
@@ -5767,12 +5627,13 @@ pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
fn id(
&mut self,
desc: ArgumentDescriptor<T::Id>,
- typ: Option<&ast::Type>,
+ typ: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<U::Id, TranslateError>;
fn operand(
&mut self,
desc: ArgumentDescriptor<T::Operand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<U::Operand, TranslateError>;
}
@@ -5780,13 +5641,13 @@ impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
+ Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -5795,8 +5656,9 @@ where
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
- self(desc, Some(typ))
+ self(desc, Some((typ, state_space)))
}
}
@@ -5807,7 +5669,7 @@ where
fn id(
&mut self,
desc: ArgumentDescriptor<&str>,
- _: Option<&ast::Type>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self(desc.op)
}
@@ -5816,6 +5678,7 @@ where
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?),
@@ -5824,7 +5687,7 @@ where
ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member),
ast::Operand::VecPack(ref ids) => ast::Operand::VecPack(
ids.into_iter()
- .map(|id| self.id(desc.new_op(id), Some(typ)))
+ .map(|id| self.id(desc.new_op(id), Some((typ, state_space))))
.collect::<Result<Vec<_>, _>>()?,
),
})
@@ -5834,37 +5697,30 @@ where
pub struct ArgumentDescriptor<Op> {
op: Op,
is_dst: bool,
- sema: ArgumentSemantics,
+ is_memory_access: bool,
+ non_default_implicit_conversion: Option<
+ fn(
+ (ast::StateSpace, &ast::Type),
+ (ast::StateSpace, &ast::Type),
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ >,
}
pub struct PtrAccess<P: ast::ArgParams> {
- underlying_type: ast::PointerType,
- state_space: ast::LdStateSpace,
+ underlying_type: ast::ScalarType,
+ state_space: ast::StateSpace,
dst: spirv::Word,
ptr_src: spirv::Word,
offset_src: P::Operand,
}
-#[derive(Copy, Clone, PartialEq, Eq, Debug)]
-pub enum ArgumentSemantics {
- // normal register access
- Default,
- // normal register access with relaxed conversion rules (ld/st)
- DefaultRelaxed,
- // st/ld global
- PhysicalPointer,
- // st/ld .param, .local
- RegisterPointer,
- // mov of .local/.global variables
- Address,
-}
-
impl<T> ArgumentDescriptor<T> {
fn new_op<U>(&self, u: U) -> ArgumentDescriptor<U> {
ArgumentDescriptor {
op: u,
is_dst: self.is_dst,
- sema: self.sema,
+ is_memory_access: self.is_memory_access,
+ non_default_implicit_conversion: self.non_default_implicit_conversion,
}
}
}
@@ -5905,7 +5761,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
let inst_type = d.typ;
ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
- ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?),
+ ast::Instruction::Not(t, a) => {
+ ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?)
+ }
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (
@@ -5928,7 +5786,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?)
}
ast::Instruction::Shl(t, a) => {
- ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?)
+ ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?)
}
ast::Instruction::Shr(t, a) => {
ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?)
@@ -6101,17 +5959,19 @@ impl ImplicitConversion {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: self.dst_sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&self.to),
+ Some((&self.to_type, self.to_space)),
)?;
let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: self.src_sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&self.from),
+ Some((&self.from_type, self.from_space)),
)?;
Ok(Statement::Conversion({
ImplicitConversion {
@@ -6138,13 +5998,13 @@ impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T
where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
+ Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -6153,12 +6013,15 @@ where
&mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
- TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?),
+ TypedOperand::Reg(id) => {
+ TypedOperand::Reg(self(desc.new_op(id), Some((typ, state_space)))?)
+ }
TypedOperand::Imm(imm) => TypedOperand::Imm(imm),
TypedOperand::RegOffset(id, imm) => {
- TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm)
+ TypedOperand::RegOffset(self(desc.new_op(id), Some((typ, state_space)))?, imm)
}
TypedOperand::VecMember(reg, index) => {
let scalar_type = match typ {
@@ -6166,7 +6029,10 @@ where
_ => return Err(error_unreachable()),
};
let vec_type = ast::Type::Vector(scalar_type, index + 1);
- TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index)
+ TypedOperand::VecMember(
+ self(desc.new_op(reg), Some((&vec_type, state_space)))?,
+ index,
+ )
}
})
}
@@ -6178,9 +6044,9 @@ impl ast::Type {
ast::Type::Scalar(scalar) => {
let kind = scalar.kind();
let width = scalar.size_of();
- if (kind != ScalarKind::Signed
- && kind != ScalarKind::Unsigned
- && kind != ScalarKind::Bit)
+ if (kind != ast::ScalarKind::Signed
+ && kind != ast::ScalarKind::Unsigned
+ && kind != ast::ScalarKind::Bit)
|| (width == 8)
{
return Err(TranslateError::MismatchedType);
@@ -6198,57 +6064,32 @@ impl ast::Type {
match self {
ast::Type::Scalar(scalar) => TypeParts {
kind: TypeKind::Scalar,
+ state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
- state_space: ast::LdStateSpace::Global,
},
ast::Type::Vector(scalar, components) => TypeParts {
kind: TypeKind::Vector,
+ state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: vec![*components as u32],
- state_space: ast::LdStateSpace::Global,
},
ast::Type::Array(scalar, components) => TypeParts {
kind: TypeKind::Array,
+ state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: components.clone(),
- state_space: ast::LdStateSpace::Global,
},
- ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts {
- kind: TypeKind::PointerScalar,
+ ast::Type::Pointer(scalar, space) => TypeParts {
+ kind: TypeKind::Pointer,
+ state_space: *space,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
- state_space: *state_space,
- },
- ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts {
- kind: TypeKind::PointerVector,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: vec![*len as u32],
- state_space: *state_space,
},
- ast::Type::Pointer(ast::PointerType::Array(scalar, components), state_space) => {
- TypeParts {
- kind: TypeKind::PointerArray,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: components.clone(),
- state_space: *state_space,
- }
- }
- ast::Type::Pointer(ast::PointerType::Pointer(scalar, inner_space), state_space) => {
- TypeParts {
- kind: TypeKind::PointerPointer,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: vec![*inner_space as u32],
- state_space: *state_space,
- }
- }
}
}
@@ -6265,29 +6106,8 @@ impl ast::Type {
ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components,
),
- TypeKind::PointerScalar => ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)),
- t.state_space,
- ),
- TypeKind::PointerVector => ast::Type::Pointer(
- ast::PointerType::Vector(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- t.components[0] as u8,
- ),
- t.state_space,
- ),
- TypeKind::PointerArray => ast::Type::Pointer(
- ast::PointerType::Array(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- t.components,
- ),
- t.state_space,
- ),
- TypeKind::PointerPointer => ast::Type::Pointer(
- ast::PointerType::Pointer(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- unsafe { mem::transmute::<_, ast::LdStateSpace>(t.components[0] as u8) },
- ),
+ TypeKind::Pointer => ast::Type::Pointer(
+ ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.state_space,
),
}
@@ -6300,7 +6120,7 @@ impl ast::Type {
ast::Type::Array(typ, len) => len
.iter()
.fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)),
- ast::Type::Pointer(_, _) => mem::size_of::<usize>(),
+ ast::Type::Pointer(..) => mem::size_of::<usize>(),
}
}
}
@@ -6308,10 +6128,10 @@ impl ast::Type {
#[derive(Eq, PartialEq, Clone)]
struct TypeParts {
kind: TypeKind,
- scalar_kind: ScalarKind,
+ scalar_kind: ast::ScalarKind,
width: u8,
+ state_space: ast::StateSpace,
components: Vec<u32>,
- state_space: ast::LdStateSpace,
}
#[derive(Eq, PartialEq, Copy, Clone)]
@@ -6319,10 +6139,7 @@ enum TypeKind {
Scalar,
Vector,
Array,
- PointerScalar,
- PointerVector,
- PointerArray,
- PointerPointer,
+ Pointer,
}
impl ast::Instruction<ExpandedArgParams> {
@@ -6450,21 +6267,21 @@ struct BrachCondition {
struct ImplicitConversion {
src: spirv::Word,
dst: spirv::Word,
- from: ast::Type,
- to: ast::Type,
+ from_type: ast::Type,
+ to_type: ast::Type,
+ from_space: ast::StateSpace,
+ to_space: ast::StateSpace,
kind: ConversionKind,
- src_sema: ArgumentSemantics,
- dst_sema: ArgumentSemantics,
}
-#[derive(PartialEq, Copy, Clone)]
+#[derive(PartialEq, Clone)]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
- BitToPtr(ast::LdStateSpace),
- PtrToBit(ast::UIntType),
- PtrToPtr { spirv_ptr: bool },
+ BitToPtr,
+ PtrToPtr,
+ AddressOf,
}
impl<T> ast::PredAt<T> {
@@ -6512,13 +6329,14 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<ast::Arg1<U>, TranslateError> {
let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
)?;
@@ -6535,9 +6353,11 @@ impl<T: ArgParamsEx> ast::Arg1Bar<T> {
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg1Bar { src: new_src })
}
@@ -6553,17 +6373,21 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let new_src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2 {
dst: new_dst,
@@ -6581,17 +6405,21 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
dst_t,
+ ast::StateSpace::Reg,
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
src_t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2 { dst, src })
}
@@ -6607,26 +6435,21 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::DefaultRelaxed,
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(should_convert_relaxed_dst_wrapper),
},
&ast::Type::from(details.typ.clone()),
+ ast::StateSpace::Reg,
)?;
- let is_logical_ptr = details.state_space == ast::LdStateSpace::Param
- || details.state_space == ast::LdStateSpace::Local;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: if is_logical_ptr {
- ArgumentSemantics::RegisterPointer
- } else {
- ArgumentSemantics::PhysicalPointer
- },
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::from(details.typ.clone()),
- details.state_space,
- ),
+ &details.typ,
+ details.state_space,
)?;
Ok(ast::Arg2Ld { dst, src })
}
@@ -6638,30 +6461,25 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
visitor: &mut V,
details: &ast::StData,
) -> Result<ast::Arg2St<U>, TranslateError> {
- let is_logical_ptr = details.state_space == ast::StStateSpace::Param
- || details.state_space == ast::StStateSpace::Local;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: if is_logical_ptr {
- ArgumentSemantics::RegisterPointer
- } else {
- ArgumentSemantics::PhysicalPointer
- },
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::from(details.typ.clone()),
- details.state_space.to_ld_ss(),
- ),
+ &details.typ,
+ details.state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::DefaultRelaxed,
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(should_convert_relaxed_src_wrapper),
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2St { src1, src2 })
}
@@ -6677,21 +6495,21 @@ impl<T: ArgParamsEx> ast::Arg2Mov<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: if details.src_is_address {
- ArgumentSemantics::Address
- } else {
- ArgumentSemantics::Default
- },
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(implicit_conversion_mov),
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2Mov { dst, src })
}
@@ -6713,25 +6531,31 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
wide_type.as_ref().unwrap_or(typ),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6745,25 +6569,31 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6772,35 +6602,38 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
self,
visitor: &mut V,
t: ast::ScalarType,
- state_space: ast::AtomSpace,
+ state_space: ast::StateSpace,
) -> Result<ast::Arg3<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::PhysicalPointer,
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::Scalar(scalar_type),
- state_space.to_ld_ss(),
- ),
+ &ast::Type::Scalar(scalar_type),
+ state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6822,33 +6655,41 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
wide_type.as_ref().unwrap_or(t),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6861,39 +6702,47 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
fn map_selp<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::SelpType,
+ t: ast::ScalarType,
) -> Result<ast::Arg4<U>, TranslateError> {
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6906,44 +6755,49 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::BitType,
- state_space: ast::AtomSpace,
+ t: ast::ScalarType,
+ state_space: ast::StateSpace,
) -> Result<ast::Arg4<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::PhysicalPointer,
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::Scalar(scalar_type),
- state_space.to_ld_ss(),
- ),
+ &ast::Type::Scalar(scalar_type),
+ state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6962,34 +6816,42 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
let u32_type = ast::Type::Scalar(ast::ScalarType::U32);
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&u32_type,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&u32_type,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -7010,9 +6872,13 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
ArgumentDescriptor {
op: self.dst1,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)?;
let dst2 = self
.dst2
@@ -7021,9 +6887,13 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
ArgumentDescriptor {
op: dst2,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)
})
.transpose()?;
@@ -7031,17 +6901,21 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4Setp {
dst1,
@@ -7062,41 +6936,51 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
base_type,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
base_type,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
base_type,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
let src4 = visitor.operand(
ArgumentDescriptor {
op: self.src4,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg5 {
dst,
@@ -7118,9 +7002,13 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
ArgumentDescriptor {
op: self.dst1,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)?;
let dst2 = self
.dst2
@@ -7129,9 +7017,13 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
ArgumentDescriptor {
op: dst2,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)
})
.transpose()?;
@@ -7139,25 +7031,31 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg5Setp {
dst1,
@@ -7195,115 +7093,41 @@ impl ast::Operand<spirv::Word> {
}
}
-impl ast::StStateSpace {
- fn to_ld_ss(self) -> ast::LdStateSpace {
- match self {
- ast::StStateSpace::Generic => ast::LdStateSpace::Generic,
- ast::StStateSpace::Global => ast::LdStateSpace::Global,
- ast::StStateSpace::Local => ast::LdStateSpace::Local,
- ast::StStateSpace::Param => ast::LdStateSpace::Param,
- ast::StStateSpace::Shared => ast::LdStateSpace::Shared,
- }
- }
-}
-
-#[derive(Clone, Copy, PartialEq, Eq)]
-enum ScalarKind {
- Bit,
- Unsigned,
- Signed,
- Float,
- Float2,
- Pred,
-}
-
impl ast::ScalarType {
- fn kind(self) -> ScalarKind {
- match self {
- ast::ScalarType::U8 => ScalarKind::Unsigned,
- ast::ScalarType::U16 => ScalarKind::Unsigned,
- ast::ScalarType::U32 => ScalarKind::Unsigned,
- ast::ScalarType::U64 => ScalarKind::Unsigned,
- ast::ScalarType::S8 => ScalarKind::Signed,
- ast::ScalarType::S16 => ScalarKind::Signed,
- ast::ScalarType::S32 => ScalarKind::Signed,
- ast::ScalarType::S64 => ScalarKind::Signed,
- ast::ScalarType::B8 => ScalarKind::Bit,
- ast::ScalarType::B16 => ScalarKind::Bit,
- ast::ScalarType::B32 => ScalarKind::Bit,
- ast::ScalarType::B64 => ScalarKind::Bit,
- ast::ScalarType::F16 => ScalarKind::Float,
- ast::ScalarType::F32 => ScalarKind::Float,
- ast::ScalarType::F64 => ScalarKind::Float,
- ast::ScalarType::F16x2 => ScalarKind::Float2,
- ast::ScalarType::Pred => ScalarKind::Pred,
- }
- }
-
- fn from_parts(width: u8, kind: ScalarKind) -> Self {
+ fn from_parts(width: u8, kind: ast::ScalarKind) -> Self {
match kind {
- ScalarKind::Float => match width {
+ ast::ScalarKind::Float => match width {
2 => ast::ScalarType::F16,
4 => ast::ScalarType::F32,
8 => ast::ScalarType::F64,
_ => unreachable!(),
},
- ScalarKind::Bit => match width {
+ ast::ScalarKind::Bit => match width {
1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32,
8 => ast::ScalarType::B64,
_ => unreachable!(),
},
- ScalarKind::Signed => match width {
+ ast::ScalarKind::Signed => match width {
1 => ast::ScalarType::S8,
2 => ast::ScalarType::S16,
4 => ast::ScalarType::S32,
8 => ast::ScalarType::S64,
_ => unreachable!(),
},
- ScalarKind::Unsigned => match width {
+ ast::ScalarKind::Unsigned => match width {
1 => ast::ScalarType::U8,
2 => ast::ScalarType::U16,
4 => ast::ScalarType::U32,
8 => ast::ScalarType::U64,
_ => unreachable!(),
},
- ScalarKind::Float2 => match width {
+ ast::ScalarKind::Float2 => match width {
4 => ast::ScalarType::F16x2,
_ => unreachable!(),
},
- ScalarKind::Pred => ast::ScalarType::Pred,
- }
- }
-}
-
-impl ast::BooleanType {
- fn to_type(self) -> ast::Type {
- match self {
- ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
- ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
- ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
- ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
- }
- }
-}
-
-impl ast::ShlType {
- fn to_type(self) -> ast::Type {
- match self {
- ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
- ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
- ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
- }
- }
-}
-
-impl ast::ShrType {
- fn signed(&self) -> bool {
- match self {
- ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true,
- _ => false,
+ ast::ScalarKind::Pred => ast::ScalarType::Pred,
}
}
}
@@ -7359,49 +7183,47 @@ impl ast::AtomInnerDetails {
}
}
-impl ast::SIntType {
- fn from_size(width: u8) -> Self {
- match width {
- 1 => ast::SIntType::S8,
- 2 => ast::SIntType::S16,
- 4 => ast::SIntType::S32,
- 8 => ast::SIntType::S64,
- _ => unreachable!(),
+impl ast::StateSpace {
+ fn to_spirv(self) -> spirv::StorageClass {
+ match self {
+ ast::StateSpace::Const => spirv::StorageClass::UniformConstant,
+ ast::StateSpace::Generic => spirv::StorageClass::Generic,
+ ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup,
+ ast::StateSpace::Local => spirv::StorageClass::Function,
+ ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
+ ast::StateSpace::Param => spirv::StorageClass::Function,
+ ast::StateSpace::Reg => spirv::StorageClass::Function,
+ ast::StateSpace::Sreg => spirv::StorageClass::Input,
}
}
-}
-impl ast::UIntType {
- fn from_size(width: u8) -> Self {
- match width {
- 1 => ast::UIntType::U8,
- 2 => ast::UIntType::U16,
- 4 => ast::UIntType::U32,
- 8 => ast::UIntType::U64,
- _ => unreachable!(),
- }
+ fn is_compatible(self, other: ast::StateSpace) -> bool {
+ self == other
+ || self == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
+ || self == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
}
-}
-impl ast::LdStateSpace {
- fn to_spirv(self) -> spirv::StorageClass {
+ fn coerces_to_generic(self) -> bool {
match self {
- ast::LdStateSpace::Const => spirv::StorageClass::UniformConstant,
- ast::LdStateSpace::Generic => spirv::StorageClass::Generic,
- ast::LdStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
- ast::LdStateSpace::Local => spirv::StorageClass::Function,
- ast::LdStateSpace::Shared => spirv::StorageClass::Workgroup,
- ast::LdStateSpace::Param => spirv::StorageClass::Function,
+ ast::StateSpace::Global
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Reg
+ | ast::StateSpace::Param
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Sreg => false,
}
}
-}
-impl From<ast::FnArgumentType> for ast::VariableType {
- fn from(t: ast::FnArgumentType) -> Self {
- match t {
- ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
- ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
- ast::FnArgumentType::Shared => todo!(),
+ fn is_addressable(self) -> bool {
+ match self {
+ ast::StateSpace::Const
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Global
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false,
}
}
}
@@ -7427,16 +7249,6 @@ impl ast::MulDetails {
}
}
-impl ast::AtomSpace {
- fn to_ld_ss(self) -> ast::LdStateSpace {
- match self {
- ast::AtomSpace::Generic => ast::LdStateSpace::Generic,
- ast::AtomSpace::Global => ast::LdStateSpace::Global,
- ast::AtomSpace::Shared => ast::LdStateSpace::Shared,
- }
- }
-}
-
impl ast::MemScope {
fn to_spirv(self) -> spirv::Scope {
match self {
@@ -7458,109 +7270,96 @@ impl ast::AtomSemantics {
}
}
-impl ast::FnArgumentType {
- fn semantics(&self) -> ArgumentSemantics {
- match self {
- ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default,
- ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer,
- ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer,
- }
- }
-}
-
-fn bitcast_register_pointer(
- operand_type: &ast::Type,
- instr_type: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+fn default_implicit_conversion(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- bitcast_physical_pointer(operand_type, instr_type, ss)
+ if !instruction_space.is_compatible(operand_space) {
+ default_implicit_conversion_space(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
+ } else if instruction_type != operand_type {
+ default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
+ } else {
+ Ok(None)
+ }
}
-fn bitcast_physical_pointer(
- operand_type: &ast::Type,
- instr_type: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+// Space is different
+fn default_implicit_conversion_space(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- match operand_type {
- // array decays to a pointer
- ast::Type::Array(op_scalar_t, _) => {
- if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
- if ss == Some(*instr_space) {
- if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) {
- Ok(None)
- } else {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- }
+ if (instruction_space == ast::StateSpace::Generic && operand_space.coerces_to_generic())
+ || (operand_space == ast::StateSpace::Generic && instruction_space.coerces_to_generic())
+ {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else if operand_space.is_compatible(ast::StateSpace::Reg) {
+ match operand_type {
+ ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
+ if *operand_ptr_space == instruction_space =>
+ {
+ if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
} else {
- if ss == Some(ast::LdStateSpace::Generic)
- || *instr_space == ast::LdStateSpace::Generic
- {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- } else {
- Err(TranslateError::MismatchedType)
- }
+ Ok(None)
}
- } else {
- Err(TranslateError::MismatchedType)
- }
- }
- ast::Type::Scalar(ast::ScalarType::B64)
- | ast::Type::Scalar(ast::ScalarType::U64)
- | ast::Type::Scalar(ast::ScalarType::S64) => {
- if let Some(space) = ss {
- Ok(Some(ConversionKind::BitToPtr(space)))
- } else {
- Err(error_unreachable())
- }
- }
- ast::Type::Scalar(ast::ScalarType::B32)
- | ast::Type::Scalar(ast::ScalarType::U32)
- | ast::Type::Scalar(ast::ScalarType::S32) => match ss {
- Some(ast::LdStateSpace::Shared)
- | Some(ast::LdStateSpace::Generic)
- | Some(ast::LdStateSpace::Param)
- | Some(ast::LdStateSpace::Local) => {
- Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared)))
}
+ // TODO: 32 bit
+ ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
+ ast::StateSpace::Global
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
+ _ => Err(TranslateError::MismatchedType),
+ },
+ ast::Type::Scalar(ast::ScalarType::B32)
+ | ast::Type::Scalar(ast::ScalarType::U32)
+ | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
+ ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
+ Ok(Some(ConversionKind::BitToPtr))
+ }
+ _ => Err(TranslateError::MismatchedType),
+ },
_ => Err(TranslateError::MismatchedType),
- },
- ast::Type::Pointer(op_scalar_t, op_space) => {
- if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
- if op_space == instr_space {
- if op_scalar_t == instr_scalar_t {
- Ok(None)
- } else {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- }
+ }
+ } else if instruction_space.is_compatible(ast::StateSpace::Reg) {
+ match instruction_type {
+ ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
+ if operand_space == *instruction_ptr_space =>
+ {
+ if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
} else {
- if *op_space == ast::LdStateSpace::Generic
- || *instr_space == ast::LdStateSpace::Generic
- {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- } else {
- Err(TranslateError::MismatchedType)
- }
+ Ok(None)
}
- } else {
- Err(TranslateError::MismatchedType)
}
+ _ => Err(TranslateError::MismatchedType),
}
- _ => Err(TranslateError::MismatchedType),
+ } else {
+ Err(TranslateError::MismatchedType)
}
}
-fn force_bitcast_ptr_to_bit(
- _: &ast::Type,
- instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+// Space is same, but type is different
+fn default_implicit_conversion_type(
+ space: ast::StateSpace,
+ operand_type: &ast::Type,
+ instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> {
- // TODO: verify this on f32, u16 and the like
- if let ast::Type::Scalar(scalar_t) = instr_type {
- if let Ok(int_type) = (*scalar_t).try_into() {
- return Ok(Some(ConversionKind::PtrToBit(int_type)));
+ if space.is_compatible(ast::StateSpace::Reg) {
+ if should_bitcast(instruction_type, operand_type) {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Err(TranslateError::MismatchedType)
}
+ } else {
+ Ok(Some(ConversionKind::PtrToPtr))
}
- Err(TranslateError::MismatchedType)
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
@@ -7570,16 +7369,18 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
return false;
}
match inst.kind() {
- ScalarKind::Bit => operand.kind() != ScalarKind::Bit,
- ScalarKind::Float => operand.kind() == ScalarKind::Bit,
- ScalarKind::Signed => {
- operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned
+ ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
+ ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
+ ast::ScalarKind::Signed => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Unsigned
}
- ScalarKind::Unsigned => {
- operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed
+ ast::ScalarKind::Unsigned => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Signed
}
- ScalarKind::Float2 => false,
- ScalarKind::Pred => false,
+ ast::ScalarKind::Float2 => false,
+ ast::ScalarKind::Pred => false,
}
}
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
@@ -7590,47 +7391,45 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
}
}
-fn should_bitcast_packed(
- operand: &ast::Type,
- instr: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+fn implicit_conversion_mov(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
- (operand, instr)
- {
- if scalar.kind() == ScalarKind::Bit
- && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ // instruction_space is always reg
+ if operand_space.is_compatible(ast::StateSpace::Reg) {
+ if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
+ (operand_type, instruction_type)
{
- return Ok(Some(ConversionKind::Default));
- }
- }
- should_bitcast_wrapper(operand, instr, ss)
-}
-
-fn should_bitcast_wrapper(
- operand: &ast::Type,
- instr: &ast::Type,
- _: Option<ast::LdStateSpace>,
-) -> Result<Option<ConversionKind>, TranslateError> {
- if instr == operand {
- return Ok(None);
- }
- if should_bitcast(instr, operand) {
- Ok(Some(ConversionKind::Default))
- } else {
- Err(TranslateError::MismatchedType)
- }
+ if scalar.kind() == ast::ScalarKind::Bit
+ && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ {
+ return Ok(Some(ConversionKind::Default));
+ }
+ }
+ // TODO: verify .params addressability:
+ // * kernel arg
+ // * func arg
+ // * variable
+ } else if operand_space.is_addressable() {
+ return Ok(Some(ConversionKind::AddressOf));
+ }
+ default_implicit_conversion(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
}
fn should_convert_relaxed_src_wrapper(
- src_type: &ast::Type,
- instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if src_type == instr_type {
+ if !operand_space.is_compatible(instruction_space) {
+ return Err(TranslateError::MismatchedType);
+ }
+ if operand_type == instruction_type {
return Ok(None);
}
- match should_convert_relaxed_src(src_type, instr_type) {
+ match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
@@ -7646,32 +7445,33 @@ fn should_convert_relaxed_src(
}
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
- ScalarKind::Bit => {
+ ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Signed | ScalarKind::Unsigned => {
+ ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of()
- && src_type.kind() != ScalarKind::Float
+ && src_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float => {
- if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float2 => todo!(),
- ScalarKind::Pred => None,
+ ast::ScalarKind::Float2 => todo!(),
+ ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
@@ -7685,14 +7485,16 @@ fn should_convert_relaxed_src(
}
fn should_convert_relaxed_dst_wrapper(
- dst_type: &ast::Type,
- instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if dst_type == instr_type {
+ if !operand_space.is_compatible(instruction_space) {
+ return Err(TranslateError::MismatchedType);
+ }
+ if operand_type == instruction_type {
return Ok(None);
}
- match should_convert_relaxed_dst(dst_type, instr_type) {
+ match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
@@ -7708,15 +7510,15 @@ fn should_convert_relaxed_dst(
}
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
- ScalarKind::Bit => {
+ ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Signed => {
- if dst_type.kind() != ScalarKind::Float {
+ ast::ScalarKind::Signed => {
+ if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() {
@@ -7728,25 +7530,26 @@ fn should_convert_relaxed_dst(
None
}
}
- ScalarKind::Unsigned => {
+ ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of()
- && dst_type.kind() != ScalarKind::Float
+ && dst_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float => {
- if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float2 => todo!(),
- ScalarKind::Pred => None,
+ ast::ScalarKind::Float2 => todo!(),
+ ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
@@ -7759,77 +7562,46 @@ fn should_convert_relaxed_dst(
}
}
-impl<'a> ast::MethodDecl<'a, &'a str> {
+impl<'a> ast::MethodDeclaration<'a, &'a str> {
fn name(&self) -> &'a str {
- match self {
- ast::MethodDecl::Kernel { name, .. } => name,
- ast::MethodDecl::Func(_, name, _) => name,
+ match self.name {
+ ast::MethodName::Kernel(name) => name,
+ ast::MethodName::Func(name) => name,
}
}
}
-struct SpirvMethodDecl<'input> {
- input: Vec<ast::Variable<ast::Type, spirv::Word>>,
- output: Vec<ast::Variable<ast::Type, spirv::Word>>,
- name: MethodName<'input>,
- uses_shared_mem: bool,
+impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
+ fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ {
+ let is_kernel = self.name.is_kernel();
+ self.input_arguments
+ .iter()
+ .map(move |arg| {
+ if !is_kernel && arg.state_space != ast::StateSpace::Reg {
+ let spirv_type =
+ SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
+ (arg.name, spirv_type)
+ } else {
+ (arg.name, SpirvType::new(arg.v_type.clone()))
+ }
+ })
+ .chain(self.shared_mem.iter().map(|id| {
+ (
+ *id,
+ SpirvType::Pointer(
+ Box::new(SpirvType::Base(SpirvScalarKey::B8)),
+ spirv::StorageClass::Workgroup,
+ ),
+ )
+ }))
+ }
}
-impl<'input> SpirvMethodDecl<'input> {
- fn new(ast_decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
- let (input, output) = match ast_decl {
- ast::MethodDecl::Kernel { in_args, .. } => {
- let spirv_input = in_args
- .iter()
- .map(|var| {
- let v_type = match &var.v_type {
- ast::KernelArgumentType::Normal(t) => {
- ast::FnArgumentType::Param(t.clone())
- }
- ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared,
- };
- ast::Variable {
- name: var.name,
- align: var.align,
- v_type: v_type.to_kernel_type(),
- array_init: var.array_init.clone(),
- }
- })
- .collect();
- (spirv_input, Vec::new())
- }
- ast::MethodDecl::Func(out_args, _, in_args) => {
- let (param_output, non_param_output): (Vec<_>, Vec<_>) =
- out_args.iter().partition(|var| var.v_type.is_param());
- let spirv_output = non_param_output
- .into_iter()
- .cloned()
- .map(|var| ast::Variable {
- name: var.name,
- align: var.align,
- v_type: var.v_type.to_func_type(),
- array_init: var.array_init.clone(),
- })
- .collect();
- let spirv_input = param_output
- .into_iter()
- .cloned()
- .chain(in_args.iter().cloned())
- .map(|var| ast::Variable {
- name: var.name,
- align: var.align,
- v_type: var.v_type.to_func_type(),
- array_init: var.array_init.clone(),
- })
- .collect();
- (spirv_input, spirv_output)
- }
- };
- SpirvMethodDecl {
- input,
- output,
- name: MethodName::new(ast_decl),
- uses_shared_mem: false,
+impl<'input, ID> ast::MethodName<'input, ID> {
+ fn is_kernel(&self) -> bool {
+ match self {
+ ast::MethodName::Kernel(..) => true,
+ ast::MethodName::Func(..) => false,
}
}
}