diff options
Diffstat (limited to 'ptx/src')
43 files changed, 2613 insertions, 3008 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 1c6d2fb..5432207 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -1,7 +1,6 @@ use half::f16; use lalrpop_util::{lexer::Token, ParseError}; -use std::convert::TryInto; -use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr}; +use std::{convert::From, mem, num::ParseFloatError, str::FromStr}; use std::{marker::PhantomData, num::ParseIntError}; #[derive(Debug, thiserror::Error)] @@ -34,195 +33,12 @@ pub enum PtxError { NonExternPointer, } -macro_rules! sub_enum { - ($name:ident { $($variant:ident),+ $(,)? }) => { - sub_enum!{ $name : ScalarType { $($variant),+ } } - }; - ($name:ident : $base_type:ident { $($variant:ident),+ $(,)? }) => { - #[derive(PartialEq, Eq, Clone, Copy)] - pub enum $name { - $( - $variant, - )+ - } - - impl From<$name> for $base_type { - fn from(t: $name) -> $base_type { - match t { - $( - $name::$variant => $base_type::$variant, - )+ - } - } - } - - impl std::convert::TryFrom<$base_type> for $name { - type Error = (); - - fn try_from(t: $base_type) -> Result<Self, Self::Error> { - match t { - $( - $base_type::$variant => Ok($name::$variant), - )+ - _ => Err(()), - } - } - } - }; -} - -macro_rules! sub_type { - ($type_name:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - sub_type! { $type_name : Type { - $( - $variant ($($field_type),+), - )+ - }} - }; - ($type_name:ident : $base_type:ident { $($variant:ident ( $($field_type:ident),+ ) ),+ $(,)? } ) => { - #[derive(PartialEq, Eq, Clone)] - pub enum $type_name { - $( - $variant ($($field_type),+), - )+ - } - - impl From<$type_name> for $base_type { - #[allow(non_snake_case)] - fn from(t: $type_name) -> $base_type { - match t { - $( - $type_name::$variant ( $($field_type),+ ) => <$base_type>::$variant ( $($field_type.into()),+), - )+ - } - } - } - - impl std::convert::TryFrom<$base_type> for $type_name { - type Error = (); - - #[allow(non_snake_case)] - #[allow(unreachable_patterns)] - fn try_from(t: $base_type) -> Result<Self, Self::Error> { - match t { - $( - $base_type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )), - )+ - _ => Err(()), - } - } - } - }; -} - -sub_type! { - VariableRegType { - Scalar(ScalarType), - Vector(SizedScalarType, u8), - // Array type is used when emiting SSA statements at the start of a method - Array(ScalarType, VecU32), - // Pointer variant is used when passing around SLM pointer between - // function calls for dynamic SLM - Pointer(SizedScalarType, PointerStateSpace) - } -} - -type VecU32 = Vec<u32>; - -sub_type! { - VariableLocalType { - Scalar(SizedScalarType), - Vector(SizedScalarType, u8), - Array(SizedScalarType, VecU32), - } -} - -impl TryFrom<VariableGlobalType> for VariableLocalType { - type Error = PtxError; - - fn try_from(value: VariableGlobalType) -> Result<Self, Self::Error> { - match value { - VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)), - VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)), - VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)), - VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray), - } - } -} - -sub_type! { - VariableGlobalType { - Scalar(SizedScalarType), - Vector(SizedScalarType, u8), - Array(SizedScalarType, VecU32), - Pointer(SizedScalarType, PointerStateSpace), - } -} - // For some weird reson this is illegal: // .param .f16x2 foobar; // but this is legal: // .param .f16x2 foobar[1]; // even more interestingly this is legal, but only in .func (not in .entry): // .param .b32 foobar[] -sub_type! { - VariableParamType { - Scalar(LdStScalarType), - Array(SizedScalarType, VecU32), - Pointer(SizedScalarType, PointerStateSpace), - } -} - -sub_enum!(SizedScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F16x2, - F32, - F64, -}); - -sub_enum!(LdStScalarType { - B8, - B16, - B32, - B64, - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64, - F16, - F32, - F64, -}); - -sub_enum!(SelpType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, - F32, - F64, -}); #[derive(Copy, Clone, Eq, PartialEq)] pub enum BarDetails { @@ -266,23 +82,25 @@ pub struct Module<'a> { } pub enum Directive<'a, P: ArgParams> { - Variable(Variable<VariableType, P::Id>), - Method(Function<'a, &'a str, Statement<P>>), + Variable(LinkingDirective, Variable<P::Id>), + Method(LinkingDirective, Function<'a, &'a str, Statement<P>>), } -pub enum MethodDecl<'a, ID> { - Func(Vec<FnArgument<ID>>, ID, Vec<FnArgument<ID>>), - Kernel { - name: &'a str, - in_args: Vec<KernelArgument<ID>>, - }, +#[derive(Hash, PartialEq, Eq, Copy, Clone)] +pub enum MethodName<'input, ID> { + Kernel(&'input str), + Func(ID), } -pub type FnArgument<ID> = Variable<FnArgumentType, ID>; -pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>; +pub struct MethodDeclaration<'input, ID> { + pub return_arguments: Vec<Variable<ID>>, + pub name: MethodName<'input, ID>, + pub input_arguments: Vec<Variable<ID>>, + pub shared_mem: Option<ID>, +} pub struct Function<'a, ID, S> { - pub func_directive: MethodDecl<'a, ID>, + pub func_directive: MethodDeclaration<'a, ID>, pub tuning: Vec<TuningDirective>, pub body: Option<Vec<S>>, } @@ -290,118 +108,50 @@ pub struct Function<'a, ID, S> { pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a>>>; #[derive(PartialEq, Eq, Clone)] -pub enum FnArgumentType { - Reg(VariableRegType), - Param(VariableParamType), - Shared, -} -#[derive(PartialEq, Eq, Clone)] -pub enum KernelArgumentType { - Normal(VariableParamType), - Shared, -} - -impl From<KernelArgumentType> for Type { - fn from(this: KernelArgumentType) -> Self { - match this { - KernelArgumentType::Normal(typ) => typ.into(), - KernelArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } -} - -impl FnArgumentType { - pub fn to_type(&self, is_kernel: bool) -> Type { - if is_kernel { - self.to_kernel_type() - } else { - self.to_func_type() - } - } - - pub fn to_kernel_type(&self) -> Type { - match self { - FnArgumentType::Reg(x) => x.clone().into(), - FnArgumentType::Param(x) => x.clone().into(), - FnArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } - - pub fn to_func_type(&self) -> Type { - match self { - FnArgumentType::Reg(x) => x.clone().into(), - FnArgumentType::Param(VariableParamType::Scalar(t)) => { - Type::Pointer(PointerType::Scalar((*t).into()), LdStateSpace::Param) - } - FnArgumentType::Param(VariableParamType::Array(t, dims)) => Type::Pointer( - PointerType::Array((*t).into(), dims.clone()), - LdStateSpace::Param, - ), - FnArgumentType::Param(VariableParamType::Pointer(t, space)) => Type::Pointer( - PointerType::Pointer((*t).into(), (*space).into()), - LdStateSpace::Param, - ), - FnArgumentType::Shared => { - Type::Pointer(PointerType::Scalar(ScalarType::B8), LdStateSpace::Shared) - } - } - } - - pub fn is_param(&self) -> bool { - match self { - FnArgumentType::Param(_) => true, - _ => false, - } - } -} - -sub_enum!( - PointerStateSpace : LdStateSpace { - Generic, - Global, - Const, - Shared, - Param, - } -); - -#[derive(PartialEq, Eq, Clone)] pub enum Type { + // .param.b32 foo; + // -> OpTypeInt Scalar(ScalarType), + // .param.v2.b32 foo; + // -> OpTypeVector Vector(ScalarType, u8), + // .param.b32 foo[4]; + // -> OpTypeArray Array(ScalarType, Vec<u32>), - Pointer(PointerType, LdStateSpace), -} - -#[derive(PartialEq, Eq, Clone)] -pub enum PointerType { - Scalar(ScalarType), - Vector(ScalarType, u8), - Array(ScalarType, VecU32), - Pointer(ScalarType, LdStateSpace), -} - -impl From<SizedScalarType> for PointerType { - fn from(t: SizedScalarType) -> Self { - PointerType::Scalar(t.into()) - } -} - -impl TryFrom<PointerType> for SizedScalarType { - type Error = (); - - fn try_from(value: PointerType) -> Result<Self, Self::Error> { - match value { - PointerType::Scalar(t) => Ok(t.try_into()?), - PointerType::Vector(_, _) => Err(()), - PointerType::Array(_, _) => Err(()), - PointerType::Pointer(_, _) => Err(()), - } - } + /* + Variables of this type almost never exist in the original .ptx and are + usually artificially created. Some examples below: + - extern pointers to the .shared memory in the form: + .extern .shared .b32 shared_mem[]; + which we first parse as + .extern .shared .b32 shared_mem; + and then convert to an additional function parameter: + .param .ptr<.b32.shared> shared_mem; + and do a load at the start of the function (and renames inside fn): + .reg .ptr<.b32.shared> temp; + ld.param.ptr<.b32.shared> temp, [shared_mem]; + note, we don't support non-.shared extern pointers, because there's + zero use for them in the ptxas + - artifical pointers created by stateful conversion, which work + similiarly to the above + - function parameters: + foobar(.param .align 4 .b8 numbers[]) + which get parsed to + foobar(.param .align 4 .b8 numbers) + and then converted to + foobar(.reg .align 4 .ptr<.b8.param> numbers) + - ld/st with offset: + .reg.b32 x; + .param.b64 arg0; + st.param.b32 [arg0+4], x; + Yes, this code is legal and actually emitted by the NV compiler! + We convert the st to: + .reg ptr<.b64.param> temp = ptr_offset(arg0, 4); + st.param.b32 [temp], x; + */ + // .reg ptr<.b64.param> + // -> OpTypePointer Function + Pointer(ScalarType, StateSpace), } #[derive(PartialEq, Eq, Hash, Clone, Copy)] @@ -425,52 +175,6 @@ pub enum ScalarType { Pred, } -sub_enum!(IntType { - U8, - U16, - U32, - U64, - S8, - S16, - S32, - S64 -}); - -sub_enum!(BitType { B8, B16, B32, B64 }); - -sub_enum!(UIntType { U8, U16, U32, U64 }); - -sub_enum!(SIntType { S8, S16, S32, S64 }); - -impl IntType { - pub fn is_signed(self) -> bool { - match self { - IntType::U8 | IntType::U16 | IntType::U32 | IntType::U64 => false, - IntType::S8 | IntType::S16 | IntType::S32 | IntType::S64 => true, - } - } - - pub fn width(self) -> u8 { - match self { - IntType::U8 => 1, - IntType::U16 => 2, - IntType::U32 => 4, - IntType::U64 => 8, - IntType::S8 => 1, - IntType::S16 => 2, - IntType::S32 => 4, - IntType::S64 => 8, - } - } -} - -sub_enum!(FloatType { - F16, - F16x2, - F32, - F64 -}); - impl ScalarType { pub fn size_of(self) -> u8 { match self { @@ -509,51 +213,19 @@ pub enum Statement<P: ArgParams> { } pub struct MultiVariable<ID> { - pub var: Variable<VariableType, ID>, + pub var: Variable<ID>, pub count: Option<u32>, } #[derive(Clone)] -pub struct Variable<T, ID> { +pub struct Variable<ID> { pub align: Option<u32>, - pub v_type: T, + pub v_type: Type, + pub state_space: StateSpace, pub name: ID, pub array_init: Vec<u8>, } -#[derive(Eq, PartialEq, Clone)] -pub enum VariableType { - Reg(VariableRegType), - Local(VariableLocalType), - Param(VariableParamType), - Global(VariableGlobalType), - Shared(VariableGlobalType), -} - -impl VariableType { - pub fn to_type(&self) -> (StateSpace, Type) { - match self { - VariableType::Reg(t) => (StateSpace::Reg, t.clone().into()), - VariableType::Local(t) => (StateSpace::Local, t.clone().into()), - VariableType::Param(t) => (StateSpace::Param, t.clone().into()), - VariableType::Global(t) => (StateSpace::Global, t.clone().into()), - VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()), - } - } -} - -impl From<VariableType> for Type { - fn from(t: VariableType) -> Self { - match t { - VariableType::Reg(t) => t.into(), - VariableType::Local(t) => t.into(), - VariableType::Param(t) => t.into(), - VariableType::Global(t) => t.into(), - VariableType::Shared(t) => t.into(), - } - } -} - #[derive(Copy, Clone, PartialEq, Eq)] pub enum StateSpace { Reg, @@ -562,6 +234,8 @@ pub enum StateSpace { Local, Shared, Param, + Generic, + Sreg, } pub struct PredAt<ID> { @@ -576,24 +250,24 @@ pub enum Instruction<P: ArgParams> { Add(ArithDetails, Arg3<P>), Setp(SetpData, Arg4Setp<P>), SetpBool(SetpBoolData, Arg5Setp<P>), - Not(BooleanType, Arg2<P>), + Not(ScalarType, Arg2<P>), Bra(BraData, Arg1<P>), Cvt(CvtDetails, Arg2<P>), Cvta(CvtaDetails, Arg2<P>), - Shl(ShlType, Arg3<P>), - Shr(ShrType, Arg3<P>), + Shl(ScalarType, Arg3<P>), + Shr(ScalarType, Arg3<P>), St(StData, Arg2St<P>), Ret(RetData), Call(CallInst<P>), Abs(AbsDetails, Arg2<P>), Mad(MulDetails, Arg4<P>), - Or(BooleanType, Arg3<P>), + Or(ScalarType, Arg3<P>), Sub(ArithDetails, Arg3<P>), Min(MinMaxDetails, Arg3<P>), Max(MinMaxDetails, Arg3<P>), Rcp(RcpDetails, Arg2<P>), - And(BooleanType, Arg3<P>), - Selp(SelpType, Arg4<P>), + And(ScalarType, Arg3<P>), + Selp(ScalarType, Arg4<P>), Bar(BarDetails, Arg1Bar<P>), Atom(AtomDetails, Arg3<P>), AtomCas(AtomCasDetails, Arg4<P>), @@ -605,13 +279,13 @@ pub enum Instruction<P: ArgParams> { Cos { flush_to_zero: bool, arg: Arg2<P> }, Lg2 { flush_to_zero: bool, arg: Arg2<P> }, Ex2 { flush_to_zero: bool, arg: Arg2<P> }, - Clz { typ: BitType, arg: Arg2<P> }, - Brev { typ: BitType, arg: Arg2<P> }, - Popc { typ: BitType, arg: Arg2<P> }, - Xor { typ: BooleanType, arg: Arg3<P> }, - Bfe { typ: IntType, arg: Arg4<P> }, - Bfi { typ: BitType, arg: Arg5<P> }, - Rem { typ: IntType, arg: Arg3<P> }, + Clz { typ: ScalarType, arg: Arg2<P> }, + Brev { typ: ScalarType, arg: Arg2<P> }, + Popc { typ: ScalarType, arg: Arg2<P> }, + Xor { typ: ScalarType, arg: Arg3<P> }, + Bfe { typ: ScalarType, arg: Arg4<P> }, + Bfi { typ: ScalarType, arg: Arg5<P> }, + Rem { typ: ScalarType, arg: Arg3<P> }, } #[derive(Copy, Clone)] @@ -737,34 +411,12 @@ pub enum VectorPrefix { pub struct LdDetails { pub qualifier: LdStQualifier, - pub state_space: LdStateSpace, + pub state_space: StateSpace, pub caching: LdCacheOperator, - pub typ: LdStType, + pub typ: Type, pub non_coherent: bool, } -sub_type! { - LdStType { - Scalar(LdStScalarType), - Vector(LdStScalarType, u8), - // Used in generated code - Pointer(PointerType, LdStateSpace), - } -} - -impl From<LdStType> for PointerType { - fn from(t: LdStType) -> Self { - match t { - LdStType::Scalar(t) => PointerType::Scalar(t.into()), - LdStType::Vector(t, len) => PointerType::Vector(t.into(), len), - LdStType::Pointer(PointerType::Scalar(scalar_type), space) => { - PointerType::Pointer(scalar_type, space) - } - LdStType::Pointer(..) => unreachable!(), - } - } -} - #[derive(Copy, Clone, PartialEq, Eq)] pub enum LdStQualifier { Weak, @@ -780,17 +432,6 @@ pub enum MemScope { Sys, } -#[derive(Copy, Clone, PartialEq, Eq, Debug)] -#[repr(u8)] -pub enum LdStateSpace { - Generic, - Const, - Global, - Local, - Param, - Shared, -} - #[derive(Copy, Clone, PartialEq, Eq)] pub enum LdCacheOperator { Cached, @@ -825,7 +466,7 @@ impl MovDetails { #[derive(Copy, Clone)] pub struct MulIntDesc { - pub typ: IntType, + pub typ: ScalarType, pub control: MulIntControl, } @@ -845,7 +486,7 @@ pub enum RoundingMode { } pub struct AddIntDesc { - pub typ: IntType, + pub typ: ScalarType, pub saturate: bool, } @@ -892,39 +533,39 @@ pub struct BraData { pub enum CvtDetails { IntFromInt(CvtIntToIntDesc), - FloatFromFloat(CvtDesc<FloatType, FloatType>), - IntFromFloat(CvtDesc<IntType, FloatType>), - FloatFromInt(CvtDesc<FloatType, IntType>), + FloatFromFloat(CvtDesc), + IntFromFloat(CvtDesc), + FloatFromInt(CvtDesc), } pub struct CvtIntToIntDesc { - pub dst: IntType, - pub src: IntType, + pub dst: ScalarType, + pub src: ScalarType, pub saturate: bool, } -pub struct CvtDesc<Dst, Src> { +pub struct CvtDesc { pub rounding: Option<RoundingMode>, pub flush_to_zero: Option<bool>, pub saturate: bool, - pub dst: Dst, - pub src: Src, + pub dst: ScalarType, + pub src: ScalarType, } impl CvtDetails { pub fn new_int_from_int_checked<'err, 'input>( saturate: bool, - dst: IntType, - src: IntType, + dst: ScalarType, + src: ScalarType, err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>, ) -> Self { if saturate { - if src.is_signed() { - if dst.is_signed() && dst.width() >= src.width() { + if src.kind() == ScalarKind::Signed { + if dst.kind() == ScalarKind::Signed && dst.size_of() >= src.size_of() { err.push(ParseError::from(PtxError::SyntaxError)); } } else { - if dst == src || dst.width() >= src.width() { + if dst == src || dst.size_of() >= src.size_of() { err.push(ParseError::from(PtxError::SyntaxError)); } } @@ -936,11 +577,11 @@ impl CvtDetails { rounding: RoundingMode, flush_to_zero: bool, saturate: bool, - dst: FloatType, - src: IntType, + dst: ScalarType, + src: ScalarType, err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>, ) -> Self { - if flush_to_zero && dst != FloatType::F32 { + if flush_to_zero && dst != ScalarType::F32 { err.push(ParseError::from(PtxError::NonF32Ftz)); } CvtDetails::FloatFromInt(CvtDesc { @@ -956,11 +597,11 @@ impl CvtDetails { rounding: RoundingMode, flush_to_zero: bool, saturate: bool, - dst: IntType, - src: FloatType, + dst: ScalarType, + src: ScalarType, err: &'err mut Vec<ParseError<usize, Token<'input>, PtxError>>, ) -> Self { - if flush_to_zero && src != FloatType::F32 { + if flush_to_zero && src != ScalarType::F32 { err.push(ParseError::from(PtxError::NonF32Ftz)); } CvtDetails::IntFromFloat(CvtDesc { @@ -974,58 +615,21 @@ impl CvtDetails { } pub struct CvtaDetails { - pub to: CvtaStateSpace, - pub from: CvtaStateSpace, + pub to: StateSpace, + pub from: StateSpace, pub size: CvtaSize, } -#[derive(Copy, Clone, PartialEq, Eq)] -pub enum CvtaStateSpace { - Generic, - Const, - Global, - Local, - Shared, -} - pub enum CvtaSize { U32, U64, } -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum ShlType { - B16, - B32, - B64, -} - -sub_enum!(ShrType { - B16, - B32, - B64, - U16, - U32, - U64, - S16, - S32, - S64, -}); - pub struct StData { pub qualifier: LdStQualifier, - pub state_space: StStateSpace, + pub state_space: StateSpace, pub caching: StCacheOperator, - pub typ: LdStType, -} - -#[derive(PartialEq, Eq, Copy, Clone)] -pub enum StStateSpace { - Generic, - Global, - Local, - Param, - Shared, + pub typ: Type, } #[derive(PartialEq, Eq)] @@ -1040,13 +644,6 @@ pub struct RetData { pub uniform: bool, } -sub_enum!(BooleanType { - Pred, - B16, - B32, - B64, -}); - #[derive(Copy, Clone)] pub enum MulDetails { Unsigned(MulUInt), @@ -1056,32 +653,32 @@ pub enum MulDetails { #[derive(Copy, Clone)] pub struct MulUInt { - pub typ: UIntType, + pub typ: ScalarType, pub control: MulIntControl, } #[derive(Copy, Clone)] pub struct MulSInt { - pub typ: SIntType, + pub typ: ScalarType, pub control: MulIntControl, } #[derive(Copy, Clone)] pub enum ArithDetails { - Unsigned(UIntType), + Unsigned(ScalarType), Signed(ArithSInt), Float(ArithFloat), } #[derive(Copy, Clone)] pub struct ArithSInt { - pub typ: SIntType, + pub typ: ScalarType, pub saturate: bool, } #[derive(Copy, Clone)] pub struct ArithFloat { - pub typ: FloatType, + pub typ: ScalarType, pub rounding: Option<RoundingMode>, pub flush_to_zero: Option<bool>, pub saturate: bool, @@ -1089,8 +686,8 @@ pub struct ArithFloat { #[derive(Copy, Clone)] pub enum MinMaxDetails { - Signed(SIntType), - Unsigned(UIntType), + Signed(ScalarType), + Unsigned(ScalarType), Float(MinMaxFloat), } @@ -1098,14 +695,14 @@ pub enum MinMaxDetails { pub struct MinMaxFloat { pub flush_to_zero: Option<bool>, pub nan: bool, - pub typ: FloatType, + pub typ: ScalarType, } #[derive(Copy, Clone)] pub struct AtomDetails { pub semantics: AtomSemantics, pub scope: MemScope, - pub space: AtomSpace, + pub space: StateSpace, pub inner: AtomInnerDetails, } @@ -1118,18 +715,11 @@ pub enum AtomSemantics { } #[derive(Copy, Clone)] -pub enum AtomSpace { - Generic, - Global, - Shared, -} - -#[derive(Copy, Clone)] pub enum AtomInnerDetails { - Bit { op: AtomBitOp, typ: BitType }, - Unsigned { op: AtomUIntOp, typ: UIntType }, - Signed { op: AtomSIntOp, typ: SIntType }, - Float { op: AtomFloatOp, typ: FloatType }, + Bit { op: AtomBitOp, typ: ScalarType }, + Unsigned { op: AtomUIntOp, typ: ScalarType }, + Signed { op: AtomSIntOp, typ: ScalarType }, + Float { op: AtomFloatOp, typ: ScalarType }, } #[derive(Copy, Clone, Eq, PartialEq)] @@ -1165,20 +755,20 @@ pub enum AtomFloatOp { pub struct AtomCasDetails { pub semantics: AtomSemantics, pub scope: MemScope, - pub space: AtomSpace, - pub typ: BitType, + pub space: StateSpace, + pub typ: ScalarType, } #[derive(Copy, Clone)] pub enum DivDetails { - Unsigned(UIntType), - Signed(SIntType), + Unsigned(ScalarType), + Signed(ScalarType), Float(DivFloatDetails), } #[derive(Copy, Clone)] pub struct DivFloatDetails { - pub typ: FloatType, + pub typ: ScalarType, pub flush_to_zero: Option<bool>, pub kind: DivFloatKind, } @@ -1197,7 +787,7 @@ pub enum NumsOrArrays<'a> { #[derive(Copy, Clone)] pub struct SqrtDetails { - pub typ: FloatType, + pub typ: ScalarType, pub flush_to_zero: Option<bool>, pub kind: SqrtKind, } @@ -1210,7 +800,7 @@ pub enum SqrtKind { #[derive(Copy, Clone, Eq, PartialEq)] pub struct RsqrtDetails { - pub typ: FloatType, + pub typ: ScalarType, pub flush_to_zero: bool, } @@ -1221,7 +811,7 @@ pub struct NegDetails { } impl<'a> NumsOrArrays<'a> { - pub fn to_vec(self, typ: SizedScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> { + pub fn to_vec(self, typ: ScalarType, dimensions: &mut [u32]) -> Result<Vec<u8>, PtxError> { self.normalize_dimensions(dimensions)?; let sizeof_t = ScalarType::from(typ).size_of() as usize; let result_size = dimensions.iter().fold(sizeof_t, |x, y| x * (*y as usize)); @@ -1252,7 +842,7 @@ impl<'a> NumsOrArrays<'a> { fn parse_and_copy( &self, - t: SizedScalarType, + t: ScalarType, size_of_t: usize, dimensions: &[u32], result: &mut [u8], @@ -1292,47 +882,48 @@ impl<'a> NumsOrArrays<'a> { } fn parse_and_copy_single( - t: SizedScalarType, + t: ScalarType, idx: usize, str_val: &str, radix: u32, output: &mut [u8], ) -> Result<(), PtxError> { match t { - SizedScalarType::B8 | SizedScalarType::U8 => { + ScalarType::B8 | ScalarType::U8 => { Self::parse_and_copy_single_t::<u8>(idx, str_val, radix, output)?; } - SizedScalarType::B16 | SizedScalarType::U16 => { + ScalarType::B16 | ScalarType::U16 => { Self::parse_and_copy_single_t::<u16>(idx, str_val, radix, output)?; } - SizedScalarType::B32 | SizedScalarType::U32 => { + ScalarType::B32 | ScalarType::U32 => { Self::parse_and_copy_single_t::<u32>(idx, str_val, radix, output)?; } - SizedScalarType::B64 | SizedScalarType::U64 => { + ScalarType::B64 | ScalarType::U64 => { Self::parse_and_copy_single_t::<u64>(idx, str_val, radix, output)?; } - SizedScalarType::S8 => { + ScalarType::S8 => { Self::parse_and_copy_single_t::<i8>(idx, str_val, radix, output)?; } - SizedScalarType::S16 => { + ScalarType::S16 => { Self::parse_and_copy_single_t::<i16>(idx, str_val, radix, output)?; } - SizedScalarType::S32 => { + ScalarType::S32 => { Self::parse_and_copy_single_t::<i32>(idx, str_val, radix, output)?; } - SizedScalarType::S64 => { + ScalarType::S64 => { Self::parse_and_copy_single_t::<i64>(idx, str_val, radix, output)?; } - SizedScalarType::F16 => { + ScalarType::F16 => { Self::parse_and_copy_single_t::<f16>(idx, str_val, radix, output)?; } - SizedScalarType::F16x2 => todo!(), - SizedScalarType::F32 => { + ScalarType::F16x2 => todo!(), + ScalarType::F32 => { Self::parse_and_copy_single_t::<f32>(idx, str_val, radix, output)?; } - SizedScalarType::F64 => { + ScalarType::F64 => { Self::parse_and_copy_single_t::<f64>(idx, str_val, radix, output)?; } + ScalarType::Pred => todo!(), } Ok(()) } @@ -1379,6 +970,40 @@ pub enum TuningDirective { MinNCtaPerSm(u32), } +#[derive(Clone, Copy, PartialEq, Eq)] +pub enum ScalarKind { + Bit, + Unsigned, + Signed, + Float, + Float2, + Pred, +} + +impl ScalarType { + pub fn kind(self) -> ScalarKind { + match self { + ScalarType::U8 => ScalarKind::Unsigned, + ScalarType::U16 => ScalarKind::Unsigned, + ScalarType::U32 => ScalarKind::Unsigned, + ScalarType::U64 => ScalarKind::Unsigned, + ScalarType::S8 => ScalarKind::Signed, + ScalarType::S16 => ScalarKind::Signed, + ScalarType::S32 => ScalarKind::Signed, + ScalarType::S64 => ScalarKind::Signed, + ScalarType::B8 => ScalarKind::Bit, + ScalarType::B16 => ScalarKind::Bit, + ScalarType::B32 => ScalarKind::Bit, + ScalarType::B64 => ScalarKind::Bit, + ScalarType::F16 => ScalarKind::Float, + ScalarType::F32 => ScalarKind::Float, + ScalarType::F64 => ScalarKind::Float, + ScalarType::F16x2 => ScalarKind::Float2, + ScalarType::Pred => ScalarKind::Pred, + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -1386,13 +1011,13 @@ mod tests { #[test] fn array_fails_multiple_0_dmiensions() { let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0, 0]).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut vec![0, 0]).is_err()); } #[test] fn array_fails_on_empty() { let inp = NumsOrArrays::Nums(Vec::new()); - assert!(inp.to_vec(SizedScalarType::B8, &mut vec![0]).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut vec![0]).is_err()); } #[test] @@ -1404,7 +1029,7 @@ mod tests { let mut dimensions = vec![0u32, 2]; assert_eq!( vec![1u8, 2, 3, 4], - inp.to_vec(SizedScalarType::B8, &mut dimensions).unwrap() + inp.to_vec(ScalarType::B8, &mut dimensions).unwrap() ); assert_eq!(dimensions, vec![2u32, 2]); } @@ -1416,7 +1041,7 @@ mod tests { NumsOrArrays::Arrays(vec![NumsOrArrays::Nums(vec![("1", 10)])]), ]); let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); } #[test] @@ -1426,6 +1051,6 @@ mod tests { NumsOrArrays::Nums(vec![("4", 10), ("5", 10)]), ]); let mut dimensions = vec![0u32, 2]; - assert!(inp.to_vec(SizedScalarType::B8, &mut dimensions).is_err()); + assert!(inp.to_vec(ScalarType::B8, &mut dimensions).is_err()); } } diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index 423fd57..b697317 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -343,10 +343,16 @@ TargetSpecifier = { Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = { AddressSize => None, - <f:Function> => Some(ast::Directive::Method(f)), + <f:Function> => { + let (linking, func) = f; + Some(ast::Directive::Method(linking, func)) + }, File => None, Section => None, - <v:ModuleVariable> ";" => Some(ast::Directive::Variable(v)), + <v:ModuleVariable> ";" => { + let (linking, var) = v; + Some(ast::Directive::Variable(linking, var)) + }, ! => { let err = <>; errors.push(err.error); @@ -358,11 +364,13 @@ AddressSize = { ".address_size" U8Num }; -Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = { - LinkingDirectives - <func_directive:MethodDecl> +Function: (ast::LinkingDirective, ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>>) = { + <linking:LinkingDirectives> + <func_directive:MethodDeclaration> <tuning:TuningDirective*> - <body:FunctionBody> => ast::Function{<>} + <body:FunctionBody> => { + (linking, ast::Function{func_directive, tuning, body}) + } }; LinkingDirective: ast::LinkingDirective = { @@ -388,44 +396,50 @@ LinkingDirectives: ast::LinkingDirective = { } } -MethodDecl: ast::MethodDecl<'input, &'input str> = { - ".entry" <name:ExtendedID> <in_args:KernelArguments> => - ast::MethodDecl::Kernel{ name, in_args }, - ".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => { - ast::MethodDecl::Func(ret_vals.unwrap_or_else(|| Vec::new()), name, params) +MethodDeclaration: ast::MethodDeclaration<'input, &'input str> = { + ".entry" <name:ExtendedID> <input_arguments:KernelArguments> => { + let return_arguments = Vec::new(); + let name = ast::MethodName::Kernel(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } + }, + ".func" <return_arguments:FnArguments?> <name:ExtendedID> <input_arguments:FnArguments> => { + let return_arguments = return_arguments.unwrap_or_else(|| Vec::new()); + let name = ast::MethodName::Func(name); + ast::MethodDeclaration{ return_arguments, name, input_arguments, shared_mem: None } } }; -KernelArguments: Vec<ast::KernelArgument<&'input str>> = { +KernelArguments: Vec<ast::Variable<&'input str>> = { "(" <args:Comma<KernelInput>> ")" => args }; -FnArguments: Vec<ast::FnArgument<&'input str>> = { +FnArguments: Vec<ast::Variable<&'input str>> = { "(" <args:Comma<FnInput>> ")" => args }; -KernelInput: ast::Variable<ast::KernelArgumentType, &'input str> = { +KernelInput: ast::Variable<&'input str> = { <v:ParamDeclaration> => { let (align, v_type, name) = v; ast::Variable { align, - v_type: ast::KernelArgumentType::Normal(v_type), + v_type, + state_space: ast::StateSpace::Param, name, array_init: Vec::new() } } } -FnInput: ast::Variable<ast::FnArgumentType, &'input str> = { +FnInput: ast::Variable<&'input str> = { <v:RegVariable> => { let (align, v_type, name) = v; - let v_type = ast::FnArgumentType::Reg(v_type); - ast::Variable{ align, v_type, name, array_init: Vec::new() } + let state_space = ast::StateSpace::Reg; + ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() } }, <v:ParamDeclaration> => { let (align, v_type, name) = v; - let v_type = ast::FnArgumentType::Param(v_type); - ast::Variable{ align, v_type, name, array_init: Vec::new() } + let state_space = ast::StateSpace::Param; + ast::Variable{ align, v_type, state_space, name, array_init: Vec::new() } } } @@ -508,141 +522,148 @@ VariableParam: u32 = { "<" <n:U32Num> ">" => n } -Variable: ast::Variable<ast::VariableType, &'input str> = { +Variable: ast::Variable<&'input str> = { <v:RegVariable> => { let (align, v_type, name) = v; - let v_type = ast::VariableType::Reg(v_type); - ast::Variable {align, v_type, name, array_init: Vec::new()} + let state_space = ast::StateSpace::Reg; + ast::Variable {align, v_type, state_space, name, array_init: Vec::new()} }, LocalVariable, <v:ParamVariable> => { let (align, array_init, v_type, name) = v; - let v_type = ast::VariableType::Param(v_type); - ast::Variable {align, v_type, name, array_init} + let state_space = ast::StateSpace::Param; + ast::Variable {align, v_type, state_space, name, array_init} }, SharedVariable, }; -RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = { +RegVariable: (Option<u32>, ast::Type, &'input str) = { ".reg" <var:VariableScalar<ScalarType>> => { let (align, t, name) = var; - let v_type = ast::VariableRegType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, v_type, name) }, ".reg" <var:VariableVector<SizedScalarType>> => { let (align, v_len, t, name) = var; - let v_type = ast::VariableRegType::Vector(t, v_len); + let v_type = ast::Type::Vector(t, v_len); (align, v_type, name) } } -LocalVariable: ast::Variable<ast::VariableType, &'input str> = { +LocalVariable: ast::Variable<&'input str> = { ".local" <var:VariableScalar<SizedScalarType>> => { let (align, t, name) = var; - let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t)); - ast::Variable { align, v_type, name, array_init: Vec::new() } + let v_type = ast::Type::Scalar(t); + let state_space = ast::StateSpace::Local; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".local" <var:VariableVector<SizedScalarType>> => { let (align, v_len, t, name) = var; - let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len)); - ast::Variable { align, v_type, name, array_init: Vec::new() } + let v_type = ast::Type::Vector(t, v_len); + let state_space = ast::StateSpace::Local; + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".local" <var:VariableArrayOrPointer<SizedScalarType>> =>? { let (align, t, name, arr_or_ptr) = var; + let state_space = ast::StateSpace::Local; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableLocalType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); } }; - Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init }) + Ok(ast::Variable { align, v_type, state_space, name, array_init }) } } -SharedVariable: ast::Variable<ast::VariableType, &'input str> = { +SharedVariable: ast::Variable<&'input str> = { ".shared" <var:VariableScalar<SizedScalarType>> => { let (align, t, name) = var; - let v_type = ast::VariableGlobalType::Scalar(t); - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + let v_type = ast::Type::Scalar(t); + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".shared" <var:VariableVector<SizedScalarType>> => { let (align, v_len, t, name) = var; - let v_type = ast::VariableGlobalType::Vector(t, v_len); - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + let v_type = ast::Type::Vector(t, v_len); + ast::Variable { align, v_type, state_space, name, array_init: Vec::new() } }, ".shared" <var:VariableArrayOrPointer<SizedScalarType>> =>? { let (align, t, name, arr_or_ptr) = var; + let state_space = ast::StateSpace::Shared; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableGlobalType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }); } }; - Ok(ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init }) + Ok(ast::Variable { align, v_type, state_space, name, array_init }) } } - -ModuleVariable: ast::Variable<ast::VariableType, &'input str> = { - LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => { +ModuleVariable: (ast::LinkingDirective, ast::Variable<&'input str>) = { + <linking:LinkingDirectives> ".global" <def:GlobalVariableDefinitionNoArray> => { let (align, v_type, name, array_init) = def; - ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init } + let state_space = ast::StateSpace::Global; + (linking, ast::Variable { align, v_type, state_space, name, array_init }) }, - LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => { + <linking:LinkingDirectives> ".shared" <def:GlobalVariableDefinitionNoArray> => { let (align, v_type, name, array_init) = def; - ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() } + let state_space = ast::StateSpace::Shared; + (linking, ast::Variable { align, v_type, state_space, name, array_init: Vec::new() }) }, - <ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? { + <linking:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? { let (align, t, name, arr_or_ptr) = var; - let (v_type, array_init) = match arr_or_ptr { + let (v_type, state_space, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { if space == ".global" { - (ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init) + (ast::Type::Array(t, dimensions), ast::StateSpace::Global, init) } else { - (ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init) + (ast::Type::Array(t, dimensions), ast::StateSpace::Shared, init) } } ast::ArrayOrPointer::Pointer => { - if !ldirs.contains(ast::LinkingDirective::EXTERN) { + if !linking.contains(ast::LinkingDirective::EXTERN) { return Err(ParseError::User { error: ast::PtxError::NonExternPointer }); } if space == ".global" { - (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Global)), Vec::new()) + (ast::Type::Array(t, Vec::new()), ast::StateSpace::Global, Vec::new()) } else { - (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Shared)), Vec::new()) + (ast::Type::Array(t, Vec::new()), ast::StateSpace::Shared, Vec::new()) } } }; - Ok(ast::Variable{ align, array_init, v_type, name }) + Ok((linking, ast::Variable{ align, v_type, state_space, name, array_init })) } } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space -ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = { +ParamVariable: (Option<u32>, Vec<u8>, ast::Type, &'input str) = { ".param" <var:VariableScalar<LdStScalarType>> => { let (align, t, name) = var; - let v_type = ast::VariableParamType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, Vec::new(), v_type, name) }, ".param" <var:VariableArrayOrPointer<SizedScalarType>> => { let (align, t, name, arr_or_ptr) = var; let (v_type, array_init) = match arr_or_ptr { ast::ArrayOrPointer::Array { dimensions, init } => { - (ast::VariableParamType::Array(t, dimensions), init) + (ast::Type::Array(t, dimensions), init) } ast::ArrayOrPointer::Pointer => { - (ast::VariableParamType::Pointer(t, ast::PointerStateSpace::Param), Vec::new()) + (ast::Type::Scalar(t), Vec::new()) } }; (align, array_init, v_type, name) } } -ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = { +ParamDeclaration: (Option<u32>, ast::Type, &'input str) = { <var:ParamVariable> =>? { let (align, array_init, v_type, name) = var; if array_init.len() > 0 { @@ -653,56 +674,56 @@ ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = { } } -GlobalVariableDefinitionNoArray: (Option<u32>, ast::VariableGlobalType, &'input str, Vec<u8>) = { +GlobalVariableDefinitionNoArray: (Option<u32>, ast::Type, &'input str, Vec<u8>) = { <scalar:VariableScalar<SizedScalarType>> => { let (align, t, name) = scalar; - let v_type = ast::VariableGlobalType::Scalar(t); + let v_type = ast::Type::Scalar(t); (align, v_type, name, Vec::new()) }, <var:VariableVector<SizedScalarType>> => { let (align, v_len, t, name) = var; - let v_type = ast::VariableGlobalType::Vector(t, v_len); + let v_type = ast::Type::Vector(t, v_len); (align, v_type, name, Vec::new()) }, } #[inline] -SizedScalarType: ast::SizedScalarType = { - ".b8" => ast::SizedScalarType::B8, - ".b16" => ast::SizedScalarType::B16, - ".b32" => ast::SizedScalarType::B32, - ".b64" => ast::SizedScalarType::B64, - ".u8" => ast::SizedScalarType::U8, - ".u16" => ast::SizedScalarType::U16, - ".u32" => ast::SizedScalarType::U32, - ".u64" => ast::SizedScalarType::U64, - ".s8" => ast::SizedScalarType::S8, - ".s16" => ast::SizedScalarType::S16, - ".s32" => ast::SizedScalarType::S32, - ".s64" => ast::SizedScalarType::S64, - ".f16" => ast::SizedScalarType::F16, - ".f16x2" => ast::SizedScalarType::F16x2, - ".f32" => ast::SizedScalarType::F32, - ".f64" => ast::SizedScalarType::F64, +SizedScalarType: ast::ScalarType = { + ".b8" => ast::ScalarType::B8, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f16" => ast::ScalarType::F16, + ".f16x2" => ast::ScalarType::F16x2, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, } #[inline] -LdStScalarType: ast::LdStScalarType = { - ".b8" => ast::LdStScalarType::B8, - ".b16" => ast::LdStScalarType::B16, - ".b32" => ast::LdStScalarType::B32, - ".b64" => ast::LdStScalarType::B64, - ".u8" => ast::LdStScalarType::U8, - ".u16" => ast::LdStScalarType::U16, - ".u32" => ast::LdStScalarType::U32, - ".u64" => ast::LdStScalarType::U64, - ".s8" => ast::LdStScalarType::S8, - ".s16" => ast::LdStScalarType::S16, - ".s32" => ast::LdStScalarType::S32, - ".s64" => ast::LdStScalarType::S64, - ".f16" => ast::LdStScalarType::F16, - ".f32" => ast::LdStScalarType::F32, - ".f64" => ast::LdStScalarType::F64, +LdStScalarType: ast::ScalarType = { + ".b8" => ast::ScalarType::B8, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f16" => ast::ScalarType::F16, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, } Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = { @@ -755,7 +776,7 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::Ld( ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ss.unwrap_or(ast::LdStateSpace::Generic), + state_space: ss.unwrap_or(ast::StateSpace::Generic), caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t, non_coherent: false @@ -767,7 +788,7 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::Ld( ast::LdDetails { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ast::LdStateSpace::Global, + state_space: ast::StateSpace::Global, caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t, non_coherent: false @@ -779,7 +800,7 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::Ld( ast::LdDetails { qualifier: ast::LdStQualifier::Weak, - state_space: ast::LdStateSpace::Global, + state_space: ast::StateSpace::Global, caching: cop.unwrap_or(ast::LdCacheOperator::Cached), typ: t, non_coherent: true @@ -789,9 +810,9 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = { } }; -LdStType: ast::LdStType = { - <v:VectorPrefix> <t:LdStScalarType> => ast::LdStType::Vector(t, v), - <t:LdStScalarType> => ast::LdStType::Scalar(t), +LdStType: ast::Type = { + <v:VectorPrefix> <t:LdStScalarType> => ast::Type::Vector(t, v), + <t:LdStScalarType> => ast::Type::Scalar(t), } LdStQualifier: ast::LdStQualifier = { @@ -807,11 +828,11 @@ MemScope: ast::MemScope = { ".sys" => ast::MemScope::Sys }; -LdNonGlobalStateSpace: ast::LdStateSpace = { - ".const" => ast::LdStateSpace::Const, - ".local" => ast::LdStateSpace::Local, - ".param" => ast::LdStateSpace::Param, - ".shared" => ast::LdStateSpace::Shared, +LdNonGlobalStateSpace: ast::StateSpace = { + ".const" => ast::StateSpace::Const, + ".local" => ast::StateSpace::Local, + ".param" => ast::StateSpace::Param, + ".shared" => ast::StateSpace::Shared, }; LdCacheOperator: ast::LdCacheOperator = { @@ -899,39 +920,39 @@ RoundingModeInt : ast::RoundingMode = { ".rpi" => ast::RoundingMode::PositiveInf, }; -IntType : ast::IntType = { - ".u16" => ast::IntType::U16, - ".u32" => ast::IntType::U32, - ".u64" => ast::IntType::U64, - ".s16" => ast::IntType::S16, - ".s32" => ast::IntType::S32, - ".s64" => ast::IntType::S64, +IntType : ast::ScalarType = { + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; -IntType3264: ast::IntType = { - ".u32" => ast::IntType::U32, - ".u64" => ast::IntType::U64, - ".s32" => ast::IntType::S32, - ".s64" => ast::IntType::S64, +IntType3264: ast::ScalarType = { + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, } -UIntType: ast::UIntType = { - ".u16" => ast::UIntType::U16, - ".u32" => ast::UIntType::U32, - ".u64" => ast::UIntType::U64, +UIntType: ast::ScalarType = { + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, }; -SIntType: ast::SIntType = { - ".s16" => ast::SIntType::S16, - ".s32" => ast::SIntType::S32, - ".s64" => ast::SIntType::S64, +SIntType: ast::ScalarType = { + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; -FloatType: ast::FloatType = { - ".f16" => ast::FloatType::F16, - ".f16x2" => ast::FloatType::F16x2, - ".f32" => ast::FloatType::F32, - ".f64" => ast::FloatType::F64, +FloatType: ast::ScalarType = { + ".f16" => ast::ScalarType::F16, + ".f16x2" => ast::ScalarType::F16x2, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#integer-arithmetic-instructions-add @@ -1023,11 +1044,11 @@ InstNot: ast::Instruction<ast::ParsedArgParams<'input>> = { "not" <t:BooleanType> <a:Arg2> => ast::Instruction::Not(t, a) }; -BooleanType: ast::BooleanType = { - ".pred" => ast::BooleanType::Pred, - ".b16" => ast::BooleanType::B16, - ".b32" => ast::BooleanType::B32, - ".b64" => ast::BooleanType::B64, +BooleanType: ast::ScalarType = { + ".pred" => ast::ScalarType::Pred, + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#control-flow-instructions-at @@ -1080,8 +1101,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { rounding: r, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F16, - src: ast::FloatType::F16 + dst: ast::ScalarType::F16, + src: ast::ScalarType::F16 } ), a) }, @@ -1091,8 +1112,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { rounding: None, flush_to_zero: Some(f.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F32, - src: ast::FloatType::F16 + dst: ast::ScalarType::F32, + src: ast::ScalarType::F16 } ), a) }, @@ -1102,8 +1123,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { rounding: None, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F64, - src: ast::FloatType::F16 + dst: ast::ScalarType::F64, + src: ast::ScalarType::F16 } ), a) }, @@ -1113,8 +1134,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { rounding: Some(r), flush_to_zero: Some(f.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F16, - src: ast::FloatType::F32 + dst: ast::ScalarType::F16, + src: ast::ScalarType::F32 } ), a) }, @@ -1124,8 +1145,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { rounding: r, flush_to_zero: Some(f.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F32, - src: ast::FloatType::F32 + dst: ast::ScalarType::F32, + src: ast::ScalarType::F32 } ), a) }, @@ -1135,8 +1156,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { rounding: None, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F64, - src: ast::FloatType::F32 + dst: ast::ScalarType::F64, + src: ast::ScalarType::F32 } ), a) }, @@ -1146,8 +1167,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { rounding: Some(r), flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F16, - src: ast::FloatType::F64 + dst: ast::ScalarType::F16, + src: ast::ScalarType::F64 } ), a) }, @@ -1157,8 +1178,8 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { rounding: Some(r), flush_to_zero: Some(s.is_some()), saturate: s.is_some(), - dst: ast::FloatType::F32, - src: ast::FloatType::F64 + dst: ast::ScalarType::F32, + src: ast::ScalarType::F64 } ), a) }, @@ -1168,28 +1189,28 @@ InstCvt: ast::Instruction<ast::ParsedArgParams<'input>> = { rounding: r, flush_to_zero: None, saturate: s.is_some(), - dst: ast::FloatType::F64, - src: ast::FloatType::F64 + dst: ast::ScalarType::F64, + src: ast::ScalarType::F64 } ), a) }, }; -CvtTypeInt: ast::IntType = { - ".u8" => ast::IntType::U8, - ".u16" => ast::IntType::U16, - ".u32" => ast::IntType::U32, - ".u64" => ast::IntType::U64, - ".s8" => ast::IntType::S8, - ".s16" => ast::IntType::S16, - ".s32" => ast::IntType::S32, - ".s64" => ast::IntType::S64, +CvtTypeInt: ast::ScalarType = { + ".u8" => ast::ScalarType::U8, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s8" => ast::ScalarType::S8, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; -CvtTypeFloat: ast::FloatType = { - ".f16" => ast::FloatType::F16, - ".f32" => ast::FloatType::F32, - ".f64" => ast::FloatType::F64, +CvtTypeFloat: ast::ScalarType = { + ".f16" => ast::ScalarType::F16, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shl @@ -1197,10 +1218,10 @@ InstShl: ast::Instruction<ast::ParsedArgParams<'input>> = { "shl" <t:ShlType> <a:Arg3> => ast::Instruction::Shl(t, a) }; -ShlType: ast::ShlType = { - ".b16" => ast::ShlType::B16, - ".b32" => ast::ShlType::B32, - ".b64" => ast::ShlType::B64, +ShlType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-shr @@ -1208,16 +1229,16 @@ InstShr: ast::Instruction<ast::ParsedArgParams<'input>> = { "shr" <t:ShrType> <a:Arg3> => ast::Instruction::Shr(t, a) }; -ShrType: ast::ShrType = { - ".b16" => ast::ShrType::B16, - ".b32" => ast::ShrType::B32, - ".b64" => ast::ShrType::B64, - ".u16" => ast::ShrType::U16, - ".u32" => ast::ShrType::U32, - ".u64" => ast::ShrType::U64, - ".s16" => ast::ShrType::S16, - ".s32" => ast::ShrType::S32, - ".s64" => ast::ShrType::S64, +ShrType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st @@ -1227,7 +1248,7 @@ InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = { ast::Instruction::St( ast::StData { qualifier: q.unwrap_or(ast::LdStQualifier::Weak), - state_space: ss.unwrap_or(ast::StStateSpace::Generic), + state_space: ss.unwrap_or(ast::StateSpace::Generic), caching: cop.unwrap_or(ast::StCacheOperator::Writeback), typ: t }, @@ -1241,11 +1262,11 @@ MemoryOperand: ast::Operand<&'input str> = { "[" <o:Operand> "]" => o } -StStateSpace: ast::StStateSpace = { - ".global" => ast::StStateSpace::Global, - ".local" => ast::StStateSpace::Local, - ".param" => ast::StStateSpace::Param, - ".shared" => ast::StStateSpace::Shared, +StStateSpace: ast::StateSpace = { + ".global" => ast::StateSpace::Global, + ".local" => ast::StateSpace::Local, + ".param" => ast::StateSpace::Param, + ".shared" => ast::StateSpace::Shared, }; StCacheOperator: ast::StCacheOperator = { @@ -1264,7 +1285,7 @@ InstRet: ast::Instruction<ast::ParsedArgParams<'input>> = { InstCvta: ast::Instruction<ast::ParsedArgParams<'input>> = { "cvta" <from:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => { ast::Instruction::Cvta(ast::CvtaDetails { - to: ast::CvtaStateSpace::Generic, + to: ast::StateSpace::Generic, from, size: s }, @@ -1273,18 +1294,18 @@ InstCvta: ast::Instruction<ast::ParsedArgParams<'input>> = { "cvta" ".to" <to:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => { ast::Instruction::Cvta(ast::CvtaDetails { to, - from: ast::CvtaStateSpace::Generic, + from: ast::StateSpace::Generic, size: s }, a) } } -CvtaStateSpace: ast::CvtaStateSpace = { - ".const" => ast::CvtaStateSpace::Const, - ".global" => ast::CvtaStateSpace::Global, - ".local" => ast::CvtaStateSpace::Local, - ".shared" => ast::CvtaStateSpace::Shared, +CvtaStateSpace: ast::StateSpace = { + ".const" => ast::StateSpace::Const, + ".global" => ast::StateSpace::Global, + ".local" => ast::StateSpace::Local, + ".shared" => ast::StateSpace::Shared, } CvtaSize: ast::CvtaSize = { @@ -1393,16 +1414,16 @@ MinMaxDetails: ast::MinMaxDetails = { <t:UIntType> => ast::MinMaxDetails::Unsigned(t), <t:SIntType> => ast::MinMaxDetails::Signed(t), <ftz:".ftz"?> <nan:".NaN"?> ".f32" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F32 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F32 } ), ".f64" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::FloatType::F64 } + ast::MinMaxFloat{ flush_to_zero: None, nan: false, typ: ast::ScalarType::F64 } ), <ftz:".ftz"?> <nan:".NaN"?> ".f16" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16 } ), <ftz:".ftz"?> <nan:".NaN"?> ".f16x2" => ast::MinMaxDetails::Float( - ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::FloatType::F16x2 } + ast::MinMaxFloat{ flush_to_zero: Some(ftz.is_some()), nan: nan.is_some(), typ: ast::ScalarType::F16x2 } ) } @@ -1411,18 +1432,18 @@ InstSelp: ast::Instruction<ast::ParsedArgParams<'input>> = { "selp" <t:SelpType> <a:Arg4> => ast::Instruction::Selp(t, a), }; -SelpType: ast::SelpType = { - ".b16" => ast::SelpType::B16, - ".b32" => ast::SelpType::B32, - ".b64" => ast::SelpType::B64, - ".u16" => ast::SelpType::U16, - ".u32" => ast::SelpType::U32, - ".u64" => ast::SelpType::U64, - ".s16" => ast::SelpType::S16, - ".s32" => ast::SelpType::S32, - ".s64" => ast::SelpType::S64, - ".f32" => ast::SelpType::F32, - ".f64" => ast::SelpType::F64, +SelpType: ast::ScalarType = { + ".b16" => ast::ScalarType::B16, + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, + ".u16" => ast::ScalarType::U16, + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, + ".s16" => ast::ScalarType::S16, + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, + ".f32" => ast::ScalarType::F32, + ".f64" => ast::ScalarType::F64, }; // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-bar @@ -1442,7 +1463,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Bit { op, typ } }; ast::Instruction::Atom(details,a) @@ -1451,10 +1472,10 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Inc, - typ: ast::UIntType::U32 + typ: ast::ScalarType::U32 } }; ast::Instruction::Atom(details,a) @@ -1463,10 +1484,10 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Dec, - typ: ast::UIntType::U32 + typ: ast::ScalarType::U32 } }; ast::Instruction::Atom(details,a) @@ -1476,7 +1497,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Float { op, typ } }; ast::Instruction::Atom(details,a) @@ -1485,7 +1506,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Unsigned { op, typ } }; ast::Instruction::Atom(details,a) @@ -1494,7 +1515,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = { let details = ast::AtomDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), inner: ast::AtomInnerDetails::Signed { op, typ } }; ast::Instruction::Atom(details,a) @@ -1506,7 +1527,7 @@ InstAtomCas: ast::Instruction<ast::ParsedArgParams<'input>> = { let details = ast::AtomCasDetails { semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed), scope: scope.unwrap_or(ast::MemScope::Gpu), - space: space.unwrap_or(ast::AtomSpace::Generic), + space: space.unwrap_or(ast::StateSpace::Generic), typ, }; ast::Instruction::AtomCas(details,a) @@ -1520,9 +1541,9 @@ AtomSemantics: ast::AtomSemantics = { ".acq_rel" => ast::AtomSemantics::AcquireRelease } -AtomSpace: ast::AtomSpace = { - ".global" => ast::AtomSpace::Global, - ".shared" => ast::AtomSpace::Shared +AtomSpace: ast::StateSpace = { + ".global" => ast::StateSpace::Global, + ".shared" => ast::StateSpace::Shared } AtomBitOp: ast::AtomBitOp = { @@ -1544,19 +1565,19 @@ AtomSIntOp: ast::AtomSIntOp = { ".max" => ast::AtomSIntOp::Max, } -BitType: ast::BitType = { - ".b32" => ast::BitType::B32, - ".b64" => ast::BitType::B64, +BitType: ast::ScalarType = { + ".b32" => ast::ScalarType::B32, + ".b64" => ast::ScalarType::B64, } -UIntType3264: ast::UIntType = { - ".u32" => ast::UIntType::U32, - ".u64" => ast::UIntType::U64, +UIntType3264: ast::ScalarType = { + ".u32" => ast::ScalarType::U32, + ".u64" => ast::ScalarType::U64, } -SIntType3264: ast::SIntType = { - ".s32" => ast::SIntType::S32, - ".s64" => ast::SIntType::S64, +SIntType3264: ast::ScalarType = { + ".s32" => ast::ScalarType::S32, + ".s64" => ast::ScalarType::S64, } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#floating-point-instructions-div @@ -1566,7 +1587,7 @@ InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = { "div" <t:SIntType> <a:Arg3> => ast::Instruction::Div(ast::DivDetails::Signed(t), a), "div" <kind:DivFloatKind> <ftz:".ftz"?> ".f32" <a:Arg3> => { let inner = ast::DivFloatDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: Some(ftz.is_some()), kind }; @@ -1574,7 +1595,7 @@ InstDiv: ast::Instruction<ast::ParsedArgParams<'input>> = { }, "div" <rnd:RoundingModeFloat> ".f64" <a:Arg3> => { let inner = ast::DivFloatDetails { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, flush_to_zero: None, kind: ast::DivFloatKind::Rounding(rnd) }; @@ -1592,7 +1613,7 @@ DivFloatKind: ast::DivFloatKind = { InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = { "sqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => { let details = ast::SqrtDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: Some(ftz.is_some()), kind: ast::SqrtKind::Approx, }; @@ -1600,7 +1621,7 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = { }, "sqrt" <rnd:RoundingModeFloat> <ftz:".ftz"?> ".f32" <a:Arg2> => { let details = ast::SqrtDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: Some(ftz.is_some()), kind: ast::SqrtKind::Rounding(rnd), }; @@ -1608,7 +1629,7 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = { }, "sqrt" <rnd:RoundingModeFloat> ".f64" <a:Arg2> => { let details = ast::SqrtDetails { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, flush_to_zero: None, kind: ast::SqrtKind::Rounding(rnd), }; @@ -1621,14 +1642,14 @@ InstSqrt: ast::Instruction<ast::ParsedArgParams<'input>> = { InstRsqrt: ast::Instruction<ast::ParsedArgParams<'input>> = { "rsqrt" ".approx" <ftz:".ftz"?> ".f32" <a:Arg2> => { let details = ast::RsqrtDetails { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, flush_to_zero: ftz.is_some(), }; ast::Instruction::Rsqrt(details, a) }, "rsqrt" ".approx" <ftz:".ftz"?> ".f64" <a:Arg2> => { let details = ast::RsqrtDetails { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, flush_to_zero: ftz.is_some(), }; ast::Instruction::Rsqrt(details, a) @@ -1739,7 +1760,7 @@ ArithDetails: ast::ArithDetails = { saturate: false, }), ".sat" ".s32" => ast::ArithDetails::Signed(ast::ArithSInt { - typ: ast::SIntType::S32, + typ: ast::ScalarType::S32, saturate: true, }), <f:ArithFloat> => ast::ArithDetails::Float(f) @@ -1747,25 +1768,25 @@ ArithDetails: ast::ArithDetails = { ArithFloat: ast::ArithFloat = { <rn:RoundingModeFloat?> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, rounding: rn, flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, <rn:RoundingModeFloat?> ".f64" => ast::ArithFloat { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, rounding: rn, flush_to_zero: None, saturate: false, }, <rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat { - typ: ast::FloatType::F16, + typ: ast::ScalarType::F16, rounding: rn.map(|_| ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, <rn:".rn"?> <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat { - typ: ast::FloatType::F16x2, + typ: ast::ScalarType::F16x2, rounding: rn.map(|_| ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), @@ -1774,25 +1795,25 @@ ArithFloat: ast::ArithFloat = { ArithFloatMustRound: ast::ArithFloat = { <rn:RoundingModeFloat> <ftz:".ftz"?> <sat:".sat"?> ".f32" => ast::ArithFloat { - typ: ast::FloatType::F32, + typ: ast::ScalarType::F32, rounding: Some(rn), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, <rn:RoundingModeFloat> ".f64" => ast::ArithFloat { - typ: ast::FloatType::F64, + typ: ast::ScalarType::F64, rounding: Some(rn), flush_to_zero: None, saturate: false, }, ".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16" => ast::ArithFloat { - typ: ast::FloatType::F16, + typ: ast::ScalarType::F16, rounding: Some(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), }, ".rn" <ftz:".ftz"?> <sat:".sat"?> ".f16x2" => ast::ArithFloat { - typ: ast::FloatType::F16x2, + typ: ast::ScalarType::F16x2, rounding: Some(ast::RoundingMode::NearestEven), flush_to_zero: Some(ftz.is_some()), saturate: sat.is_some(), diff --git a/ptx/src/test/spirv_run/and.spvtxt b/ptx/src/test/spirv_run/and.spvtxt index a378602..f66639a 100644 --- a/ptx/src/test/spirv_run/and.spvtxt +++ b/ptx/src/test/spirv_run/and.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %34 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %41 = OpBitcast %_ptr_Generic_uchar %24 + %42 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %41 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %42 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/atom_add.spvtxt b/ptx/src/test/spirv_run/atom_add.spvtxt index 3966da6..b4de00a 100644 --- a/ptx/src/test/spirv_run/atom_add.spvtxt +++ b/ptx/src/test/spirv_run/atom_add.spvtxt @@ -24,6 +24,7 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %_ptr_Workgroup_uint = OpTypePointer Workgroup %uint %uint_1 = OpConstant %uint 1 %uint_0 = OpConstant %uint 0 @@ -49,9 +50,11 @@ %13 = OpLoad %uint %29 Aligned 4 OpStore %7 %13 %16 = OpLoad %ulong %5 - %26 = OpIAdd %ulong %16 %ulong_4 - %30 = OpConvertUToPtr %_ptr_Generic_uint %26 - %15 = OpLoad %uint %30 Aligned 4 + %30 = OpConvertUToPtr %_ptr_Generic_uint %16 + %51 = OpBitcast %_ptr_Generic_uchar %30 + %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4 + %26 = OpBitcast %_ptr_Generic_uint %52 + %15 = OpLoad %uint %26 Aligned 4 OpStore %8 %15 %17 = OpLoad %uint %7 %31 = OpBitcast %_ptr_Workgroup_uint %4 @@ -69,8 +72,10 @@ OpStore %34 %22 Aligned 4 %23 = OpLoad %ulong %6 %24 = OpLoad %uint %8 - %28 = OpIAdd %ulong %23 %ulong_4_0 - %35 = OpConvertUToPtr %_ptr_Generic_uint %28 - OpStore %35 %24 Aligned 4 + %35 = OpConvertUToPtr %_ptr_Generic_uint %23 + %56 = OpBitcast %_ptr_Generic_uchar %35 + %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4_0 + %28 = OpBitcast %_ptr_Generic_uint %57 + OpStore %28 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_add_float.spvtxt b/ptx/src/test/spirv_run/atom_add_float.spvtxt index c2292f1..7d25632 100644 --- a/ptx/src/test/spirv_run/atom_add_float.spvtxt +++ b/ptx/src/test/spirv_run/atom_add_float.spvtxt @@ -28,6 +28,7 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_4_0 = OpConstant %ulong 4 %37 = OpFunction %float None %46 %39 = OpFunctionParameter %_ptr_Workgroup_float @@ -54,9 +55,11 @@ %13 = OpLoad %float %29 Aligned 4 OpStore %7 %13 %16 = OpLoad %ulong %5 - %26 = OpIAdd %ulong %16 %ulong_4 - %30 = OpConvertUToPtr %_ptr_Generic_float %26 - %15 = OpLoad %float %30 Aligned 4 + %30 = OpConvertUToPtr %_ptr_Generic_float %16 + %58 = OpBitcast %_ptr_Generic_uchar %30 + %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_4 + %26 = OpBitcast %_ptr_Generic_float %59 + %15 = OpLoad %float %26 Aligned 4 OpStore %8 %15 %17 = OpLoad %float %7 %31 = OpBitcast %_ptr_Workgroup_float %4 @@ -74,8 +77,10 @@ OpStore %34 %22 Aligned 4 %23 = OpLoad %ulong %6 %24 = OpLoad %float %8 - %28 = OpIAdd %ulong %23 %ulong_4_0 - %35 = OpConvertUToPtr %_ptr_Generic_float %28 - OpStore %35 %24 Aligned 4 + %35 = OpConvertUToPtr %_ptr_Generic_float %23 + %60 = OpBitcast %_ptr_Generic_uchar %35 + %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4_0 + %28 = OpBitcast %_ptr_Generic_float %61 + OpStore %28 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_cas.spvtxt b/ptx/src/test/spirv_run/atom_cas.spvtxt index e1feb0a..7c2f4fa 100644 --- a/ptx/src/test/spirv_run/atom_cas.spvtxt +++ b/ptx/src/test/spirv_run/atom_cas.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %uint_100 = OpConstant %uint 100 %uint_1 = OpConstant %uint 1 %uint_0 = OpConstant %uint 0 @@ -45,16 +47,20 @@ OpStore %6 %12 %15 = OpLoad %ulong %4 %16 = OpLoad %uint %6 - %24 = OpIAdd %ulong %15 %ulong_4 - %32 = OpConvertUToPtr %_ptr_Generic_uint %24 + %31 = OpConvertUToPtr %_ptr_Generic_uint %15 + %49 = OpBitcast %_ptr_Generic_uchar %31 + %50 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %49 %ulong_4 + %24 = OpBitcast %_ptr_Generic_uint %50 %33 = OpCopyObject %uint %16 - %31 = OpAtomicCompareExchange %uint %32 %uint_1 %uint_0 %uint_0 %uint_100 %33 - %14 = OpCopyObject %uint %31 + %32 = OpAtomicCompareExchange %uint %24 %uint_1 %uint_0 %uint_0 %uint_100 %33 + %14 = OpCopyObject %uint %32 OpStore %6 %14 %18 = OpLoad %ulong %4 - %27 = OpIAdd %ulong %18 %ulong_4_0 - %34 = OpConvertUToPtr %_ptr_Generic_uint %27 - %17 = OpLoad %uint %34 Aligned 4 + %34 = OpConvertUToPtr %_ptr_Generic_uint %18 + %53 = OpBitcast %_ptr_Generic_uchar %34 + %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_4_0 + %27 = OpBitcast %_ptr_Generic_uint %54 + %17 = OpLoad %uint %27 Aligned 4 OpStore %7 %17 %19 = OpLoad %ulong %5 %20 = OpLoad %uint %6 @@ -62,8 +68,10 @@ OpStore %35 %20 Aligned 4 %21 = OpLoad %ulong %5 %22 = OpLoad %uint %7 - %29 = OpIAdd %ulong %21 %ulong_4_1 - %36 = OpConvertUToPtr %_ptr_Generic_uint %29 - OpStore %36 %22 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_uint %21 + %55 = OpBitcast %_ptr_Generic_uchar %36 + %56 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %55 %ulong_4_1 + %29 = OpBitcast %_ptr_Generic_uint %56 + OpStore %29 %22 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/atom_inc.spvtxt b/ptx/src/test/spirv_run/atom_inc.spvtxt index 11b4243..4855cd4 100644 --- a/ptx/src/test/spirv_run/atom_inc.spvtxt +++ b/ptx/src/test/spirv_run/atom_inc.spvtxt @@ -10,14 +10,14 @@ %47 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "atom_inc" - OpDecorate %42 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_global_inc" Import OpDecorate %38 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_generic_inc" Import + OpDecorate %42 LinkageAttributes "__zluda_ptx_impl__atom_relaxed_gpu_global_inc" Import %void = OpTypeVoid %uint = OpTypeInt 32 0 -%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint - %51 = OpTypeFunction %uint %_ptr_CrossWorkgroup_uint %uint %_ptr_Generic_uint = OpTypePointer Generic %uint - %53 = OpTypeFunction %uint %_ptr_Generic_uint %uint + %51 = OpTypeFunction %uint %_ptr_Generic_uint %uint +%_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint + %53 = OpTypeFunction %uint %_ptr_CrossWorkgroup_uint %uint %ulong = OpTypeInt 64 0 %55 = OpTypeFunction %void %ulong %ulong %_ptr_Function_ulong = OpTypePointer Function %ulong @@ -25,15 +25,17 @@ %uint_101 = OpConstant %uint 101 %uint_101_0 = OpConstant %uint 101 %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 - %42 = OpFunction %uint None %51 - %44 = OpFunctionParameter %_ptr_CrossWorkgroup_uint - %45 = OpFunctionParameter %uint - OpFunctionEnd - %38 = OpFunction %uint None %53 + %38 = OpFunction %uint None %51 %40 = OpFunctionParameter %_ptr_Generic_uint %41 = OpFunctionParameter %uint OpFunctionEnd + %42 = OpFunction %uint None %53 + %44 = OpFunctionParameter %_ptr_CrossWorkgroup_uint + %45 = OpFunctionParameter %uint + OpFunctionEnd %1 = OpFunction %void None %55 %9 = OpFunctionParameter %ulong %10 = OpFunctionParameter %ulong @@ -69,13 +71,17 @@ OpStore %34 %20 Aligned 4 %21 = OpLoad %ulong %5 %22 = OpLoad %uint %7 - %28 = OpIAdd %ulong %21 %ulong_4 - %35 = OpConvertUToPtr %_ptr_Generic_uint %28 - OpStore %35 %22 Aligned 4 + %35 = OpConvertUToPtr %_ptr_Generic_uint %21 + %60 = OpBitcast %_ptr_Generic_uchar %35 + %61 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %60 %ulong_4 + %28 = OpBitcast %_ptr_Generic_uint %61 + OpStore %28 %22 Aligned 4 %23 = OpLoad %ulong %5 %24 = OpLoad %uint %8 - %30 = OpIAdd %ulong %23 %ulong_8 - %36 = OpConvertUToPtr %_ptr_Generic_uint %30 - OpStore %36 %24 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_uint %23 + %62 = OpBitcast %_ptr_Generic_uchar %36 + %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_8 + %30 = OpBitcast %_ptr_Generic_uint %63 + OpStore %30 %24 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/bfe.spvtxt b/ptx/src/test/spirv_run/bfe.spvtxt index 535ede9..0001808 100644 --- a/ptx/src/test/spirv_run/bfe.spvtxt +++ b/ptx/src/test/spirv_run/bfe.spvtxt @@ -20,6 +20,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %34 = OpFunction %uint None %43 %36 = OpFunctionParameter %uint @@ -48,14 +50,18 @@ %13 = OpLoad %uint %29 Aligned 4 OpStore %6 %13 %16 = OpLoad %ulong %4 - %26 = OpIAdd %ulong %16 %ulong_4 - %30 = OpConvertUToPtr %_ptr_Generic_uint %26 - %15 = OpLoad %uint %30 Aligned 4 + %30 = OpConvertUToPtr %_ptr_Generic_uint %16 + %51 = OpBitcast %_ptr_Generic_uchar %30 + %52 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %51 %ulong_4 + %26 = OpBitcast %_ptr_Generic_uint %52 + %15 = OpLoad %uint %26 Aligned 4 OpStore %7 %15 %18 = OpLoad %ulong %4 - %28 = OpIAdd %ulong %18 %ulong_8 - %31 = OpConvertUToPtr %_ptr_Generic_uint %28 - %17 = OpLoad %uint %31 Aligned 4 + %31 = OpConvertUToPtr %_ptr_Generic_uint %18 + %53 = OpBitcast %_ptr_Generic_uchar %31 + %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_8 + %28 = OpBitcast %_ptr_Generic_uint %54 + %17 = OpLoad %uint %28 Aligned 4 OpStore %8 %17 %20 = OpLoad %uint %6 %21 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/bfi.spvtxt b/ptx/src/test/spirv_run/bfi.spvtxt index a226f78..1979939 100644 --- a/ptx/src/test/spirv_run/bfi.spvtxt +++ b/ptx/src/test/spirv_run/bfi.spvtxt @@ -20,6 +20,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %ulong_12 = OpConstant %ulong 12 %44 = OpFunction %uint None %54 @@ -51,19 +53,25 @@ %14 = OpLoad %uint %35 Aligned 4 OpStore %6 %14 %17 = OpLoad %ulong %4 - %30 = OpIAdd %ulong %17 %ulong_4 - %36 = OpConvertUToPtr %_ptr_Generic_uint %30 - %16 = OpLoad %uint %36 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_uint %17 + %62 = OpBitcast %_ptr_Generic_uchar %36 + %63 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %62 %ulong_4 + %30 = OpBitcast %_ptr_Generic_uint %63 + %16 = OpLoad %uint %30 Aligned 4 OpStore %7 %16 %19 = OpLoad %ulong %4 - %32 = OpIAdd %ulong %19 %ulong_8 - %37 = OpConvertUToPtr %_ptr_Generic_uint %32 - %18 = OpLoad %uint %37 Aligned 4 + %37 = OpConvertUToPtr %_ptr_Generic_uint %19 + %64 = OpBitcast %_ptr_Generic_uchar %37 + %65 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %64 %ulong_8 + %32 = OpBitcast %_ptr_Generic_uint %65 + %18 = OpLoad %uint %32 Aligned 4 OpStore %8 %18 %21 = OpLoad %ulong %4 - %34 = OpIAdd %ulong %21 %ulong_12 - %38 = OpConvertUToPtr %_ptr_Generic_uint %34 - %20 = OpLoad %uint %38 Aligned 4 + %38 = OpConvertUToPtr %_ptr_Generic_uint %21 + %66 = OpBitcast %_ptr_Generic_uchar %38 + %67 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %66 %ulong_12 + %34 = OpBitcast %_ptr_Generic_uint %67 + %20 = OpLoad %uint %34 Aligned 4 OpStore %9 %20 %23 = OpLoad %uint %6 %24 = OpLoad %uint %7 @@ -71,7 +79,7 @@ %26 = OpLoad %uint %9 %40 = OpCopyObject %uint %23 %41 = OpCopyObject %uint %24 - %39 = OpFunctionCall %uint %44 %41 %40 %25 %26 + %39 = OpFunctionCall %uint %44 %40 %41 %25 %26 %22 = OpCopyObject %uint %39 OpStore %6 %22 %27 = OpLoad %ulong %5 diff --git a/ptx/src/test/spirv_run/call.spvtxt b/ptx/src/test/spirv_run/call.spvtxt index 5473234..6929b1e 100644 --- a/ptx/src/test/spirv_run/call.spvtxt +++ b/ptx/src/test/spirv_run/call.spvtxt @@ -42,7 +42,7 @@ %23 = OpBitcast %_ptr_Function_ulong %10 %24 = OpCopyObject %ulong %18 OpStore %23 %24 Aligned 8 - %43 = OpFunctionCall %void %1 %11 %10 + %43 = OpFunctionCall %void %1 %10 %11 %19 = OpLoad %ulong %11 Aligned 8 OpStore %9 %19 %20 = OpLoad %ulong %8 @@ -52,8 +52,8 @@ OpReturn OpFunctionEnd %1 = OpFunction %void None %44 - %27 = OpFunctionParameter %_ptr_Function_ulong %28 = OpFunctionParameter %_ptr_Function_ulong + %27 = OpFunctionParameter %_ptr_Function_ulong %35 = OpLabel %29 = OpVariable %_ptr_Function_ulong Function %30 = OpLoad %ulong %28 Aligned 8 diff --git a/ptx/src/test/spirv_run/cvt_rni.spvtxt b/ptx/src/test/spirv_run/cvt_rni.spvtxt index 288a939..e10999c 100644 --- a/ptx/src/test/spirv_run/cvt_rni.spvtxt +++ b/ptx/src/test/spirv_run/cvt_rni.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_4_0 = OpConstant %ulong 4 %1 = OpFunction %void None %37 %8 = OpFunctionParameter %ulong @@ -40,9 +42,11 @@ %12 = OpLoad %float %28 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %25 = OpIAdd %ulong %15 %ulong_4 - %29 = OpConvertUToPtr %_ptr_Generic_float %25 - %14 = OpLoad %float %29 Aligned 4 + %29 = OpConvertUToPtr %_ptr_Generic_float %15 + %44 = OpBitcast %_ptr_Generic_uchar %29 + %45 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %44 %ulong_4 + %25 = OpBitcast %_ptr_Generic_float %45 + %14 = OpLoad %float %25 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %16 = OpExtInst %float %34 rint %17 @@ -56,8 +60,10 @@ OpStore %30 %21 Aligned 4 %22 = OpLoad %ulong %5 %23 = OpLoad %float %7 - %27 = OpIAdd %ulong %22 %ulong_4_0 - %31 = OpConvertUToPtr %_ptr_Generic_float %27 - OpStore %31 %23 Aligned 4 + %31 = OpConvertUToPtr %_ptr_Generic_float %22 + %46 = OpBitcast %_ptr_Generic_uchar %31 + %47 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %46 %ulong_4_0 + %27 = OpBitcast %_ptr_Generic_float %47 + OpStore %27 %23 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_rzi.spvtxt b/ptx/src/test/spirv_run/cvt_rzi.spvtxt index 68c12c6..7dda454 100644 --- a/ptx/src/test/spirv_run/cvt_rzi.spvtxt +++ b/ptx/src/test/spirv_run/cvt_rzi.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_4_0 = OpConstant %ulong 4 %1 = OpFunction %void None %37 %8 = OpFunctionParameter %ulong @@ -40,9 +42,11 @@ %12 = OpLoad %float %28 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %25 = OpIAdd %ulong %15 %ulong_4 - %29 = OpConvertUToPtr %_ptr_Generic_float %25 - %14 = OpLoad %float %29 Aligned 4 + %29 = OpConvertUToPtr %_ptr_Generic_float %15 + %44 = OpBitcast %_ptr_Generic_uchar %29 + %45 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %44 %ulong_4 + %25 = OpBitcast %_ptr_Generic_float %45 + %14 = OpLoad %float %25 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %16 = OpExtInst %float %34 trunc %17 @@ -56,8 +60,10 @@ OpStore %30 %21 Aligned 4 %22 = OpLoad %ulong %5 %23 = OpLoad %float %7 - %27 = OpIAdd %ulong %22 %ulong_4_0 - %31 = OpConvertUToPtr %_ptr_Generic_float %27 - OpStore %31 %23 Aligned 4 + %31 = OpConvertUToPtr %_ptr_Generic_float %22 + %46 = OpBitcast %_ptr_Generic_uchar %31 + %47 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %46 %ulong_4_0 + %27 = OpBitcast %_ptr_Generic_float %47 + OpStore %27 %23 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt b/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt index d9ae053..c1229d4 100644 --- a/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt +++ b/ptx/src/test/spirv_run/cvt_s32_f32.spvtxt @@ -21,8 +21,11 @@ %float = OpTypeFloat 32 %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint %ulong_4_0 = OpConstant %ulong 4 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar %1 = OpFunction %void None %45 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -45,10 +48,12 @@ %12 = OpBitcast %uint %28 OpStore %6 %12 %15 = OpLoad %ulong %4 - %25 = OpIAdd %ulong %15 %ulong_4 - %31 = OpConvertUToPtr %_ptr_Generic_float %25 - %30 = OpLoad %float %31 Aligned 4 - %14 = OpBitcast %uint %30 + %30 = OpConvertUToPtr %_ptr_Generic_float %15 + %53 = OpBitcast %_ptr_Generic_uchar %30 + %54 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %53 %ulong_4 + %25 = OpBitcast %_ptr_Generic_float %54 + %31 = OpLoad %float %25 Aligned 4 + %14 = OpBitcast %uint %31 OpStore %7 %14 %17 = OpLoad %uint %6 %33 = OpBitcast %float %17 @@ -67,9 +72,11 @@ OpStore %36 %37 Aligned 4 %22 = OpLoad %ulong %5 %23 = OpLoad %uint %7 - %27 = OpIAdd %ulong %22 %ulong_4_0 - %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %27 + %38 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %22 + %57 = OpBitcast %_ptr_CrossWorkgroup_uchar %38 + %58 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %57 %ulong_4_0 + %27 = OpBitcast %_ptr_CrossWorkgroup_uint %58 %39 = OpCopyObject %uint %23 - OpStore %38 %39 Aligned 4 + OpStore %27 %39 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/div_approx.spvtxt b/ptx/src/test/spirv_run/div_approx.spvtxt index 274f73e..858ec8d 100644 --- a/ptx/src/test/spirv_run/div_approx.spvtxt +++ b/ptx/src/test/spirv_run/div_approx.spvtxt @@ -19,6 +19,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -40,9 +42,11 @@ %12 = OpLoad %float %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_float %22 - %14 = OpLoad %float %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_float %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_float %39 + %14 = OpLoad %float %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %18 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/extern_shared.spvtxt b/ptx/src/test/spirv_run/extern_shared.spvtxt index fb2987e..13587d5 100644 --- a/ptx/src/test/spirv_run/extern_shared.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared.spvtxt @@ -7,37 +7,30 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %30 = OpExtInstImport "OpenCL.std" + %27 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %2 "extern_shared" %1 %void = OpTypeVoid %uint = OpTypeInt 32 0 %_ptr_Workgroup_uint = OpTypePointer Workgroup %uint -%_ptr_Workgroup__ptr_Workgroup_uint = OpTypePointer Workgroup %_ptr_Workgroup_uint - %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup + %1 = OpVariable %_ptr_Workgroup_uint Workgroup %ulong = OpTypeInt 64 0 %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %38 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar -%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar + %34 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong -%_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %38 + %2 = OpFunction %void None %34 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong - %26 = OpFunctionParameter %_ptr_Workgroup_uchar - %39 = OpLabel - %27 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %24 = OpFunctionParameter %_ptr_Workgroup_uchar + %22 = OpLabel %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function %5 = OpVariable %_ptr_Function_ulong Function %6 = OpVariable %_ptr_Function_ulong Function %7 = OpVariable %_ptr_Function_ulong Function - OpStore %27 %26 - OpBranch %24 - %24 = OpLabel OpStore %3 %8 OpStore %4 %9 %10 = OpLoad %ulong %3 Aligned 8 @@ -45,22 +38,20 @@ %11 = OpLoad %ulong %4 Aligned 8 OpStore %6 %11 %13 = OpLoad %ulong %5 - %20 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %13 - %12 = OpLoad %ulong %20 Aligned 8 + %18 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %13 + %12 = OpLoad %ulong %18 Aligned 8 OpStore %7 %12 - %28 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %27 - %14 = OpLoad %_ptr_Workgroup_uint %28 - %15 = OpLoad %ulong %7 - %21 = OpBitcast %_ptr_Workgroup_ulong %14 - OpStore %21 %15 Aligned 8 - %29 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %27 - %17 = OpLoad %_ptr_Workgroup_uint %29 - %22 = OpBitcast %_ptr_Workgroup_ulong %17 - %16 = OpLoad %ulong %22 Aligned 8 - OpStore %7 %16 - %18 = OpLoad %ulong %6 - %19 = OpLoad %ulong %7 - %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %18 - OpStore %23 %19 Aligned 8 + %14 = OpLoad %ulong %7 + %25 = OpBitcast %_ptr_Workgroup_uint %24 + %19 = OpBitcast %_ptr_Workgroup_ulong %25 + OpStore %19 %14 Aligned 8 + %26 = OpBitcast %_ptr_Workgroup_uint %24 + %20 = OpBitcast %_ptr_Workgroup_ulong %26 + %15 = OpLoad %ulong %20 Aligned 8 + OpStore %7 %15 + %16 = OpLoad %ulong %6 + %17 = OpLoad %ulong %7 + %21 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 + OpStore %21 %17 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt index 7043172..5af7168 100644 --- a/ptx/src/test/spirv_run/extern_shared_call.spvtxt +++ b/ptx/src/test/spirv_run/extern_shared_call.spvtxt @@ -7,87 +7,72 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %46 = OpExtInstImport "OpenCL.std" + %40 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL - OpEntryPoint Kernel %14 "extern_shared_call" %1 + OpEntryPoint Kernel %12 "extern_shared_call" %1 OpDecorate %1 Alignment 4 %void = OpTypeVoid %uint = OpTypeInt 32 0 %_ptr_Workgroup_uint = OpTypePointer Workgroup %uint -%_ptr_Workgroup__ptr_Workgroup_uint = OpTypePointer Workgroup %_ptr_Workgroup_uint - %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uint Workgroup + %1 = OpVariable %_ptr_Workgroup_uint Workgroup %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar - %53 = OpTypeFunction %void %_ptr_Workgroup_uchar -%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar + %46 = OpTypeFunction %void %_ptr_Workgroup_uchar %ulong = OpTypeInt 64 0 %_ptr_Function_ulong = OpTypePointer Function %ulong -%_ptr_Function__ptr_Workgroup_uint = OpTypePointer Function %_ptr_Workgroup_uint %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong %ulong_2 = OpConstant %ulong 2 - %60 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar + %50 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %2 = OpFunction %void None %53 - %38 = OpFunctionParameter %_ptr_Workgroup_uchar - %54 = OpLabel - %39 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %2 = OpFunction %void None %46 + %34 = OpFunctionParameter %_ptr_Workgroup_uchar + %11 = OpLabel %3 = OpVariable %_ptr_Function_ulong Function - OpStore %39 %38 - OpBranch %13 - %13 = OpLabel - %40 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39 - %5 = OpLoad %_ptr_Workgroup_uint %40 - %11 = OpBitcast %_ptr_Workgroup_ulong %5 - %4 = OpLoad %ulong %11 Aligned 8 + %35 = OpBitcast %_ptr_Workgroup_uint %34 + %9 = OpBitcast %_ptr_Workgroup_ulong %35 + %4 = OpLoad %ulong %9 Aligned 8 OpStore %3 %4 + %6 = OpLoad %ulong %3 + %5 = OpIAdd %ulong %6 %ulong_2 + OpStore %3 %5 %7 = OpLoad %ulong %3 - %6 = OpIAdd %ulong %7 %ulong_2 - OpStore %3 %6 - %41 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %39 - %8 = OpLoad %_ptr_Workgroup_uint %41 - %9 = OpLoad %ulong %3 - %12 = OpBitcast %_ptr_Workgroup_ulong %8 - OpStore %12 %9 Aligned 8 + %36 = OpBitcast %_ptr_Workgroup_uint %34 + %10 = OpBitcast %_ptr_Workgroup_ulong %36 + OpStore %10 %7 Aligned 8 OpReturn OpFunctionEnd - %14 = OpFunction %void None %60 - %20 = OpFunctionParameter %ulong - %21 = OpFunctionParameter %ulong - %42 = OpFunctionParameter %_ptr_Workgroup_uchar - %61 = OpLabel - %43 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %12 = OpFunction %void None %50 + %18 = OpFunctionParameter %ulong + %19 = OpFunctionParameter %ulong + %37 = OpFunctionParameter %_ptr_Workgroup_uchar + %32 = OpLabel + %13 = OpVariable %_ptr_Function_ulong Function + %14 = OpVariable %_ptr_Function_ulong Function %15 = OpVariable %_ptr_Function_ulong Function %16 = OpVariable %_ptr_Function_ulong Function %17 = OpVariable %_ptr_Function_ulong Function - %18 = OpVariable %_ptr_Function_ulong Function - %19 = OpVariable %_ptr_Function_ulong Function - OpStore %43 %42 - OpBranch %36 - %36 = OpLabel + OpStore %13 %18 + OpStore %14 %19 + %20 = OpLoad %ulong %13 Aligned 8 OpStore %15 %20 + %21 = OpLoad %ulong %14 Aligned 8 OpStore %16 %21 - %22 = OpLoad %ulong %15 Aligned 8 + %23 = OpLoad %ulong %15 + %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %23 + %22 = OpLoad %ulong %28 Aligned 8 OpStore %17 %22 - %23 = OpLoad %ulong %16 Aligned 8 - OpStore %18 %23 - %25 = OpLoad %ulong %17 - %32 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %25 - %24 = OpLoad %ulong %32 Aligned 8 - OpStore %19 %24 - %44 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43 - %26 = OpLoad %_ptr_Workgroup_uint %44 - %27 = OpLoad %ulong %19 - %33 = OpBitcast %_ptr_Workgroup_ulong %26 - OpStore %33 %27 Aligned 8 - %63 = OpFunctionCall %void %2 %42 - %45 = OpBitcast %_ptr_Function__ptr_Workgroup_uint %43 - %29 = OpLoad %_ptr_Workgroup_uint %45 - %34 = OpBitcast %_ptr_Workgroup_ulong %29 - %28 = OpLoad %ulong %34 Aligned 8 - OpStore %19 %28 - %30 = OpLoad %ulong %18 - %31 = OpLoad %ulong %19 - %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %30 - OpStore %35 %31 Aligned 8 + %24 = OpLoad %ulong %17 + %38 = OpBitcast %_ptr_Workgroup_uint %37 + %29 = OpBitcast %_ptr_Workgroup_ulong %38 + OpStore %29 %24 Aligned 8 + %52 = OpFunctionCall %void %2 %37 + %39 = OpBitcast %_ptr_Workgroup_uint %37 + %30 = OpBitcast %_ptr_Workgroup_ulong %39 + %25 = OpLoad %ulong %30 Aligned 8 + OpStore %17 %25 + %26 = OpLoad %ulong %16 + %27 = OpLoad %ulong %17 + %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %26 + OpStore %31 %27 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/fma.spvtxt b/ptx/src/test/spirv_run/fma.spvtxt index 300a328..8cc0e16 100644 --- a/ptx/src/test/spirv_run/fma.spvtxt +++ b/ptx/src/test/spirv_run/fma.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %1 = OpFunction %void None %38 %9 = OpFunctionParameter %ulong @@ -41,14 +43,18 @@ %13 = OpLoad %float %29 Aligned 4 OpStore %6 %13 %16 = OpLoad %ulong %4 - %26 = OpIAdd %ulong %16 %ulong_4 - %30 = OpConvertUToPtr %_ptr_Generic_float %26 - %15 = OpLoad %float %30 Aligned 4 + %30 = OpConvertUToPtr %_ptr_Generic_float %16 + %45 = OpBitcast %_ptr_Generic_uchar %30 + %46 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %45 %ulong_4 + %26 = OpBitcast %_ptr_Generic_float %46 + %15 = OpLoad %float %26 Aligned 4 OpStore %7 %15 %18 = OpLoad %ulong %4 - %28 = OpIAdd %ulong %18 %ulong_8 - %31 = OpConvertUToPtr %_ptr_Generic_float %28 - %17 = OpLoad %float %31 Aligned 4 + %31 = OpConvertUToPtr %_ptr_Generic_float %18 + %47 = OpBitcast %_ptr_Generic_uchar %31 + %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_8 + %28 = OpBitcast %_ptr_Generic_float %48 + %17 = OpLoad %float %28 Aligned 4 OpStore %8 %17 %20 = OpLoad %float %6 %21 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/ld_st_offset.spvtxt b/ptx/src/test/spirv_run/ld_st_offset.spvtxt index 5e314a0..ea97222 100644 --- a/ptx/src/test/spirv_run/ld_st_offset.spvtxt +++ b/ptx/src/test/spirv_run/ld_st_offset.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_4_0 = OpConstant %ulong 4 %1 = OpFunction %void None %33 %8 = OpFunctionParameter %ulong @@ -40,9 +42,11 @@ %12 = OpLoad %uint %24 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %21 = OpIAdd %ulong %15 %ulong_4 - %25 = OpConvertUToPtr %_ptr_Generic_uint %21 - %14 = OpLoad %uint %25 Aligned 4 + %25 = OpConvertUToPtr %_ptr_Generic_uint %15 + %40 = OpBitcast %_ptr_Generic_uchar %25 + %41 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %40 %ulong_4 + %21 = OpBitcast %_ptr_Generic_uint %41 + %14 = OpLoad %uint %21 Aligned 4 OpStore %7 %14 %16 = OpLoad %ulong %5 %17 = OpLoad %uint %7 @@ -50,8 +54,10 @@ OpStore %26 %17 Aligned 4 %18 = OpLoad %ulong %5 %19 = OpLoad %uint %6 - %23 = OpIAdd %ulong %18 %ulong_4_0 - %27 = OpConvertUToPtr %_ptr_Generic_uint %23 - OpStore %27 %19 Aligned 4 + %27 = OpConvertUToPtr %_ptr_Generic_uint %18 + %42 = OpBitcast %_ptr_Generic_uchar %27 + %43 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %42 %ulong_4_0 + %23 = OpBitcast %_ptr_Generic_uint %43 + OpStore %23 %19 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mad_s32.spvtxt b/ptx/src/test/spirv_run/mad_s32.spvtxt index bb44af0..0ee3ca7 100644 --- a/ptx/src/test/spirv_run/mad_s32.spvtxt +++ b/ptx/src/test/spirv_run/mad_s32.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %ulong_4_0 = OpConstant %ulong 4 %ulong_8_0 = OpConstant %ulong 8 @@ -44,20 +46,24 @@ %14 = OpLoad %uint %38 Aligned 4 OpStore %7 %14 %17 = OpLoad %ulong %4 - %31 = OpIAdd %ulong %17 %ulong_4 - %39 = OpConvertUToPtr %_ptr_Generic_uint %31 - %16 = OpLoad %uint %39 Aligned 4 + %39 = OpConvertUToPtr %_ptr_Generic_uint %17 + %56 = OpBitcast %_ptr_Generic_uchar %39 + %57 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %56 %ulong_4 + %31 = OpBitcast %_ptr_Generic_uint %57 + %16 = OpLoad %uint %31 Aligned 4 OpStore %8 %16 %19 = OpLoad %ulong %4 - %33 = OpIAdd %ulong %19 %ulong_8 - %40 = OpConvertUToPtr %_ptr_Generic_uint %33 - %18 = OpLoad %uint %40 Aligned 4 + %40 = OpConvertUToPtr %_ptr_Generic_uint %19 + %58 = OpBitcast %_ptr_Generic_uchar %40 + %59 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %58 %ulong_8 + %33 = OpBitcast %_ptr_Generic_uint %59 + %18 = OpLoad %uint %33 Aligned 4 OpStore %9 %18 %21 = OpLoad %uint %7 %22 = OpLoad %uint %8 %23 = OpLoad %uint %9 - %54 = OpIMul %uint %21 %22 - %20 = OpIAdd %uint %23 %54 + %60 = OpIMul %uint %21 %22 + %20 = OpIAdd %uint %23 %60 OpStore %6 %20 %24 = OpLoad %ulong %5 %25 = OpLoad %uint %6 @@ -65,13 +71,17 @@ OpStore %41 %25 Aligned 4 %26 = OpLoad %ulong %5 %27 = OpLoad %uint %6 - %35 = OpIAdd %ulong %26 %ulong_4_0 - %42 = OpConvertUToPtr %_ptr_Generic_uint %35 - OpStore %42 %27 Aligned 4 + %42 = OpConvertUToPtr %_ptr_Generic_uint %26 + %61 = OpBitcast %_ptr_Generic_uchar %42 + %62 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %61 %ulong_4_0 + %35 = OpBitcast %_ptr_Generic_uint %62 + OpStore %35 %27 Aligned 4 %28 = OpLoad %ulong %5 %29 = OpLoad %uint %6 - %37 = OpIAdd %ulong %28 %ulong_8_0 - %43 = OpConvertUToPtr %_ptr_Generic_uint %37 - OpStore %43 %29 Aligned 4 + %43 = OpConvertUToPtr %_ptr_Generic_uint %28 + %63 = OpBitcast %_ptr_Generic_uchar %43 + %64 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %63 %ulong_8_0 + %37 = OpBitcast %_ptr_Generic_uint %64 + OpStore %37 %29 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/max.spvtxt b/ptx/src/test/spirv_run/max.spvtxt index d3ffa2f..86b732a 100644 --- a/ptx/src/test/spirv_run/max.spvtxt +++ b/ptx/src/test/spirv_run/max.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %39 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/min.spvtxt b/ptx/src/test/spirv_run/min.spvtxt index de2e35e..a187376 100644 --- a/ptx/src/test/spirv_run/min.spvtxt +++ b/ptx/src/test/spirv_run/min.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %39 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/mul_ftz.spvtxt b/ptx/src/test/spirv_run/mul_ftz.spvtxt index ed268fb..e7a4a56 100644 --- a/ptx/src/test/spirv_run/mul_ftz.spvtxt +++ b/ptx/src/test/spirv_run/mul_ftz.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %float %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_float %22 - %14 = OpLoad %float %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_float %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_float %39 + %14 = OpLoad %float %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %18 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/mul_non_ftz.spvtxt b/ptx/src/test/spirv_run/mul_non_ftz.spvtxt index 436aca1..5326baa 100644 --- a/ptx/src/test/spirv_run/mul_non_ftz.spvtxt +++ b/ptx/src/test/spirv_run/mul_non_ftz.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_float = OpTypePointer Function %float %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %float %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_float %22 - %14 = OpLoad %float %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_float %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_float %39 + %14 = OpLoad %float %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %float %6 %18 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/mul_wide.spvtxt b/ptx/src/test/spirv_run/mul_wide.spvtxt index 7ac81cf..e96a964 100644 --- a/ptx/src/test/spirv_run/mul_wide.spvtxt +++ b/ptx/src/test/spirv_run/mul_wide.spvtxt @@ -18,7 +18,9 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_CrossWorkgroup_uint = OpTypePointer CrossWorkgroup %uint %ulong_4 = OpConstant %ulong 4 - %_struct_38 = OpTypeStruct %uint %uint + %uchar = OpTypeInt 8 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar + %_struct_42 = OpTypeStruct %uint %uint %v2uint = OpTypeVector %uint 2 %_ptr_Generic_ulong = OpTypePointer Generic %ulong %1 = OpFunction %void None %33 @@ -43,17 +45,19 @@ %13 = OpLoad %uint %24 Aligned 4 OpStore %6 %13 %16 = OpLoad %ulong %4 - %23 = OpIAdd %ulong %16 %ulong_4 - %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %23 - %15 = OpLoad %uint %25 Aligned 4 + %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_uint %16 + %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %25 + %41 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %40 %ulong_4 + %23 = OpBitcast %_ptr_CrossWorkgroup_uint %41 + %15 = OpLoad %uint %23 Aligned 4 OpStore %7 %15 %18 = OpLoad %uint %6 %19 = OpLoad %uint %7 - %39 = OpSMulExtended %_struct_38 %18 %19 - %40 = OpCompositeExtract %uint %39 0 - %41 = OpCompositeExtract %uint %39 1 - %43 = OpCompositeConstruct %v2uint %40 %41 - %17 = OpBitcast %ulong %43 + %43 = OpSMulExtended %_struct_42 %18 %19 + %44 = OpCompositeExtract %uint %43 0 + %45 = OpCompositeExtract %uint %43 1 + %47 = OpCompositeConstruct %v2uint %44 %45 + %17 = OpBitcast %ulong %47 OpStore %8 %17 %20 = OpLoad %ulong %5 %21 = OpLoad %ulong %8 diff --git a/ptx/src/test/spirv_run/or.spvtxt b/ptx/src/test/spirv_run/or.spvtxt index fef3f40..82db00c 100644 --- a/ptx/src/test/spirv_run/or.spvtxt +++ b/ptx/src/test/spirv_run/or.spvtxt @@ -16,6 +16,8 @@ %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_Generic_ulong = OpTypePointer Generic %ulong %ulong_8 = OpConstant %ulong 8 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %34 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -37,9 +39,11 @@ %12 = OpLoad %ulong %23 Aligned 8 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_8 - %24 = OpConvertUToPtr %_ptr_Generic_ulong %22 - %14 = OpLoad %ulong %24 Aligned 8 + %24 = OpConvertUToPtr %_ptr_Generic_ulong %15 + %39 = OpBitcast %_ptr_Generic_uchar %24 + %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_8 + %22 = OpBitcast %_ptr_Generic_ulong %40 + %14 = OpLoad %ulong %22 Aligned 8 OpStore %7 %14 %17 = OpLoad %ulong %6 %18 = OpLoad %ulong %7 diff --git a/ptx/src/test/spirv_run/pred_not.spvtxt b/ptx/src/test/spirv_run/pred_not.spvtxt index 18fde05..644731b 100644 --- a/ptx/src/test/spirv_run/pred_not.spvtxt +++ b/ptx/src/test/spirv_run/pred_not.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_ulong = OpTypePointer Generic %ulong %ulong_8 = OpConstant %ulong 8 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %true = OpConstantTrue %bool %false = OpConstantFalse %bool %ulong_1 = OpConstant %ulong 1 @@ -45,9 +47,11 @@ %18 = OpLoad %ulong %37 Aligned 8 OpStore %6 %18 %21 = OpLoad %ulong %4 - %34 = OpIAdd %ulong %21 %ulong_8 - %38 = OpConvertUToPtr %_ptr_Generic_ulong %34 - %20 = OpLoad %ulong %38 Aligned 8 + %38 = OpConvertUToPtr %_ptr_Generic_ulong %21 + %52 = OpBitcast %_ptr_Generic_uchar %38 + %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_8 + %34 = OpBitcast %_ptr_Generic_ulong %53 + %20 = OpLoad %ulong %34 Aligned 8 OpStore %7 %20 %23 = OpLoad %ulong %6 %24 = OpLoad %ulong %7 diff --git a/ptx/src/test/spirv_run/reg_local.spvtxt b/ptx/src/test/spirv_run/reg_local.spvtxt index 7bb5bd9..a0b957a 100644 --- a/ptx/src/test/spirv_run/reg_local.spvtxt +++ b/ptx/src/test/spirv_run/reg_local.spvtxt @@ -26,6 +26,7 @@ %ulong_0 = OpConstant %ulong 0 %_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_0_0 = OpConstant %ulong 0 +%_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar %1 = OpFunction %void None %37 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -48,10 +49,10 @@ %12 = OpCopyObject %ulong %24 OpStore %7 %12 %14 = OpLoad %ulong %7 - %26 = OpCopyObject %ulong %14 - %19 = OpIAdd %ulong %26 %ulong_1 - %27 = OpBitcast %_ptr_Generic_ulong %4 - OpStore %27 %19 Aligned 8 + %19 = OpIAdd %ulong %14 %ulong_1 + %26 = OpBitcast %_ptr_Generic_ulong %4 + %27 = OpCopyObject %ulong %19 + OpStore %26 %27 Aligned 8 %28 = OpBitcast %_ptr_Generic_ulong %4 %47 = OpBitcast %_ptr_Generic_uchar %28 %48 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %47 %ulong_0 @@ -61,9 +62,11 @@ OpStore %7 %15 %16 = OpLoad %ulong %6 %17 = OpLoad %ulong %7 - %23 = OpIAdd %ulong %16 %ulong_0_0 - %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %23 + %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 + %50 = OpBitcast %_ptr_CrossWorkgroup_uchar %30 + %51 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %50 %ulong_0_0 + %23 = OpBitcast %_ptr_CrossWorkgroup_ulong %51 %31 = OpCopyObject %ulong %17 - OpStore %30 %31 Aligned 8 + OpStore %23 %31 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/rem.spvtxt b/ptx/src/test/spirv_run/rem.spvtxt index ce1d3e6..2184523 100644 --- a/ptx/src/test/spirv_run/rem.spvtxt +++ b/ptx/src/test/spirv_run/rem.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %39 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/test/spirv_run/selp.spvtxt b/ptx/src/test/spirv_run/selp.spvtxt index 9798758..40c0bce 100644 --- a/ptx/src/test/spirv_run/selp.spvtxt +++ b/ptx/src/test/spirv_run/selp.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_ushort = OpTypePointer Function %ushort %_ptr_Generic_ushort = OpTypePointer Generic %ushort %ulong_2 = OpConstant %ulong 2 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %bool = OpTypeBool %false = OpConstantFalse %bool %1 = OpFunction %void None %32 @@ -41,9 +43,11 @@ %12 = OpLoad %ushort %24 Aligned 2 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_2 - %25 = OpConvertUToPtr %_ptr_Generic_ushort %22 - %14 = OpLoad %ushort %25 Aligned 2 + %25 = OpConvertUToPtr %_ptr_Generic_ushort %15 + %39 = OpBitcast %_ptr_Generic_uchar %25 + %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_2 + %22 = OpBitcast %_ptr_Generic_ushort %40 + %14 = OpLoad %ushort %22 Aligned 2 OpStore %7 %14 %17 = OpLoad %ushort %6 %18 = OpLoad %ushort %7 diff --git a/ptx/src/test/spirv_run/selp_true.spvtxt b/ptx/src/test/spirv_run/selp_true.spvtxt index f7038e0..81b3b5f 100644 --- a/ptx/src/test/spirv_run/selp_true.spvtxt +++ b/ptx/src/test/spirv_run/selp_true.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_ushort = OpTypePointer Function %ushort %_ptr_Generic_ushort = OpTypePointer Generic %ushort %ulong_2 = OpConstant %ulong 2 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %bool = OpTypeBool %true = OpConstantTrue %bool %1 = OpFunction %void None %32 @@ -41,9 +43,11 @@ %12 = OpLoad %ushort %24 Aligned 2 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_2 - %25 = OpConvertUToPtr %_ptr_Generic_ushort %22 - %14 = OpLoad %ushort %25 Aligned 2 + %25 = OpConvertUToPtr %_ptr_Generic_ushort %15 + %39 = OpBitcast %_ptr_Generic_uchar %25 + %40 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %39 %ulong_2 + %22 = OpBitcast %_ptr_Generic_ushort %40 + %14 = OpLoad %ushort %22 Aligned 2 OpStore %7 %14 %17 = OpLoad %ushort %6 %18 = OpLoad %ushort %7 diff --git a/ptx/src/test/spirv_run/setp.spvtxt b/ptx/src/test/spirv_run/setp.spvtxt index c3129e3..5868881 100644 --- a/ptx/src/test/spirv_run/setp.spvtxt +++ b/ptx/src/test/spirv_run/setp.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_ulong = OpTypePointer Generic %ulong %ulong_8 = OpConstant %ulong 8 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_1 = OpConstant %ulong 1 %ulong_2 = OpConstant %ulong 2 %1 = OpFunction %void None %43 @@ -43,9 +45,11 @@ %18 = OpLoad %ulong %35 Aligned 8 OpStore %6 %18 %21 = OpLoad %ulong %4 - %32 = OpIAdd %ulong %21 %ulong_8 - %36 = OpConvertUToPtr %_ptr_Generic_ulong %32 - %20 = OpLoad %ulong %36 Aligned 8 + %36 = OpConvertUToPtr %_ptr_Generic_ulong %21 + %50 = OpBitcast %_ptr_Generic_uchar %36 + %51 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %50 %ulong_8 + %32 = OpBitcast %_ptr_Generic_ulong %51 + %20 = OpLoad %ulong %32 Aligned 8 OpStore %7 %20 %23 = OpLoad %ulong %6 %24 = OpLoad %ulong %7 diff --git a/ptx/src/test/spirv_run/setp_gt.spvtxt b/ptx/src/test/spirv_run/setp_gt.spvtxt index 77f6546..e9783f5 100644 --- a/ptx/src/test/spirv_run/setp_gt.spvtxt +++ b/ptx/src/test/spirv_run/setp_gt.spvtxt @@ -20,6 +20,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %43 %14 = OpFunctionParameter %ulong %15 = OpFunctionParameter %ulong @@ -43,9 +45,11 @@ %18 = OpLoad %float %35 Aligned 4 OpStore %6 %18 %21 = OpLoad %ulong %4 - %34 = OpIAdd %ulong %21 %ulong_4 - %36 = OpConvertUToPtr %_ptr_Generic_float %34 - %20 = OpLoad %float %36 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_float %21 + %52 = OpBitcast %_ptr_Generic_uchar %36 + %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_4 + %34 = OpBitcast %_ptr_Generic_float %53 + %20 = OpLoad %float %34 Aligned 4 OpStore %7 %20 %23 = OpLoad %float %6 %24 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/setp_leu.spvtxt b/ptx/src/test/spirv_run/setp_leu.spvtxt index f80880a..1d2d781 100644 --- a/ptx/src/test/spirv_run/setp_leu.spvtxt +++ b/ptx/src/test/spirv_run/setp_leu.spvtxt @@ -20,6 +20,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %43 %14 = OpFunctionParameter %ulong %15 = OpFunctionParameter %ulong @@ -43,9 +45,11 @@ %18 = OpLoad %float %35 Aligned 4 OpStore %6 %18 %21 = OpLoad %ulong %4 - %34 = OpIAdd %ulong %21 %ulong_4 - %36 = OpConvertUToPtr %_ptr_Generic_float %34 - %20 = OpLoad %float %36 Aligned 4 + %36 = OpConvertUToPtr %_ptr_Generic_float %21 + %52 = OpBitcast %_ptr_Generic_uchar %36 + %53 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %52 %ulong_4 + %34 = OpBitcast %_ptr_Generic_float %53 + %20 = OpLoad %float %34 Aligned 4 OpStore %7 %20 %23 = OpLoad %float %6 %24 = OpLoad %float %7 diff --git a/ptx/src/test/spirv_run/setp_nan.spvtxt b/ptx/src/test/spirv_run/setp_nan.spvtxt index 4a9fe11..2ee333a 100644 --- a/ptx/src/test/spirv_run/setp_nan.spvtxt +++ b/ptx/src/test/spirv_run/setp_nan.spvtxt @@ -22,6 +22,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %ulong_12 = OpConstant %ulong 12 %ulong_16 = OpConstant %ulong 16 @@ -69,45 +71,59 @@ %36 = OpLoad %float %116 Aligned 4 OpStore %6 %36 %39 = OpLoad %ulong %4 - %89 = OpIAdd %ulong %39 %ulong_4 - %117 = OpConvertUToPtr %_ptr_Generic_float %89 - %38 = OpLoad %float %117 Aligned 4 + %117 = OpConvertUToPtr %_ptr_Generic_float %39 + %144 = OpBitcast %_ptr_Generic_uchar %117 + %145 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %144 %ulong_4 + %89 = OpBitcast %_ptr_Generic_float %145 + %38 = OpLoad %float %89 Aligned 4 OpStore %7 %38 %41 = OpLoad %ulong %4 - %91 = OpIAdd %ulong %41 %ulong_8 - %118 = OpConvertUToPtr %_ptr_Generic_float %91 - %40 = OpLoad %float %118 Aligned 4 + %118 = OpConvertUToPtr %_ptr_Generic_float %41 + %146 = OpBitcast %_ptr_Generic_uchar %118 + %147 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %146 %ulong_8 + %91 = OpBitcast %_ptr_Generic_float %147 + %40 = OpLoad %float %91 Aligned 4 OpStore %8 %40 %43 = OpLoad %ulong %4 - %93 = OpIAdd %ulong %43 %ulong_12 - %119 = OpConvertUToPtr %_ptr_Generic_float %93 - %42 = OpLoad %float %119 Aligned 4 + %119 = OpConvertUToPtr %_ptr_Generic_float %43 + %148 = OpBitcast %_ptr_Generic_uchar %119 + %149 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %148 %ulong_12 + %93 = OpBitcast %_ptr_Generic_float %149 + %42 = OpLoad %float %93 Aligned 4 OpStore %9 %42 %45 = OpLoad %ulong %4 - %95 = OpIAdd %ulong %45 %ulong_16 - %120 = OpConvertUToPtr %_ptr_Generic_float %95 - %44 = OpLoad %float %120 Aligned 4 + %120 = OpConvertUToPtr %_ptr_Generic_float %45 + %150 = OpBitcast %_ptr_Generic_uchar %120 + %151 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %150 %ulong_16 + %95 = OpBitcast %_ptr_Generic_float %151 + %44 = OpLoad %float %95 Aligned 4 OpStore %10 %44 %47 = OpLoad %ulong %4 - %97 = OpIAdd %ulong %47 %ulong_20 - %121 = OpConvertUToPtr %_ptr_Generic_float %97 - %46 = OpLoad %float %121 Aligned 4 + %121 = OpConvertUToPtr %_ptr_Generic_float %47 + %152 = OpBitcast %_ptr_Generic_uchar %121 + %153 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %152 %ulong_20 + %97 = OpBitcast %_ptr_Generic_float %153 + %46 = OpLoad %float %97 Aligned 4 OpStore %11 %46 %49 = OpLoad %ulong %4 - %99 = OpIAdd %ulong %49 %ulong_24 - %122 = OpConvertUToPtr %_ptr_Generic_float %99 - %48 = OpLoad %float %122 Aligned 4 + %122 = OpConvertUToPtr %_ptr_Generic_float %49 + %154 = OpBitcast %_ptr_Generic_uchar %122 + %155 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %154 %ulong_24 + %99 = OpBitcast %_ptr_Generic_float %155 + %48 = OpLoad %float %99 Aligned 4 OpStore %12 %48 %51 = OpLoad %ulong %4 - %101 = OpIAdd %ulong %51 %ulong_28 - %123 = OpConvertUToPtr %_ptr_Generic_float %101 - %50 = OpLoad %float %123 Aligned 4 + %123 = OpConvertUToPtr %_ptr_Generic_float %51 + %156 = OpBitcast %_ptr_Generic_uchar %123 + %157 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %156 %ulong_28 + %101 = OpBitcast %_ptr_Generic_float %157 + %50 = OpLoad %float %101 Aligned 4 OpStore %13 %50 %53 = OpLoad %float %6 %54 = OpLoad %float %7 - %142 = OpIsNan %bool %53 - %143 = OpIsNan %bool %54 - %52 = OpLogicalOr %bool %142 %143 + %158 = OpIsNan %bool %53 + %159 = OpIsNan %bool %54 + %52 = OpLogicalOr %bool %158 %159 OpStore %15 %52 %55 = OpLoad %bool %15 OpBranchConditional %55 %16 %17 @@ -129,9 +145,9 @@ OpStore %124 %60 Aligned 4 %62 = OpLoad %float %8 %63 = OpLoad %float %9 - %145 = OpIsNan %bool %62 - %146 = OpIsNan %bool %63 - %61 = OpLogicalOr %bool %145 %146 + %161 = OpIsNan %bool %62 + %162 = OpIsNan %bool %63 + %61 = OpLogicalOr %bool %161 %162 OpStore %15 %61 %64 = OpLoad %bool %15 OpBranchConditional %64 %20 %21 @@ -149,14 +165,16 @@ %23 = OpLabel %68 = OpLoad %ulong %5 %69 = OpLoad %uint %14 - %107 = OpIAdd %ulong %68 %ulong_4_0 - %125 = OpConvertUToPtr %_ptr_Generic_uint %107 - OpStore %125 %69 Aligned 4 + %125 = OpConvertUToPtr %_ptr_Generic_uint %68 + %163 = OpBitcast %_ptr_Generic_uchar %125 + %164 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %163 %ulong_4_0 + %107 = OpBitcast %_ptr_Generic_uint %164 + OpStore %107 %69 Aligned 4 %71 = OpLoad %float %10 %72 = OpLoad %float %11 - %147 = OpIsNan %bool %71 - %148 = OpIsNan %bool %72 - %70 = OpLogicalOr %bool %147 %148 + %165 = OpIsNan %bool %71 + %166 = OpIsNan %bool %72 + %70 = OpLogicalOr %bool %165 %166 OpStore %15 %70 %73 = OpLoad %bool %15 OpBranchConditional %73 %24 %25 @@ -174,14 +192,16 @@ %27 = OpLabel %77 = OpLoad %ulong %5 %78 = OpLoad %uint %14 - %111 = OpIAdd %ulong %77 %ulong_8_0 - %126 = OpConvertUToPtr %_ptr_Generic_uint %111 - OpStore %126 %78 Aligned 4 + %126 = OpConvertUToPtr %_ptr_Generic_uint %77 + %167 = OpBitcast %_ptr_Generic_uchar %126 + %168 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %167 %ulong_8_0 + %111 = OpBitcast %_ptr_Generic_uint %168 + OpStore %111 %78 Aligned 4 %80 = OpLoad %float %12 %81 = OpLoad %float %13 - %149 = OpIsNan %bool %80 - %150 = OpIsNan %bool %81 - %79 = OpLogicalOr %bool %149 %150 + %169 = OpIsNan %bool %80 + %170 = OpIsNan %bool %81 + %79 = OpLogicalOr %bool %169 %170 OpStore %15 %79 %82 = OpLoad %bool %15 OpBranchConditional %82 %28 %29 @@ -199,8 +219,10 @@ %31 = OpLabel %86 = OpLoad %ulong %5 %87 = OpLoad %uint %14 - %115 = OpIAdd %ulong %86 %ulong_12_0 - %127 = OpConvertUToPtr %_ptr_Generic_uint %115 - OpStore %127 %87 Aligned 4 + %127 = OpConvertUToPtr %_ptr_Generic_uint %86 + %171 = OpBitcast %_ptr_Generic_uchar %127 + %172 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %171 %ulong_12_0 + %115 = OpBitcast %_ptr_Generic_uint %172 + OpStore %115 %87 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/setp_num.spvtxt b/ptx/src/test/spirv_run/setp_num.spvtxt index 3ac6eab..c576a50 100644 --- a/ptx/src/test/spirv_run/setp_num.spvtxt +++ b/ptx/src/test/spirv_run/setp_num.spvtxt @@ -22,6 +22,8 @@ %_ptr_Function_bool = OpTypePointer Function %bool %_ptr_Generic_float = OpTypePointer Generic %float %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %ulong_8 = OpConstant %ulong 8 %ulong_12 = OpConstant %ulong 12 %ulong_16 = OpConstant %ulong 16 @@ -77,46 +79,60 @@ %36 = OpLoad %float %116 Aligned 4 OpStore %6 %36 %39 = OpLoad %ulong %4 - %89 = OpIAdd %ulong %39 %ulong_4 - %117 = OpConvertUToPtr %_ptr_Generic_float %89 - %38 = OpLoad %float %117 Aligned 4 + %117 = OpConvertUToPtr %_ptr_Generic_float %39 + %144 = OpBitcast %_ptr_Generic_uchar %117 + %145 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %144 %ulong_4 + %89 = OpBitcast %_ptr_Generic_float %145 + %38 = OpLoad %float %89 Aligned 4 OpStore %7 %38 %41 = OpLoad %ulong %4 - %91 = OpIAdd %ulong %41 %ulong_8 - %118 = OpConvertUToPtr %_ptr_Generic_float %91 - %40 = OpLoad %float %118 Aligned 4 + %118 = OpConvertUToPtr %_ptr_Generic_float %41 + %146 = OpBitcast %_ptr_Generic_uchar %118 + %147 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %146 %ulong_8 + %91 = OpBitcast %_ptr_Generic_float %147 + %40 = OpLoad %float %91 Aligned 4 OpStore %8 %40 %43 = OpLoad %ulong %4 - %93 = OpIAdd %ulong %43 %ulong_12 - %119 = OpConvertUToPtr %_ptr_Generic_float %93 - %42 = OpLoad %float %119 Aligned 4 + %119 = OpConvertUToPtr %_ptr_Generic_float %43 + %148 = OpBitcast %_ptr_Generic_uchar %119 + %149 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %148 %ulong_12 + %93 = OpBitcast %_ptr_Generic_float %149 + %42 = OpLoad %float %93 Aligned 4 OpStore %9 %42 %45 = OpLoad %ulong %4 - %95 = OpIAdd %ulong %45 %ulong_16 - %120 = OpConvertUToPtr %_ptr_Generic_float %95 - %44 = OpLoad %float %120 Aligned 4 + %120 = OpConvertUToPtr %_ptr_Generic_float %45 + %150 = OpBitcast %_ptr_Generic_uchar %120 + %151 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %150 %ulong_16 + %95 = OpBitcast %_ptr_Generic_float %151 + %44 = OpLoad %float %95 Aligned 4 OpStore %10 %44 %47 = OpLoad %ulong %4 - %97 = OpIAdd %ulong %47 %ulong_20 - %121 = OpConvertUToPtr %_ptr_Generic_float %97 - %46 = OpLoad %float %121 Aligned 4 + %121 = OpConvertUToPtr %_ptr_Generic_float %47 + %152 = OpBitcast %_ptr_Generic_uchar %121 + %153 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %152 %ulong_20 + %97 = OpBitcast %_ptr_Generic_float %153 + %46 = OpLoad %float %97 Aligned 4 OpStore %11 %46 %49 = OpLoad %ulong %4 - %99 = OpIAdd %ulong %49 %ulong_24 - %122 = OpConvertUToPtr %_ptr_Generic_float %99 - %48 = OpLoad %float %122 Aligned 4 + %122 = OpConvertUToPtr %_ptr_Generic_float %49 + %154 = OpBitcast %_ptr_Generic_uchar %122 + %155 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %154 %ulong_24 + %99 = OpBitcast %_ptr_Generic_float %155 + %48 = OpLoad %float %99 Aligned 4 OpStore %12 %48 %51 = OpLoad %ulong %4 - %101 = OpIAdd %ulong %51 %ulong_28 - %123 = OpConvertUToPtr %_ptr_Generic_float %101 - %50 = OpLoad %float %123 Aligned 4 + %123 = OpConvertUToPtr %_ptr_Generic_float %51 + %156 = OpBitcast %_ptr_Generic_uchar %123 + %157 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %156 %ulong_28 + %101 = OpBitcast %_ptr_Generic_float %157 + %50 = OpLoad %float %101 Aligned 4 OpStore %13 %50 %53 = OpLoad %float %6 %54 = OpLoad %float %7 - %142 = OpIsNan %bool %53 - %143 = OpIsNan %bool %54 - %144 = OpLogicalOr %bool %142 %143 - %52 = OpSelect %bool %144 %false %true + %158 = OpIsNan %bool %53 + %159 = OpIsNan %bool %54 + %160 = OpLogicalOr %bool %158 %159 + %52 = OpSelect %bool %160 %false %true OpStore %15 %52 %55 = OpLoad %bool %15 OpBranchConditional %55 %16 %17 @@ -138,10 +154,10 @@ OpStore %124 %60 Aligned 4 %62 = OpLoad %float %8 %63 = OpLoad %float %9 - %148 = OpIsNan %bool %62 - %149 = OpIsNan %bool %63 - %150 = OpLogicalOr %bool %148 %149 - %61 = OpSelect %bool %150 %false_0 %true_0 + %164 = OpIsNan %bool %62 + %165 = OpIsNan %bool %63 + %166 = OpLogicalOr %bool %164 %165 + %61 = OpSelect %bool %166 %false_0 %true_0 OpStore %15 %61 %64 = OpLoad %bool %15 OpBranchConditional %64 %20 %21 @@ -159,15 +175,17 @@ %23 = OpLabel %68 = OpLoad %ulong %5 %69 = OpLoad %uint %14 - %107 = OpIAdd %ulong %68 %ulong_4_0 - %125 = OpConvertUToPtr %_ptr_Generic_uint %107 - OpStore %125 %69 Aligned 4 + %125 = OpConvertUToPtr %_ptr_Generic_uint %68 + %169 = OpBitcast %_ptr_Generic_uchar %125 + %170 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %169 %ulong_4_0 + %107 = OpBitcast %_ptr_Generic_uint %170 + OpStore %107 %69 Aligned 4 %71 = OpLoad %float %10 %72 = OpLoad %float %11 - %153 = OpIsNan %bool %71 - %154 = OpIsNan %bool %72 - %155 = OpLogicalOr %bool %153 %154 - %70 = OpSelect %bool %155 %false_1 %true_1 + %171 = OpIsNan %bool %71 + %172 = OpIsNan %bool %72 + %173 = OpLogicalOr %bool %171 %172 + %70 = OpSelect %bool %173 %false_1 %true_1 OpStore %15 %70 %73 = OpLoad %bool %15 OpBranchConditional %73 %24 %25 @@ -185,15 +203,17 @@ %27 = OpLabel %77 = OpLoad %ulong %5 %78 = OpLoad %uint %14 - %111 = OpIAdd %ulong %77 %ulong_8_0 - %126 = OpConvertUToPtr %_ptr_Generic_uint %111 - OpStore %126 %78 Aligned 4 + %126 = OpConvertUToPtr %_ptr_Generic_uint %77 + %176 = OpBitcast %_ptr_Generic_uchar %126 + %177 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %176 %ulong_8_0 + %111 = OpBitcast %_ptr_Generic_uint %177 + OpStore %111 %78 Aligned 4 %80 = OpLoad %float %12 %81 = OpLoad %float %13 - %158 = OpIsNan %bool %80 - %159 = OpIsNan %bool %81 - %160 = OpLogicalOr %bool %158 %159 - %79 = OpSelect %bool %160 %false_2 %true_2 + %178 = OpIsNan %bool %80 + %179 = OpIsNan %bool %81 + %180 = OpLogicalOr %bool %178 %179 + %79 = OpSelect %bool %180 %false_2 %true_2 OpStore %15 %79 %82 = OpLoad %bool %15 OpBranchConditional %82 %28 %29 @@ -211,8 +231,10 @@ %31 = OpLabel %86 = OpLoad %ulong %5 %87 = OpLoad %uint %14 - %115 = OpIAdd %ulong %86 %ulong_12_0 - %127 = OpConvertUToPtr %_ptr_Generic_uint %115 - OpStore %127 %87 Aligned 4 + %127 = OpConvertUToPtr %_ptr_Generic_uint %86 + %183 = OpBitcast %_ptr_Generic_uchar %127 + %184 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %183 %ulong_12_0 + %115 = OpBitcast %_ptr_Generic_uint %184 + OpStore %115 %87 Aligned 4 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt index 2ea964c..1b2e3dd 100644 --- a/ptx/src/test/spirv_run/shared_ptr_32.spvtxt +++ b/ptx/src/test/spirv_run/shared_ptr_32.spvtxt @@ -24,7 +24,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %uint_0 = OpConstant %uint 0 + %ulong_0 = OpConstant %ulong 0 +%_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar %1 = OpFunction %void None %40 %10 = OpFunctionParameter %ulong %11 = OpFunctionParameter %ulong @@ -54,9 +55,11 @@ %27 = OpConvertUToPtr %_ptr_Workgroup_ulong %17 OpStore %27 %18 Aligned 8 %20 = OpLoad %uint %7 - %24 = OpIAdd %uint %20 %uint_0 - %28 = OpConvertUToPtr %_ptr_Workgroup_ulong %24 - %19 = OpLoad %ulong %28 Aligned 8 + %28 = OpConvertUToPtr %_ptr_Workgroup_ulong %20 + %46 = OpBitcast %_ptr_Workgroup_uchar %28 + %47 = OpInBoundsPtrAccessChain %_ptr_Workgroup_uchar %46 %ulong_0 + %24 = OpBitcast %_ptr_Workgroup_ulong %47 + %19 = OpLoad %ulong %24 Aligned 8 OpStore %9 %19 %21 = OpLoad %ulong %6 %22 = OpLoad %ulong %9 diff --git a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt index 19d5a5a..fd4f893 100644 --- a/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt +++ b/ptx/src/test/spirv_run/shared_ptr_take_address.spvtxt @@ -7,27 +7,24 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %33 = OpExtInstImport "OpenCL.std" + %31 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %2 "shared_ptr_take_address" %1 OpDecorate %1 Alignment 4 %void = OpTypeVoid %uchar = OpTypeInt 8 0 %_ptr_Workgroup_uchar = OpTypePointer Workgroup %uchar -%_ptr_Workgroup__ptr_Workgroup_uchar = OpTypePointer Workgroup %_ptr_Workgroup_uchar - %1 = OpVariable %_ptr_Workgroup__ptr_Workgroup_uchar Workgroup + %1 = OpVariable %_ptr_Workgroup_uchar Workgroup %ulong = OpTypeInt 64 0 - %39 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar -%_ptr_Function__ptr_Workgroup_uchar = OpTypePointer Function %_ptr_Workgroup_uchar + %36 = OpTypeFunction %void %ulong %ulong %_ptr_Workgroup_uchar %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong %_ptr_Workgroup_ulong = OpTypePointer Workgroup %ulong - %2 = OpFunction %void None %39 + %2 = OpFunction %void None %36 %10 = OpFunctionParameter %ulong %11 = OpFunctionParameter %ulong - %31 = OpFunctionParameter %_ptr_Workgroup_uchar - %40 = OpLabel - %32 = OpVariable %_ptr_Function__ptr_Workgroup_uchar Function + %30 = OpFunctionParameter %_ptr_Workgroup_uchar + %28 = OpLabel %3 = OpVariable %_ptr_Function_ulong Function %4 = OpVariable %_ptr_Function_ulong Function %5 = OpVariable %_ptr_Function_ulong Function @@ -35,34 +32,30 @@ %7 = OpVariable %_ptr_Function_ulong Function %8 = OpVariable %_ptr_Function_ulong Function %9 = OpVariable %_ptr_Function_ulong Function - OpStore %32 %31 - OpBranch %29 - %29 = OpLabel OpStore %3 %10 OpStore %4 %11 %12 = OpLoad %ulong %3 Aligned 8 OpStore %5 %12 %13 = OpLoad %ulong %4 Aligned 8 OpStore %6 %13 - %15 = OpLoad %_ptr_Workgroup_uchar %32 - %24 = OpConvertPtrToU %ulong %15 - %14 = OpCopyObject %ulong %24 + %23 = OpConvertPtrToU %ulong %30 + %14 = OpCopyObject %ulong %23 OpStore %7 %14 - %17 = OpLoad %ulong %5 - %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %17 - %16 = OpLoad %ulong %25 Aligned 8 - OpStore %8 %16 - %18 = OpLoad %ulong %7 - %19 = OpLoad %ulong %8 - %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %18 - OpStore %26 %19 Aligned 8 - %21 = OpLoad %ulong %7 - %27 = OpConvertUToPtr %_ptr_Workgroup_ulong %21 - %20 = OpLoad %ulong %27 Aligned 8 - OpStore %9 %20 - %22 = OpLoad %ulong %6 - %23 = OpLoad %ulong %9 - %28 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %22 - OpStore %28 %23 Aligned 8 + %16 = OpLoad %ulong %5 + %24 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %16 + %15 = OpLoad %ulong %24 Aligned 8 + OpStore %8 %15 + %17 = OpLoad %ulong %7 + %18 = OpLoad %ulong %8 + %25 = OpConvertUToPtr %_ptr_Workgroup_ulong %17 + OpStore %25 %18 Aligned 8 + %20 = OpLoad %ulong %7 + %26 = OpConvertUToPtr %_ptr_Workgroup_ulong %20 + %19 = OpLoad %ulong %26 Aligned 8 + OpStore %9 %19 + %21 = OpLoad %ulong %6 + %22 = OpLoad %ulong %9 + %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_ulong %21 + OpStore %27 %22 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt index 33812f6..cf0d86e 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid.spvtxt @@ -7,7 +7,7 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %50 = OpExtInstImport "OpenCL.std" + %54 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "stateful_ld_st_ntid" %gl_LocalInvocationID OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId @@ -18,34 +18,34 @@ %gl_LocalInvocationID = OpVariable %_ptr_Input_v3ulong Input %uchar = OpTypeInt 8 0 %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %57 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar + %61 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar %_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %1 = OpFunction %void None %57 + %1 = OpFunction %void None %61 %20 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar %21 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %48 = OpLabel - %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %52 = OpLabel + %12 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %13 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %10 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %11 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %6 = OpVariable %_ptr_Function_uint Function %7 = OpVariable %_ptr_Function_ulong Function %8 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %20 - OpStore %3 %21 - %13 = OpBitcast %_ptr_Function_ulong %2 - %44 = OpLoad %ulong %13 Aligned 8 - %12 = OpCopyObject %ulong %44 - %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %12 + OpStore %12 %20 + OpStore %13 %21 + %45 = OpBitcast %_ptr_Function_ulong %12 + %44 = OpLoad %ulong %45 Aligned 8 + %14 = OpCopyObject %ulong %44 + %22 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14 OpStore %10 %22 - %15 = OpBitcast %_ptr_Function_ulong %3 - %45 = OpLoad %ulong %15 Aligned 8 - %14 = OpCopyObject %ulong %45 - %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %14 + %47 = OpBitcast %_ptr_Function_ulong %13 + %46 = OpLoad %ulong %47 Aligned 8 + %15 = OpCopyObject %ulong %46 + %23 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %15 OpStore %11 %23 %24 = OpLoad %_ptr_CrossWorkgroup_uchar %10 %17 = OpConvertPtrToU %ulong %24 @@ -57,35 +57,37 @@ %18 = OpCopyObject %ulong %19 %27 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %18 OpStore %11 %27 - %62 = OpLoad %v3ulong %gl_LocalInvocationID - %43 = OpCompositeExtract %ulong %62 0 - %63 = OpBitcast %ulong %43 - %29 = OpUConvert %uint %63 + %66 = OpLoad %v3ulong %gl_LocalInvocationID + %43 = OpCompositeExtract %ulong %66 0 + %67 = OpBitcast %ulong %43 + %29 = OpUConvert %uint %67 %28 = OpCopyObject %uint %29 OpStore %6 %28 %31 = OpLoad %uint %6 - %64 = OpBitcast %uint %31 - %30 = OpUConvert %ulong %64 + %68 = OpBitcast %uint %31 + %30 = OpUConvert %ulong %68 OpStore %7 %30 %33 = OpLoad %_ptr_CrossWorkgroup_uchar %10 %34 = OpLoad %ulong %7 - %65 = OpBitcast %_ptr_CrossWorkgroup_uchar %33 - %66 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %65 %34 - %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %66 + %48 = OpCopyObject %ulong %34 + %69 = OpBitcast %_ptr_CrossWorkgroup_uchar %33 + %70 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %69 %48 + %32 = OpBitcast %_ptr_CrossWorkgroup_uchar %70 OpStore %10 %32 %36 = OpLoad %_ptr_CrossWorkgroup_uchar %11 %37 = OpLoad %ulong %7 - %67 = OpBitcast %_ptr_CrossWorkgroup_uchar %36 - %68 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %67 %37 - %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %68 + %49 = OpCopyObject %ulong %37 + %71 = OpBitcast %_ptr_CrossWorkgroup_uchar %36 + %72 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %71 %49 + %35 = OpBitcast %_ptr_CrossWorkgroup_uchar %72 OpStore %11 %35 %39 = OpLoad %_ptr_CrossWorkgroup_uchar %10 - %46 = OpBitcast %_ptr_CrossWorkgroup_ulong %39 - %38 = OpLoad %ulong %46 Aligned 8 + %50 = OpBitcast %_ptr_CrossWorkgroup_ulong %39 + %38 = OpLoad %ulong %50 Aligned 8 OpStore %8 %38 %40 = OpLoad %_ptr_CrossWorkgroup_uchar %11 %41 = OpLoad %ulong %8 - %47 = OpBitcast %_ptr_CrossWorkgroup_ulong %40 - OpStore %47 %41 Aligned 8 + %51 = OpBitcast %_ptr_CrossWorkgroup_ulong %40 + OpStore %51 %41 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt index cb77d14..97bf000 100644 --- a/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt +++ b/ptx/src/test/spirv_run/stateful_ld_st_ntid_chain.spvtxt @@ -7,7 +7,7 @@ OpCapability Int64 OpCapability Float16 OpCapability Float64 - %58 = OpExtInstImport "OpenCL.std" + %62 = OpExtInstImport "OpenCL.std" OpMemoryModel Physical64 OpenCL OpEntryPoint Kernel %1 "stateful_ld_st_ntid_chain" %gl_LocalInvocationID OpDecorate %gl_LocalInvocationID BuiltIn LocalInvocationId @@ -18,18 +18,18 @@ %gl_LocalInvocationID = OpVariable %_ptr_Input_v3ulong Input %uchar = OpTypeInt 8 0 %_ptr_CrossWorkgroup_uchar = OpTypePointer CrossWorkgroup %uchar - %65 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar + %69 = OpTypeFunction %void %_ptr_CrossWorkgroup_uchar %_ptr_CrossWorkgroup_uchar %_ptr_Function__ptr_CrossWorkgroup_uchar = OpTypePointer Function %_ptr_CrossWorkgroup_uchar %uint = OpTypeInt 32 0 %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Function_ulong = OpTypePointer Function %ulong %_ptr_CrossWorkgroup_ulong = OpTypePointer CrossWorkgroup %ulong - %1 = OpFunction %void None %65 + %1 = OpFunction %void None %69 %28 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar %29 = OpFunctionParameter %_ptr_CrossWorkgroup_uchar - %56 = OpLabel - %2 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function - %3 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %60 = OpLabel + %20 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function + %21 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %14 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %15 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function %16 = OpVariable %_ptr_Function__ptr_CrossWorkgroup_uchar Function @@ -39,17 +39,17 @@ %10 = OpVariable %_ptr_Function_uint Function %11 = OpVariable %_ptr_Function_ulong Function %12 = OpVariable %_ptr_Function_ulong Function - OpStore %2 %28 - OpStore %3 %29 - %21 = OpBitcast %_ptr_Function_ulong %2 - %52 = OpLoad %ulong %21 Aligned 8 - %20 = OpCopyObject %ulong %52 - %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %20 + OpStore %20 %28 + OpStore %21 %29 + %53 = OpBitcast %_ptr_Function_ulong %20 + %52 = OpLoad %ulong %53 Aligned 8 + %22 = OpCopyObject %ulong %52 + %30 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22 OpStore %14 %30 - %23 = OpBitcast %_ptr_Function_ulong %3 - %53 = OpLoad %ulong %23 Aligned 8 - %22 = OpCopyObject %ulong %53 - %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %22 + %55 = OpBitcast %_ptr_Function_ulong %21 + %54 = OpLoad %ulong %55 Aligned 8 + %23 = OpCopyObject %ulong %54 + %31 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %23 OpStore %17 %31 %32 = OpLoad %_ptr_CrossWorkgroup_uchar %14 %25 = OpConvertPtrToU %ulong %32 @@ -61,35 +61,37 @@ %26 = OpCopyObject %ulong %27 %35 = OpConvertUToPtr %_ptr_CrossWorkgroup_uchar %26 OpStore %18 %35 - %70 = OpLoad %v3ulong %gl_LocalInvocationID - %51 = OpCompositeExtract %ulong %70 0 - %71 = OpBitcast %ulong %51 - %37 = OpUConvert %uint %71 + %74 = OpLoad %v3ulong %gl_LocalInvocationID + %51 = OpCompositeExtract %ulong %74 0 + %75 = OpBitcast %ulong %51 + %37 = OpUConvert %uint %75 %36 = OpCopyObject %uint %37 OpStore %10 %36 %39 = OpLoad %uint %10 - %72 = OpBitcast %uint %39 - %38 = OpUConvert %ulong %72 + %76 = OpBitcast %uint %39 + %38 = OpUConvert %ulong %76 OpStore %11 %38 %41 = OpLoad %_ptr_CrossWorkgroup_uchar %15 %42 = OpLoad %ulong %11 - %73 = OpBitcast %_ptr_CrossWorkgroup_uchar %41 - %74 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %73 %42 - %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %74 + %56 = OpCopyObject %ulong %42 + %77 = OpBitcast %_ptr_CrossWorkgroup_uchar %41 + %78 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %77 %56 + %40 = OpBitcast %_ptr_CrossWorkgroup_uchar %78 OpStore %16 %40 %44 = OpLoad %_ptr_CrossWorkgroup_uchar %18 %45 = OpLoad %ulong %11 - %75 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 - %76 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %75 %45 - %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %76 + %57 = OpCopyObject %ulong %45 + %79 = OpBitcast %_ptr_CrossWorkgroup_uchar %44 + %80 = OpInBoundsPtrAccessChain %_ptr_CrossWorkgroup_uchar %79 %57 + %43 = OpBitcast %_ptr_CrossWorkgroup_uchar %80 OpStore %19 %43 %47 = OpLoad %_ptr_CrossWorkgroup_uchar %16 - %54 = OpBitcast %_ptr_CrossWorkgroup_ulong %47 - %46 = OpLoad %ulong %54 Aligned 8 + %58 = OpBitcast %_ptr_CrossWorkgroup_ulong %47 + %46 = OpLoad %ulong %58 Aligned 8 OpStore %12 %46 %48 = OpLoad %_ptr_CrossWorkgroup_uchar %19 %49 = OpLoad %ulong %12 - %55 = OpBitcast %_ptr_CrossWorkgroup_ulong %48 - OpStore %55 %49 Aligned 8 + %59 = OpBitcast %_ptr_CrossWorkgroup_ulong %48 + OpStore %59 %49 Aligned 8 OpReturn OpFunctionEnd diff --git a/ptx/src/test/spirv_run/vector.spvtxt b/ptx/src/test/spirv_run/vector.spvtxt index ecf2858..8253bf9 100644 --- a/ptx/src/test/spirv_run/vector.spvtxt +++ b/ptx/src/test/spirv_run/vector.spvtxt @@ -25,8 +25,8 @@ %1 = OpFunction %v2uint None %55 %7 = OpFunctionParameter %v2uint %24 = OpLabel - %2 = OpVariable %_ptr_Function_v2uint Function %3 = OpVariable %_ptr_Function_v2uint Function + %2 = OpVariable %_ptr_Function_v2uint Function %4 = OpVariable %_ptr_Function_v2uint Function %5 = OpVariable %_ptr_Function_uint Function %6 = OpVariable %_ptr_Function_uint Function diff --git a/ptx/src/test/spirv_run/verify.py b/ptx/src/test/spirv_run/verify.py new file mode 100644 index 0000000..dbfab00 --- /dev/null +++ b/ptx/src/test/spirv_run/verify.py @@ -0,0 +1,21 @@ +import os, sys, subprocess
+
+def main(path):
+ dirs = os.listdir(path)
+ for file in dirs:
+ if not file.endswith(".spvtxt"):
+ continue
+ full_file = os.path.join(path, file)
+ print(file)
+ spv_file = f"/tmp/{file}.spv"
+ # We nominally emit spv1.3, but use spv1.4 feature (OpEntryPoint interface changes in 1.4)
+ proc1 = subprocess.run(["spirv-as", "--target-env", "spv1.4", full_file, "-o", spv_file])
+ proc2 = subprocess.run(["spirv-dis", spv_file, "-o", f"{spv_file}.dis.txt"])
+ proc3 = subprocess.run(["spirv-val", spv_file ])
+ if proc1.returncode != 0 or proc2.returncode != 0 or proc3.returncode != 0:
+ print(proc1.returncode)
+ print(proc2.returncode)
+ print(proc3.returncode)
+
+if __name__ == "__main__":
+ main(sys.argv[1])
diff --git a/ptx/src/test/spirv_run/xor.spvtxt b/ptx/src/test/spirv_run/xor.spvtxt index 4cc8968..c3a1f6f 100644 --- a/ptx/src/test/spirv_run/xor.spvtxt +++ b/ptx/src/test/spirv_run/xor.spvtxt @@ -18,6 +18,8 @@ %_ptr_Function_uint = OpTypePointer Function %uint %_ptr_Generic_uint = OpTypePointer Generic %uint %ulong_4 = OpConstant %ulong 4 + %uchar = OpTypeInt 8 0 +%_ptr_Generic_uchar = OpTypePointer Generic %uchar %1 = OpFunction %void None %31 %8 = OpFunctionParameter %ulong %9 = OpFunctionParameter %ulong @@ -39,9 +41,11 @@ %12 = OpLoad %uint %23 Aligned 4 OpStore %6 %12 %15 = OpLoad %ulong %4 - %22 = OpIAdd %ulong %15 %ulong_4 - %24 = OpConvertUToPtr %_ptr_Generic_uint %22 - %14 = OpLoad %uint %24 Aligned 4 + %24 = OpConvertUToPtr %_ptr_Generic_uint %15 + %38 = OpBitcast %_ptr_Generic_uchar %24 + %39 = OpInBoundsPtrAccessChain %_ptr_Generic_uchar %38 %ulong_4 + %22 = OpBitcast %_ptr_Generic_uint %39 + %14 = OpLoad %uint %22 Aligned 4 OpStore %7 %14 %17 = OpLoad %uint %6 %18 = OpLoad %uint %7 diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7170950..c2562c3 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -1,11 +1,9 @@ use crate::ast;
use half::f16;
use rspirv::dr;
-use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem};
-use std::{
- collections::{hash_map, HashMap, HashSet},
- convert::TryInto,
-};
+use std::cell::RefCell;
+use std::collections::{hash_map, HashMap, HashSet};
+use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem, rc::Rc};
use rspirv::binary::Assemble;
@@ -48,64 +46,21 @@ enum SpirvType { }
impl SpirvType {
- fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
- let key = t.into();
- SpirvType::Pointer(Box::new(key), sc)
- }
-}
-
-impl From<ast::Type> for SpirvType {
- fn from(t: ast::Type) -> Self {
+ fn new(t: ast::Type) -> Self {
match t {
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
- ast::Type::Pointer(pointer_t, state_space) => SpirvType::Pointer(
- Box::new(SpirvType::from(ast::Type::from(pointer_t))),
- state_space.to_spirv(),
+ ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer(
+ Box::new(SpirvType::Base(pointer_t.into())),
+ space.to_spirv(),
),
}
}
-}
-impl From<ast::PointerType> for ast::Type {
- fn from(t: ast::PointerType) -> Self {
- match t {
- ast::PointerType::Scalar(t) => ast::Type::Scalar(t),
- ast::PointerType::Vector(t, len) => ast::Type::Vector(t, len),
- ast::PointerType::Array(t, dims) => ast::Type::Array(t, dims),
- ast::PointerType::Pointer(t, space) => {
- ast::Type::Pointer(ast::PointerType::Scalar(t), space)
- }
- }
- }
-}
-
-impl ast::Type {
- fn param_pointer_to(self, space: ast::LdStateSpace) -> Result<Self, TranslateError> {
- Ok(match self {
- ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
- ast::Type::Vector(t, len) => {
- ast::Type::Pointer(ast::PointerType::Vector(t, len), space)
- }
- ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
- ast::Type::Pointer(ast::PointerType::Scalar(t), space) => {
- ast::Type::Pointer(ast::PointerType::Pointer(t, space), space)
- }
- ast::Type::Pointer(_, _) => return Err(error_unreachable()),
- })
- }
-}
-
-impl Into<spirv::StorageClass> for ast::PointerStateSpace {
- fn into(self) -> spirv::StorageClass {
- match self {
- ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant,
- ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
- ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup,
- ast::PointerStateSpace::Param => spirv::StorageClass::Function,
- ast::PointerStateSpace::Generic => spirv::StorageClass::Generic,
- }
+ fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self {
+ let key = Self::new(t);
+ SpirvType::Pointer(Box::new(key), outer_space)
}
}
@@ -213,14 +168,18 @@ impl TypeWordMap { .or_insert_with(|| b.type_vector(None, base, len as u32))
}
SpirvType::Array(typ, array_dimensions) => {
- let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let (base_type, length) = match &*array_dimensions {
+ &[] => {
+ return self.get_or_add(b, SpirvType::Base(typ));
+ }
&[len] => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let base = self.get_or_add_spirv_scalar(b, typ);
let len_const = b.constant_u32(u32_type, None, len);
(base, len_const)
}
array_dimensions => {
+ let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32);
let base = self
.get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec()));
let len_const = b.constant_u32(u32_type, None, array_dimensions[0]);
@@ -262,7 +221,7 @@ impl TypeWordMap { fn get_or_add_fn(
&mut self,
b: &mut dr::Builder,
- in_params: impl ExactSizeIterator<Item = SpirvType>,
+ in_params: impl Iterator<Item = SpirvType>,
mut out_params: impl ExactSizeIterator<Item = SpirvType>,
) -> (spirv::Word, spirv::Word) {
let (out_args, out_spirv_type) = if out_params.len() == 0 {
@@ -274,6 +233,7 @@ impl TypeWordMap { self.get_or_add(b, arg_as_key),
)
} else {
+ // TODO: support multiple return values
todo!()
};
(
@@ -410,18 +370,7 @@ impl TypeWordMap { b.constant_composite(result_type, None, components.into_iter())
}
},
- ast::Type::Pointer(typ, state_space) => {
- let base_t = typ.clone().into();
- let base = self.get_or_add_constant(b, &base_t, &[])?;
- let result_type = self.get_or_add(
- b,
- SpirvType::Pointer(
- Box::new(SpirvType::from(base_t)),
- (*state_space).to_spirv(),
- ),
- );
- b.variable(result_type, None, (*state_space).to_spirv(), Some(base))
- }
+ ast::Type::Pointer(..) => return Err(error_unreachable()),
})
}
@@ -487,7 +436,7 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro .collect::<Vec<_>>();
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
- let call_map = get_call_map(&directives);
+ let call_map = get_kernels_call_map(&directives);
let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id());
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
@@ -525,9 +474,12 @@ pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<Module, TranslateErro }
// TODO: remove this once we have perf-function support for denorms
-fn emit_denorm_build_string(
+fn emit_denorm_build_string<'input>(
call_map: &HashMap<&str, HashSet<u32>>,
- denorm_information: &HashMap<MethodName, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+ denorm_information: &HashMap<
+ ast::MethodName<'input, spirv::Word>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
) -> CString {
let denorm_counts = denorm_information
.iter()
@@ -545,10 +497,12 @@ fn emit_denorm_build_string( .collect::<HashMap<_, _>>();
let mut flush_over_preserve = 0;
for (kernel, children) in call_map {
- flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0);
+ flush_over_preserve += *denorm_counts
+ .get(&ast::MethodName::Kernel(kernel))
+ .unwrap_or(&0);
for child_fn in children {
flush_over_preserve += *denorm_counts
- .get(&MethodName::Func(*child_fn))
+ .get(&ast::MethodName::Func(*child_fn))
.unwrap_or(&0);
}
}
@@ -564,15 +518,18 @@ fn emit_directives<'input>( map: &mut TypeWordMap,
id_defs: &GlobalStringIdResolver<'input>,
opencl_id: spirv::Word,
- denorm_information: &HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+ denorm_information: &HashMap<
+ ast::MethodName<'input, spirv::Word>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
call_map: &HashMap<&'input str, HashSet<spirv::Word>>,
- directives: Vec<Directive>,
+ directives: Vec<Directive<'input>>,
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<(), TranslateError> {
let empty_body = Vec::new();
for d in directives.iter() {
match d {
- Directive::Variable(var) => {
+ Directive::Variable(_, var) => {
emit_variable(builder, map, &var)?;
}
Directive::Method(f) => {
@@ -589,12 +546,13 @@ fn emit_directives<'input>( for var in f.globals.iter() {
emit_variable(builder, map, var)?;
}
+ let func_decl = (*f.func_decl).borrow();
let fn_id = emit_function_header(
builder,
map,
&id_defs,
&f.globals,
- &f.spirv_decl,
+ &*func_decl,
&denorm_information,
call_map,
&directives,
@@ -623,8 +581,13 @@ fn emit_directives<'input>( }
emit_function_body_ops(builder, map, opencl_id, &f_body)?;
builder.end_function()?;
- if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) =
- (&f.func_decl, &f.import_as)
+ if let (
+ ast::MethodDeclaration {
+ name: ast::MethodName::Func(fn_id),
+ ..
+ },
+ Some(name),
+ ) = (&*func_decl, &f.import_as)
{
builder.decorate(
*fn_id,
@@ -643,7 +606,7 @@ fn emit_directives<'input>( Ok(())
}
-fn get_call_map<'input>(
+fn get_kernels_call_map<'input>(
module: &[Directive<'input>],
) -> HashMap<&'input str, HashSet<spirv::Word>> {
let mut directly_called_by = HashMap::new();
@@ -654,14 +617,14 @@ fn get_call_map<'input>( body: Some(statements),
..
}) => {
- let call_key = MethodName::new(&func_decl);
+ let call_key: ast::MethodName<_> = (**func_decl).borrow().name;
if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) {
entry.insert(Vec::new());
}
for statement in statements {
match statement {
Statement::Call(call) => {
- multi_hash_map_append(&mut directly_called_by, call_key, call.func);
+ multi_hash_map_append(&mut directly_called_by, call_key, call.name);
}
_ => {}
}
@@ -673,28 +636,28 @@ fn get_call_map<'input>( let mut result = HashMap::new();
for (method_key, children) in directly_called_by.iter() {
match method_key {
- MethodName::Kernel(name) => {
+ ast::MethodName::Kernel(name) => {
let mut visited = HashSet::new();
for child in children {
add_call_map_single(&directly_called_by, &mut visited, *child);
}
result.insert(*name, visited);
}
- MethodName::Func(_) => {}
+ ast::MethodName::Func(_) => {}
}
}
result
}
fn add_call_map_single<'input>(
- directly_called_by: &MultiHashMap<MethodName<'input>, spirv::Word>,
+ directly_called_by: &MultiHashMap<ast::MethodName<'input, spirv::Word>, spirv::Word>,
visited: &mut HashSet<spirv::Word>,
current: spirv::Word,
) {
if !visited.insert(current) {
return;
}
- if let Some(children) = directly_called_by.get(&MethodName::Func(current)) {
+ if let Some(children) = directly_called_by.get(&ast::MethodName::Func(current)) {
for child in children {
add_call_map_single(directly_called_by, visited, *child);
}
@@ -714,11 +677,29 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, }
}
-// PTX represents dynamically allocated shared local memory as
-// .extern .shared .align 4 .b8 shared_mem[];
-// In SPIRV/OpenCL world this is expressed as an additional argument
-// This pass looks for all uses of .extern .shared and converts them to
-// an additional method argument
+/*
+ PTX represents dynamically allocated shared local memory as
+ .extern .shared .b32 shared_mem[];
+ In SPIRV/OpenCL world this is expressed as an additional argument
+ This pass looks for all uses of .extern .shared and converts them to
+ an additional method argument
+ The question is how this artificial argument should be expressed. There are
+ several options:
+ * Straight conversion:
+ .shared .b32 shared_mem[]
+ * Introduce .param_shared statespace:
+ .param_shared .b32 shared_mem
+ or
+ .param_shared .b32 shared_mem[]
+ * Introduce .shared_ptr <SCALAR> type:
+ .param .shared_ptr .b32 shared_mem
+ * Reuse .ptr hint:
+ .param .u64 .ptr shared_mem
+ This is the most tempting, but also the most nonsensical, .ptr is just a
+ hint, which has no semantical meaning (and the output of our
+ transformation has a semantical meaning - we emit additional
+ "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
+*/
fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>,
new_id: &mut impl FnMut() -> spirv::Word,
@@ -726,12 +707,16 @@ fn convert_dynamic_shared_memory_usage<'input>( let mut extern_shared_decls = HashMap::new();
for dir in module.iter() {
match dir {
- Directive::Variable(var) => {
- if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) =
- var.v_type
- {
- extern_shared_decls.insert(var.name, p_type);
- }
+ Directive::Variable(
+ linking,
+ ast::Variable {
+ v_type: ast::Type::Array(p_type, dims),
+ state_space: ast::StateSpace::Shared,
+ name,
+ ..
+ },
+ ) if linking.contains(ast::LinkingDirective::EXTERN) && dims.len() == 0 => {
+ extern_shared_decls.insert(*name, *p_type);
}
_ => {}
}
@@ -749,15 +734,14 @@ fn convert_dynamic_shared_memory_usage<'input>( globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
}) => {
- let call_key = MethodName::new(&func_decl);
+ let call_key = (*func_decl).borrow().name;
let statements = statements
.into_iter()
.map(|statement| match statement {
Statement::Call(call) => {
- multi_hash_map_append(&mut directly_called_by, call.func, call_key);
+ multi_hash_map_append(&mut directly_called_by, call.name, call_key);
Statement::Call(call)
}
statement => statement.map_id(&mut |id, _| {
@@ -773,7 +757,6 @@ fn convert_dynamic_shared_memory_usage<'input>( globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
})
}
@@ -792,66 +775,34 @@ fn convert_dynamic_shared_memory_usage<'input>( globals,
body: Some(statements),
import_as,
- mut spirv_decl,
tuning,
}) => {
- if !methods_using_extern_shared.contains(&spirv_decl.name) {
+ if !methods_using_extern_shared.contains(&(*func_decl).borrow().name) {
return Directive::Method(Function {
func_decl,
globals,
body: Some(statements),
import_as,
- spirv_decl,
tuning,
});
}
let shared_id_param = new_id();
- spirv_decl.input.push({
- ast::Variable {
- align: None,
- v_type: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Shared,
- ),
- array_init: Vec::new(),
- name: shared_id_param,
- }
- });
- spirv_decl.uses_shared_mem = true;
- let shared_var_id = new_id();
- let shared_var = ExpandedStatement::Variable(ast::Variable {
- align: None,
- name: shared_var_id,
- array_init: Vec::new(),
- v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
- ast::SizedScalarType::B8,
- ast::PointerStateSpace::Shared,
- )),
- });
- let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
- arg: ast::Arg2St {
- src1: shared_var_id,
- src2: shared_id_param,
- },
- typ: ast::Type::Scalar(ast::ScalarType::B8),
- member_index: None,
- });
- let mut new_statements = vec![shared_var, shared_var_st];
- replace_uses_of_shared_memory(
- &mut new_statements,
+ {
+ let mut func_decl = (*func_decl).borrow_mut();
+ func_decl.shared_mem = Some(shared_id_param);
+ }
+ let statements = replace_uses_of_shared_memory(
new_id,
&extern_shared_decls,
&mut methods_using_extern_shared,
shared_id_param,
- shared_var_id,
statements,
);
Directive::Method(Function {
func_decl,
globals,
- body: Some(new_statements),
+ body: Some(statements),
import_as,
- spirv_decl,
tuning,
})
}
@@ -861,47 +812,43 @@ fn convert_dynamic_shared_memory_usage<'input>( }
fn replace_uses_of_shared_memory<'a>(
- result: &mut Vec<ExpandedStatement>,
new_id: &mut impl FnMut() -> spirv::Word,
- extern_shared_decls: &HashMap<spirv::Word, ast::SizedScalarType>,
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
+ extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
+ methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
shared_id_param: spirv::Word,
- shared_var_id: spirv::Word,
statements: Vec<ExpandedStatement>,
-) {
+) -> Vec<ExpandedStatement> {
+ let mut result = Vec::with_capacity(statements.len());
for statement in statements {
match statement {
Statement::Call(mut call) => {
// We can safely skip checking call arguments,
// because there's simply no way to pass shared ptr
// without converting it to .b64 first
- if methods_using_extern_shared.contains(&MethodName::Func(call.func)) {
- call.param_list
- .push((shared_id_param, ast::FnArgumentType::Shared));
+ if methods_using_extern_shared.contains(&ast::MethodName::Func(call.name)) {
+ call.input_arguments.push((
+ shared_id_param,
+ ast::Type::Scalar(ast::ScalarType::B8),
+ ast::StateSpace::Shared,
+ ));
}
result.push(Statement::Call(call))
}
statement => {
let new_statement = statement.map_id(&mut |id, _| {
- if let Some(typ) = extern_shared_decls.get(&id) {
- if *typ == ast::SizedScalarType::B8 {
- return shared_var_id;
+ if let Some(scalar_type) = extern_shared_decls.get(&id) {
+ if *scalar_type == ast::ScalarType::B8 {
+ return shared_id_param;
}
let replacement_id = new_id();
result.push(Statement::Conversion(ImplicitConversion {
- src: shared_var_id,
+ src: shared_id_param,
dst: replacement_id,
- from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::B8),
- ast::LdStateSpace::Shared,
- ),
- to: ast::Type::Pointer(
- ast::PointerType::Scalar((*typ).into()),
- ast::LdStateSpace::Shared,
- ),
- kind: ConversionKind::PtrToPtr { spirv_ptr: true },
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
+ from_type: ast::Type::Scalar(ast::ScalarType::B8),
+ from_space: ast::StateSpace::Shared,
+ to_type: ast::Type::Scalar(*scalar_type),
+ to_space: ast::StateSpace::Shared,
+ kind: ConversionKind::PtrToPtr,
}));
replacement_id
} else {
@@ -912,16 +859,17 @@ fn replace_uses_of_shared_memory<'a>( }
}
}
+ result
}
fn get_callers_of_extern_shared<'a>(
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
- directly_called_by: &MultiHashMap<spirv::Word, MethodName<'a>>,
+ methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
+ directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
) {
let direct_uses_of_extern_shared = methods_using_extern_shared
.iter()
.filter_map(|method| {
- if let MethodName::Func(f_id) = method {
+ if let ast::MethodName::Func(f_id) = method {
Some(*f_id)
} else {
None
@@ -934,14 +882,14 @@ fn get_callers_of_extern_shared<'a>( }
fn get_callers_of_extern_shared_single<'a>(
- methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
- directly_called_by: &MultiHashMap<spirv::Word, MethodName<'a>>,
+ methods_using_extern_shared: &mut HashSet<ast::MethodName<'a, spirv::Word>>,
+ directly_called_by: &MultiHashMap<spirv::Word, ast::MethodName<'a, spirv::Word>>,
fn_id: spirv::Word,
) {
if let Some(callers) = directly_called_by.get(&fn_id) {
for caller in callers {
if methods_using_extern_shared.insert(*caller) {
- if let MethodName::Func(caller_fn) = caller {
+ if let ast::MethodName::Func(caller_fn) = caller {
get_callers_of_extern_shared_single(
methods_using_extern_shared,
directly_called_by,
@@ -983,18 +931,18 @@ fn denorm_count_map_update_impl<T: Eq + Hash>( // and emit suitable execution mode
fn compute_denorm_information<'input>(
module: &[Directive<'input>],
-) -> HashMap<MethodName<'input>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
+) -> HashMap<ast::MethodName<'input, spirv::Word>, HashMap<u8, (spirv::FPDenormMode, isize)>> {
let mut denorm_methods = HashMap::new();
for directive in module {
match directive {
- Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {}
+ Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => {}
Directive::Method(Function {
func_decl,
body: Some(statements),
..
}) => {
let mut flush_counter = DenormCountMap::new();
- let method_key = MethodName::new(func_decl);
+ let method_key = (**func_decl).borrow().name;
for statement in statements {
match statement {
Statement::Instruction(inst) => {
@@ -1038,21 +986,6 @@ fn compute_denorm_information<'input>( .collect()
}
-#[derive(Hash, PartialEq, Eq, Copy, Clone)]
-enum MethodName<'input> {
- Kernel(&'input str),
- Func(spirv::Word),
-}
-
-impl<'input> MethodName<'input> {
- fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
- match decl {
- ast::MethodDecl::Kernel { name, .. } => MethodName::Kernel(name),
- ast::MethodDecl::Func(_, id, _) => MethodName::Func(*id),
- }
- }
-}
-
fn emit_builtins(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -1061,10 +994,7 @@ fn emit_builtins( for (reg, id) in id_defs.special_registers.builtins() {
let result_type = map.get_or_add(
builder,
- SpirvType::Pointer(
- Box::new(SpirvType::from(reg.get_type())),
- spirv::StorageClass::Input,
- ),
+ SpirvType::pointer_to(reg.get_type(), spirv::StorageClass::Input),
);
builder.variable(result_type, Some(id), spirv::StorageClass::Input, None);
builder.decorate(
@@ -1079,18 +1009,21 @@ fn emit_function_header<'a>( builder: &mut dr::Builder,
map: &mut TypeWordMap,
defined_globals: &GlobalStringIdResolver<'a>,
- synthetic_globals: &[ast::Variable<ast::VariableType, spirv::Word>],
- func_decl: &SpirvMethodDecl<'a>,
- _denorm_information: &HashMap<MethodName<'a>, HashMap<u8, (spirv::FPDenormMode, isize)>>,
+ synthetic_globals: &[ast::Variable<spirv::Word>],
+ func_decl: &ast::MethodDeclaration<'a, spirv::Word>,
+ _denorm_information: &HashMap<
+ ast::MethodName<'a, spirv::Word>,
+ HashMap<u8, (spirv::FPDenormMode, isize)>,
+ >,
call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
direcitves: &[Directive],
kernel_info: &mut HashMap<String, KernelInfo>,
) -> Result<spirv::Word, TranslateError> {
- if let MethodName::Kernel(name) = func_decl.name {
- let input_args = if !func_decl.uses_shared_mem {
- func_decl.input.as_slice()
+ if let ast::MethodName::Kernel(name) = func_decl.name {
+ let input_args = if func_decl.shared_mem.is_none() {
+ func_decl.input_arguments.as_slice()
} else {
- &func_decl.input[0..func_decl.input.len() - 1]
+ &func_decl.input_arguments[0..func_decl.input_arguments.len() - 1]
};
let args_lens = input_args
.iter()
@@ -1100,14 +1033,18 @@ fn emit_function_header<'a>( name.to_string(),
KernelInfo {
arguments_sizes: args_lens,
- uses_shared_mem: func_decl.uses_shared_mem,
+ uses_shared_mem: func_decl.shared_mem.is_some(),
},
);
}
- let (ret_type, func_type) =
- get_function_type(builder, map, &func_decl.input, &func_decl.output);
+ let (ret_type, func_type) = get_function_type(
+ builder,
+ map,
+ func_decl.effective_input_arguments().map(|(_, typ)| typ),
+ &func_decl.return_arguments,
+ );
let fn_id = match func_decl.name {
- MethodName::Kernel(name) => {
+ ast::MethodName::Kernel(name) => {
let fn_id = defined_globals.get_id(name)?;
let mut global_variables = defined_globals
.variables_type_check
@@ -1123,15 +1060,18 @@ fn emit_function_header<'a>( for directive in direcitves {
match directive {
Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- globals,
- ..
+ func_decl, globals, ..
}) => {
- if child_fns.contains(name) {
- for var in globals {
- interface.push(var.name);
+ match (**func_decl).borrow().name {
+ ast::MethodName::Func(name) => {
+ if child_fns.contains(&name) {
+ for var in globals {
+ interface.push(var.name);
+ }
+ }
}
- }
+ ast::MethodName::Kernel(_) => {}
+ };
}
_ => {}
}
@@ -1140,7 +1080,7 @@ fn emit_function_header<'a>( builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables);
fn_id
}
- MethodName::Func(name) => name,
+ ast::MethodName::Func(name) => name,
};
builder.begin_function(
ret_type,
@@ -1163,9 +1103,9 @@ fn emit_function_header<'a>( }
}
*/
- for input in &func_decl.input {
- let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone()));
- builder.function_parameter(Some(input.name), result_type)?;
+ for (name, typ) in func_decl.effective_input_arguments() {
+ let result_type = map.get_or_add(builder, typ);
+ builder.function_parameter(Some(name), result_type)?;
}
Ok(fn_id)
}
@@ -1207,55 +1147,32 @@ fn translate_directive<'input>( d: ast::Directive<'input, ast::ParsedArgParams<'input>>,
) -> Result<Option<Directive<'input>>, TranslateError> {
Ok(match d {
- ast::Directive::Variable(v) => Some(Directive::Variable(translate_variable(id_defs, v)?)),
- ast::Directive::Method(f) => {
+ ast::Directive::Variable(linking, var) => Some(Directive::Variable(
+ linking,
+ ast::Variable {
+ align: var.align,
+ v_type: var.v_type.clone(),
+ state_space: var.state_space,
+ name: id_defs.get_or_add_def_typed(var.name, var.v_type, var.state_space, true),
+ array_init: var.array_init,
+ },
+ )),
+ ast::Directive::Method(_, f) => {
translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method)
}
})
}
-fn translate_variable<'a>(
- id_defs: &mut GlobalStringIdResolver<'a>,
- var: ast::Variable<ast::VariableType, &'a str>,
-) -> Result<ast::Variable<ast::VariableType, spirv::Word>, TranslateError> {
- let (space, var_type) = var.v_type.to_type();
- let mut is_variable = false;
- let var_type = match space {
- ast::StateSpace::Reg => {
- is_variable = true;
- var_type
- }
- ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?,
- ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
- ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
- ast::StateSpace::Shared => {
- // If it's a pointer it will be translated to a method parameter later
- if let ast::Type::Pointer(..) = var_type {
- is_variable = true;
- var_type
- } else {
- var_type.param_pointer_to(ast::LdStateSpace::Shared)?
- }
- }
- ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
- };
- Ok(ast::Variable {
- align: var.align,
- v_type: var.v_type,
- name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
- array_init: var.array_init,
- })
-}
-
fn translate_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
ptx_impl_imports: &mut HashMap<String, Directive<'a>>,
f: ast::ParsedFunction<'a>,
) -> Result<Option<Function<'a>>, TranslateError> {
let import_as = match &f.func_directive {
- ast::MethodDecl::Func(_, "__assertfail", _) => {
- Some("__zluda_ptx_impl____assertfail".to_owned())
- }
+ ast::MethodDeclaration {
+ name: ast::MethodName::Func("__assertfail"),
+ ..
+ } => Some("__zluda_ptx_impl____assertfail".to_owned()),
_ => None,
};
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
@@ -1279,63 +1196,38 @@ fn translate_function<'a>( }
}
-fn expand_kernel_params<'a, 'b>(
+fn rename_fn_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::KernelArgument<&'a str>>,
-) -> Result<Vec<ast::KernelArgument<spirv::Word>>, TranslateError> {
- args.map(|a| {
- Ok(ast::KernelArgument {
- name: fn_resolver.add_def(
- a.name,
- Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?),
- false,
- ),
+ args: &'b [ast::Variable<&'a str>],
+) -> Vec<ast::Variable<spirv::Word>> {
+ args.iter()
+ .map(|a| ast::Variable {
+ name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), true),
v_type: a.v_type.clone(),
+ state_space: a.state_space,
align: a.align,
- array_init: Vec::new(),
+ array_init: a.array_init.clone(),
})
- })
- .collect::<Result<_, _>>()
-}
-
-fn expand_fn_params<'a, 'b>(
- fn_resolver: &mut FnStringIdResolver<'a, 'b>,
- args: impl Iterator<Item = &'b ast::FnArgument<&'a str>>,
-) -> Result<Vec<ast::FnArgument<spirv::Word>>, TranslateError> {
- args.map(|a| {
- let is_variable = match a.v_type {
- ast::FnArgumentType::Reg(_) => true,
- _ => false,
- };
- let var_type = a.v_type.to_func_type();
- Ok(ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some(var_type), is_variable),
- v_type: a.v_type.clone(),
- align: a.align,
- array_init: Vec::new(),
- })
- })
- .collect()
+ .collect()
}
fn to_ssa<'input, 'b>(
ptx_impl_imports: &mut HashMap<String, Directive>,
mut id_defs: FnStringIdResolver<'input, 'b>,
fn_defs: GlobalFnDeclResolver<'input, 'b>,
- f_args: ast::MethodDecl<'input, spirv::Word>,
+ func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
tuning: Vec<ast::TuningDirective>,
) -> Result<Function<'input>, TranslateError> {
- let mut spirv_decl = SpirvMethodDecl::new(&f_args);
+ //deparamize_function_decl(&func_decl)?;
let f_body = match f_body {
Some(vec) => vec,
None => {
return Ok(Function {
- func_decl: f_args,
+ func_decl: func_decl,
body: None,
globals: Vec::new(),
import_as: None,
- spirv_decl,
tuning,
})
}
@@ -1345,15 +1237,14 @@ fn to_ssa<'input, 'b>( let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
- let typed_statements =
- convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
+ let (func_decl, typed_statements) =
+ convert_to_stateful_memory_access(func_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
- &f_args,
- &mut spirv_decl,
+ &mut (*func_decl).borrow_mut(),
)?;
- let ssa_statements = fix_builtins(ssa_statements, &mut numeric_id_defs)?;
+ let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
@@ -1363,16 +1254,15 @@ fn to_ssa<'input, 'b>( let (f_body, globals) =
extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs);
Ok(Function {
- func_decl: f_args,
+ func_decl: func_decl,
globals: globals,
body: Some(f_body),
import_as: None,
- spirv_decl,
tuning,
})
}
-fn fix_builtins(
+fn fix_special_registers(
typed_statements: Vec<TypedStatement>,
numeric_id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
@@ -1408,7 +1298,8 @@ fn fix_builtins( continue;
}
};
- let temp_id = numeric_id_defs.new_non_variable(Some(details.typ.clone()));
+ let temp_id = numeric_id_defs
+ .register_intermediate(Some((details.typ.clone(), details.state_space)));
let real_dst = details.arg.dst;
details.arg.dst = temp_id;
result.push(Statement::LoadVar(LoadVarDetails {
@@ -1416,17 +1307,18 @@ fn fix_builtins( src: sreg_src,
dst: temp_id,
},
+ state_space: ast::StateSpace::Sreg,
typ: ast::Type::Scalar(scalar_typ),
member_index: Some((index, Some(vector_width))),
}));
result.push(Statement::Conversion(ImplicitConversion {
src: temp_id,
dst: real_dst,
- from: ast::Type::Scalar(scalar_typ),
- to: ast::Type::Scalar(ast::ScalarType::U32),
+ from_type: ast::Type::Scalar(scalar_typ),
+ from_space: ast::StateSpace::Sreg,
+ to_type: ast::Type::Scalar(ast::ScalarType::U32),
+ to_space: ast::StateSpace::Sreg,
kind: ConversionKind::Default,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
}));
}
}
@@ -1456,10 +1348,7 @@ fn extract_globals<'input, 'b>( sorted_statements: Vec<ExpandedStatement>,
ptx_impl_imports: &mut HashMap<String, Directive>,
id_def: &mut NumericIdResolver,
-) -> (
- Vec<ExpandedStatement>,
- Vec<ast::Variable<ast::VariableType, spirv::Word>>,
-) {
+) -> (Vec<ExpandedStatement>, Vec<ast::Variable<spirv::Word>>) {
let mut local = Vec::with_capacity(sorted_statements.len());
let mut global = Vec::new();
for statement in sorted_statements {
@@ -1468,7 +1357,7 @@ fn extract_globals<'input, 'b>( var
@
ast::Variable {
- v_type: ast::VariableType::Shared(_),
+ state_space: ast::StateSpace::Shared,
..
},
)
@@ -1476,7 +1365,7 @@ fn extract_globals<'input, 'b>( var
@
ast::Variable {
- v_type: ast::VariableType::Global(_),
+ state_space: ast::StateSpace::Global,
..
},
) => global.push(var),
@@ -1505,7 +1394,7 @@ fn extract_globals<'input, 'b>( d,
a,
"inc",
- ast::SizedScalarType::U32,
+ ast::ScalarType::U32,
));
}
Statement::Instruction(ast::Instruction::Atom(
@@ -1527,7 +1416,7 @@ fn extract_globals<'input, 'b>( d,
a,
"dec",
- ast::SizedScalarType::U32,
+ ast::ScalarType::U32,
));
}
Statement::Instruction(ast::Instruction::Atom(
@@ -1553,10 +1442,9 @@ fn extract_globals<'input, 'b>( space,
};
let (op, typ) = match typ {
- ast::FloatType::F32 => ("add_f32", ast::SizedScalarType::F32),
- ast::FloatType::F64 => ("add_f64", ast::SizedScalarType::F64),
- ast::FloatType::F16 => unreachable!(),
- ast::FloatType::F16x2 => unreachable!(),
+ ast::ScalarType::F32 => ("add_f32", ast::ScalarType::F32),
+ ast::ScalarType::F64 => ("add_f64", ast::ScalarType::F64),
+ _ => unreachable!(),
};
local.push(to_ptx_impl_atomic_call(
id_def,
@@ -1599,47 +1487,13 @@ fn convert_to_typed_statements( match s {
Statement::Instruction(inst) => match inst {
ast::Instruction::Call(call) => {
- // TODO: error out if lengths don't match
- let fn_def = fn_defs.get_fn_decl(call.func)?;
- let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
- let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params);
- let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args
- .into_iter()
- .partition(|(_, arg_type)| arg_type.is_param());
- let normalized_input_args = out_params
- .into_iter()
- .map(|(id, typ)| (ast::Operand::Reg(id), typ))
- .chain(in_args.into_iter())
- .collect();
- let resolved_call = ResolvedCall {
- uniform: call.uniform,
- ret_params: out_non_params,
- func: call.func,
- param_list: normalized_input_args,
- };
+ let resolver = fn_defs.get_fn_sig_resolver(call.func)?;
+ let resolved_call = resolver.resolve_in_spirv_repr(call)?;
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let reresolved_call = resolved_call.visit(&mut visitor)?;
visitor.func.push(reresolved_call);
visitor.func.extend(visitor.post_stmts);
}
- ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => {
- if let Some(src_id) = src.underlying() {
- let (typ, _) = id_defs.get_typed(*src_id)?;
- let take_address = match typ {
- ast::Type::Scalar(_) => false,
- ast::Type::Vector(_, _) => false,
- ast::Type::Array(_, _) => true,
- ast::Type::Pointer(_, _) => true,
- };
- d.src_is_address = take_address;
- }
- let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
- let instruction = Statement::Instruction(
- ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?,
- );
- visitor.func.push(instruction);
- visitor.func.extend(visitor.post_stmts);
- }
inst => {
let mut visitor = VectorRepackVisitor::new(&mut result, id_defs);
let instruction = Statement::Instruction(inst.map(&mut visitor)?);
@@ -1674,8 +1528,14 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { fn convert_vector(
&mut self,
is_dst: bool,
- vector_sema: ArgumentSemantics,
+ non_default_implicit_conversion: Option<
+ fn(
+ (ast::StateSpace, &ast::Type),
+ (ast::StateSpace, &ast::Type),
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ >,
typ: &ast::Type,
+ state_space: ast::StateSpace,
idx: Vec<spirv::Word>,
) -> Result<spirv::Word, TranslateError> {
// mov.u32 foobar, {a,b};
@@ -1683,13 +1543,15 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> { ast::Type::Vector(scalar_t, _) => *scalar_t,
_ => return Err(TranslateError::MismatchedType),
};
- let temp_vec = self.id_def.new_non_variable(Some(typ.clone()));
+ let temp_vec = self
+ .id_def
+ .register_intermediate(Some((typ.clone(), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
packed: temp_vec,
unpacked: idx,
- vector_sema,
+ non_default_implicit_conversion,
});
if is_dst {
self.post_stmts = Some(statement);
@@ -1706,7 +1568,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams> fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- _: Option<&ast::Type>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(desc.op)
}
@@ -1715,15 +1577,20 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams> &mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
ast::Operand::Reg(reg) => TypedOperand::Reg(reg),
ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
ast::Operand::Imm(x) => TypedOperand::Imm(x),
ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
- ast::Operand::VecPack(vec) => {
- TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?)
- }
+ ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector(
+ desc.is_dst,
+ desc.non_default_implicit_conversion,
+ typ,
+ state_space,
+ vec,
+ )?),
})
}
}
@@ -1735,7 +1602,7 @@ fn to_ptx_impl_atomic_call( details: ast::AtomDetails,
arg: ast::Arg3<ExpandedArgParams>,
op: &'static str,
- typ: ast::SizedScalarType,
+ typ: ast::ScalarType,
) -> ExpandedStatement {
let semantics = ptx_semantics_name(details.semantics);
let scope = ptx_scope_name(details.scope);
@@ -1745,75 +1612,70 @@ fn to_ptx_impl_atomic_call( semantics, scope, space, op
);
// TODO: extract to a function
- let ptr_space = match details.space {
- ast::AtomSpace::Generic => ast::PointerStateSpace::Generic,
- ast::AtomSpace::Global => ast::PointerStateSpace::Global,
- ast::AtomSpace::Shared => ast::PointerStateSpace::Shared,
- };
+ let ptr_space = details.space;
let scalar_typ = ast::ScalarType::from(typ);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
+ let fn_id = id_defs.register_intermediate(None);
+ let func_decl = ast::MethodDeclaration::<spirv::Word> {
+ return_arguments: vec![ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(scalar_typ),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
- fn_id,
- vec![
- ast::FnArgument {
+ name: ast::MethodName::Func(fn_id),
+ input_arguments: vec![
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(
- typ, ptr_space,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Pointer(typ, ptr_space),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(scalar_typ),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ shared_mem: None,
+ };
let func = Function {
- func_decl,
+ func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
- spirv_decl,
tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => fn_id,
+ ast::MethodName::Kernel(_) => unreachable!(),
+ },
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
uniform: false,
- func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
- )],
- param_list: vec![
+ name: fn_id,
+ return_arguments: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
+ input_arguments: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Pointer(typ, ptr_space)),
+ ast::Type::Pointer(typ, ptr_space),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(scalar_typ)),
+ ast::Type::Scalar(scalar_typ),
+ ast::StateSpace::Reg,
),
],
})
@@ -1822,93 +1684,92 @@ fn to_ptx_impl_atomic_call( fn to_ptx_impl_bfe_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
- typ: ast::IntType,
+ typ: ast::ScalarType,
arg: ast::Arg4<ExpandedArgParams>,
) -> ExpandedStatement {
let prefix = "__zluda_ptx_impl__";
let suffix = match typ {
- ast::IntType::U32 => "bfe_u32",
- ast::IntType::U64 => "bfe_u64",
- ast::IntType::S32 => "bfe_s32",
- ast::IntType::S64 => "bfe_s64",
+ ast::ScalarType::U32 => "bfe_u32",
+ ast::ScalarType::U64 => "bfe_u64",
+ ast::ScalarType::S32 => "bfe_s32",
+ ast::ScalarType::S64 => "bfe_s64",
_ => unreachable!(),
};
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
+ let fn_id = id_defs.register_intermediate(None);
+ let func_decl = ast::MethodDeclaration::<spirv::Word> {
+ return_arguments: vec![ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
- fn_id,
- vec![
- ast::FnArgument {
+ name: ast::MethodName::Func(fn_id),
+ input_arguments: vec![
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ shared_mem: None,
+ };
let func = Function {
- func_decl,
+ func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
- spirv_decl,
tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => fn_id,
+ ast::MethodName::Kernel(_) => unreachable!(),
+ },
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
uniform: false,
- func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- )],
- param_list: vec![
+ name: fn_id,
+ return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
+ input_arguments: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
(
arg.src3,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
],
})
@@ -1917,117 +1778,107 @@ fn to_ptx_impl_bfe_call( fn to_ptx_impl_bfi_call(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
- typ: ast::BitType,
+ typ: ast::ScalarType,
arg: ast::Arg5<ExpandedArgParams>,
) -> ExpandedStatement {
let prefix = "__zluda_ptx_impl__";
let suffix = match typ {
- ast::BitType::B32 => "bfi_b32",
- ast::BitType::B64 => "bfi_b64",
- ast::BitType::B8 | ast::BitType::B16 => unreachable!(),
+ ast::ScalarType::B32 => "bfi_b32",
+ ast::ScalarType::B64 => "bfi_b64",
+ _ => unreachable!(),
};
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
- let func_decl = ast::MethodDecl::Func::<spirv::Word>(
- vec![ast::FnArgument {
+ let fn_id = id_defs.register_intermediate(None);
+ let func_decl = ast::MethodDeclaration::<spirv::Word> {
+ return_arguments: vec![ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
- fn_id,
- vec![
- ast::FnArgument {
+ name: ast::MethodName::Func(fn_id),
+ input_arguments: vec![
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(typ.into()),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
- ast::FnArgument {
+ ast::Variable {
align: None,
- v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(
- ast::ScalarType::U32,
- )),
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Scalar(ast::ScalarType::U32),
+ state_space: ast::StateSpace::Reg,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
- );
- let spirv_decl = SpirvMethodDecl::new(&func_decl);
+ shared_mem: None,
+ };
let func = Function {
- func_decl,
+ func_decl: Rc::new(RefCell::new(func_decl)),
globals: Vec::new(),
body: None,
import_as: Some(entry.key().clone()),
- spirv_decl,
tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
}
hash_map::Entry::Occupied(entry) => match entry.get() {
- Directive::Method(Function {
- func_decl: ast::MethodDecl::Func(_, name, _),
- ..
- }) => *name,
+ Directive::Method(Function { func_decl, .. }) => match (**func_decl).borrow().name {
+ ast::MethodName::Func(fn_id) => fn_id,
+ ast::MethodName::Kernel(_) => unreachable!(),
+ },
_ => unreachable!(),
},
};
Statement::Call(ResolvedCall {
uniform: false,
- func: fn_id,
- ret_params: vec![(
- arg.dst,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
- )],
- param_list: vec![
+ name: fn_id,
+ return_arguments: vec![(arg.dst, ast::Type::Scalar(typ.into()), ast::StateSpace::Reg)],
+ input_arguments: vec![
(
arg.src1,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src2,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())),
+ ast::Type::Scalar(typ.into()),
+ ast::StateSpace::Reg,
),
(
arg.src3,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
(
arg.src4,
- ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)),
+ ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
),
],
})
}
-fn to_resolved_fn_args<T>(
- params: Vec<T>,
- params_decl: &[ast::FnArgumentType],
-) -> Vec<(T, ast::FnArgumentType)> {
- params
- .into_iter()
- .zip(params_decl.iter())
- .map(|(id, typ)| (id, typ.clone()))
- .collect::<Vec<_>>()
-}
-
fn normalize_labels(
func: Vec<ExpandedStatement>,
id_def: &mut NumericIdResolver,
@@ -2056,7 +1907,7 @@ fn normalize_labels( | Statement::RepackVector(..) => {}
}
}
- iter::once(Statement::Label(id_def.new_non_variable(None)))
+ iter::once(Statement::Label(id_def.register_intermediate(None)))
.chain(func.into_iter().filter(|s| match s {
Statement::Label(i) => labels_in_use.contains(i),
_ => true,
@@ -2074,8 +1925,8 @@ fn normalize_predicates( Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Instruction((pred, inst)) => {
if let Some(pred) = pred {
- let if_true = id_def.new_non_variable(None);
- let if_false = id_def.new_non_variable(None);
+ let if_true = id_def.register_intermediate(None);
+ let if_false = id_def.register_intermediate(None);
let folded_bra = match &inst {
ast::Instruction::Bra(_, arg) => Some(arg.src),
_ => None,
@@ -2106,53 +1957,52 @@ fn normalize_predicates( Ok(result)
}
+/*
+ How do we handle arguments:
+ - input .params in kernels
+ .param .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %ulong
+ %2 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %1
+ We do this for two reasons. One, common treatment for argument-declared
+ .param variables and .param variables inside function (we assume that
+ at SPIR-V level every .param is a pointer in Function storage class)
+ - input .params in functions
+ .param .b64 in_arg
+ get turned into this SPIR-V:
+ %1 = OpFunctionParameter %_ptr_Function_ulong
+ - input .regs
+ .reg .b64 in_arg
+ get turned into the same SPIR-V as kernel .params:
+ %1 = OpFunctionParameter %ulong
+ %2 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %1
+ - output .regs
+ .reg .b64 out_arg
+ get just a variable declaration:
+ %2 = OpVariable %%_ptr_Function_ulong Function
+ - output .params don't exist, they have been moved to input positions
+ by an earlier pass
+ Distinguishing betweem kernel .params and function .params is not the
+ cleanest solution. Alternatively, we could "deparamize" all kernel .param
+ arguments by turning them into .reg arguments like this:
+ .param .b64 arg -> .reg ptr<.b64,.param> arg
+ This has the massive downside that this transformation would have to run
+ very early and would muddy up already difficult code. It's simpler to just
+ have an if here
+*/
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<TypedStatement>,
id_def: &mut NumericIdResolver,
- ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>,
- fn_decl: &mut SpirvMethodDecl,
+ fn_decl: &'a mut ast::MethodDeclaration<'b, spirv::Word>,
) -> Result<Vec<TypedStatement>, TranslateError> {
- let is_func = match ast_fn_decl {
- ast::MethodDecl::Func(..) => true,
- ast::MethodDecl::Kernel { .. } => false,
- };
let mut result = Vec::with_capacity(func.len());
- for arg in fn_decl.output.iter() {
- match type_to_variable_type(&arg.v_type, is_func)? {
- Some(var_type) => {
- result.push(Statement::Variable(ast::Variable {
- align: arg.align,
- v_type: var_type,
- name: arg.name,
- array_init: arg.array_init.clone(),
- }));
- }
- None => return Err(error_unreachable()),
- }
+ for arg in fn_decl.input_arguments.iter_mut() {
+ insert_mem_ssa_argument(id_def, &mut result, arg, fn_decl.name.is_kernel());
}
- for spirv_arg in fn_decl.input.iter_mut() {
- match type_to_variable_type(&spirv_arg.v_type, is_func)? {
- Some(var_type) => {
- let typ = spirv_arg.v_type.clone();
- let new_id = id_def.new_non_variable(Some(typ.clone()));
- result.push(Statement::Variable(ast::Variable {
- align: spirv_arg.align,
- v_type: var_type,
- name: spirv_arg.name,
- array_init: spirv_arg.array_init.clone(),
- }));
- result.push(Statement::StoreVar(StoreVarDetails {
- arg: ast::Arg2St {
- src1: spirv_arg.name,
- src2: new_id,
- },
- typ,
- member_index: None,
- }));
- spirv_arg.name = new_id;
- }
- None => {}
- }
+ for arg in fn_decl.return_arguments.iter() {
+ insert_mem_ssa_argument_reg_return(&mut result, arg);
}
for s in func {
match s {
@@ -2162,32 +2012,41 @@ fn insert_mem_ssa_statements<'a, 'b>( Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
// TODO: handle multiple output args
- if let &[out_param] = &fn_decl.output.as_slice() {
- let (typ, _) = id_def.get_typed(out_param.name)?;
- let new_id = id_def.new_non_variable(Some(typ.clone()));
- result.push(Statement::LoadVar(LoadVarDetails {
- arg: ast::Arg2 {
- dst: new_id,
- src: out_param.name,
- },
- typ: typ.clone(),
- member_index: None,
- }));
- result.push(Statement::RetValue(d, new_id));
- } else {
- result.push(Statement::Instruction(ast::Instruction::Ret(d)))
+ match &fn_decl.return_arguments[..] {
+ [return_reg] => {
+ let new_id = id_def.register_intermediate(Some((
+ return_reg.v_type.clone(),
+ ast::StateSpace::Reg,
+ )));
+ result.push(Statement::LoadVar(LoadVarDetails {
+ arg: ast::Arg2 {
+ dst: new_id,
+ src: return_reg.name,
+ },
+ // TODO: ret with stateful conversion
+ state_space: ast::StateSpace::Reg,
+ typ: return_reg.v_type.clone(),
+ member_index: None,
+ }));
+ result.push(Statement::RetValue(d, new_id));
+ }
+ [] => result.push(Statement::Instruction(ast::Instruction::Ret(d))),
+ _ => unimplemented!(),
}
}
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
},
Statement::Conditional(mut bra) => {
- let generated_id =
- id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
+ let generated_id = id_def.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )));
result.push(Statement::LoadVar(LoadVarDetails {
arg: Arg2 {
dst: generated_id,
src: bra.predicate,
},
+ state_space: ast::StateSpace::Reg,
typ: ast::Type::Scalar(ast::ScalarType::Pred),
member_index: None,
}));
@@ -2210,39 +2069,45 @@ fn insert_mem_ssa_statements<'a, 'b>( Ok(result)
}
-fn type_to_variable_type(
- t: &ast::Type,
- is_func: bool,
-) -> Result<Option<ast::VariableType>, TranslateError> {
- Ok(match t {
- ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))),
- ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector(
- (*typ)
- .try_into()
- .map_err(|_| TranslateError::MismatchedType)?,
- *len,
- ))),
- ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array(
- (*typ)
- .try_into()
- .map_err(|_| TranslateError::MismatchedType)?,
- len.clone(),
- ))),
- ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => {
- if is_func {
- return Ok(None);
- }
- Some(ast::VariableType::Reg(ast::VariableRegType::Pointer(
- scalar_type
- .clone()
- .try_into()
- .map_err(|_| error_unreachable())?,
- (*space).try_into().map_err(|_| error_unreachable())?,
- )))
- }
- ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None,
- _ => return Err(error_unreachable()),
- })
+fn insert_mem_ssa_argument(
+ id_def: &mut NumericIdResolver,
+ func: &mut Vec<TypedStatement>,
+ arg: &mut ast::Variable<spirv::Word>,
+ is_kernel: bool,
+) {
+ if !is_kernel && arg.state_space == ast::StateSpace::Param {
+ return;
+ }
+ let new_id = id_def.register_intermediate(Some((arg.v_type.clone(), arg.state_space)));
+ func.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: ast::StateSpace::Reg,
+ name: arg.name,
+ array_init: Vec::new(),
+ }));
+ func.push(Statement::StoreVar(StoreVarDetails {
+ arg: ast::Arg2St {
+ src1: arg.name,
+ src2: new_id,
+ },
+ typ: arg.v_type.clone(),
+ member_index: None,
+ }));
+ arg.name = new_id;
+}
+
+fn insert_mem_ssa_argument_reg_return(
+ func: &mut Vec<TypedStatement>,
+ arg: &ast::Variable<spirv::Word>,
+) {
+ func.push(Statement::Variable(ast::Variable {
+ align: arg.align,
+ v_type: arg.v_type.clone(),
+ state_space: arg.state_space,
+ name: arg.name,
+ array_init: arg.array_init.clone(),
+ }));
}
trait Visitable<From: ArgParamsEx, To: ArgParamsEx>: Sized {
@@ -2259,6 +2124,7 @@ struct VisitArgumentDescriptor< > {
desc: ArgumentDescriptor<spirv::Word>,
typ: &'a ast::Type,
+ state_space: ast::StateSpace,
stmt_ctor: Ctor,
}
@@ -2273,7 +2139,9 @@ impl< self,
visitor: &mut impl ArgumentMapVisitor<T, U>,
) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
- Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?))
+ Ok((self.stmt_ctor)(
+ visitor.id(self.desc, Some((self.typ, self.state_space)))?,
+ ))
}
}
@@ -2287,14 +2155,14 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { fn symbol(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, Option<u8>)>,
- expected_type: Option<&ast::Type>,
+ expected: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
let symbol = desc.op.0;
- if expected_type.is_none() {
+ if expected.is_none() {
return Ok(symbol);
};
- let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?;
- if !is_variable {
+ let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?;
+ if !var_space.is_compatible(ast::StateSpace::Reg) || !is_variable {
return Ok(symbol);
};
let member_index = match desc.op.1 {
@@ -2317,13 +2185,16 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { }
None => None,
};
- let generated_id = self.id_def.new_non_variable(Some(var_type.clone()));
+ let generated_id = self
+ .id_def
+ .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
if !desc.is_dst {
self.func.push(Statement::LoadVar(LoadVarDetails {
arg: Arg2 {
dst: generated_id,
src: symbol,
},
+ state_space: ast::StateSpace::Reg,
typ: var_type,
member_index,
}));
@@ -2348,7 +2219,7 @@ impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams> fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- typ: Option<&ast::Type>,
+ typ: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self.symbol(desc.new_op((desc.op, None)), typ)
}
@@ -2357,18 +2228,20 @@ impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams> &mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
TypedOperand::Reg(reg) => {
- TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
- }
- TypedOperand::RegOffset(reg, offset) => {
- TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset)
+ TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?)
}
+ TypedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(
+ self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?,
+ offset,
+ ),
op @ TypedOperand::Imm(..) => op,
- TypedOperand::VecMember(symbol, index) => {
- TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
- }
+ TypedOperand::VecMember(symbol, index) => TypedOperand::Reg(
+ self.symbol(desc.new_op((symbol, Some(index))), Some((typ, state_space)))?,
+ ),
})
}
}
@@ -2411,11 +2284,13 @@ fn expand_arguments<'a, 'b>( Statement::Variable(ast::Variable {
align,
v_type,
+ state_space,
name,
array_init,
}) => result.push(Statement::Variable(ast::Variable {
align,
v_type,
+ state_space,
name,
array_init,
})),
@@ -2464,7 +2339,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { fn reg(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- _: Option<&ast::Type>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(desc.op)
}
@@ -2473,108 +2348,86 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { &mut self,
desc: ArgumentDescriptor<(spirv::Word, i32)>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
let (reg, offset) = desc.op;
- let add_type;
- match typ {
- ast::Type::Pointer(underlying_type, state_space) => {
- let reg_typ = self.id_def.get_typed(reg)?;
- if let ast::Type::Pointer(_, _) = reg_typ {
- let id_constant_stmt = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: ast::ScalarType::S64,
- value: ast::ImmediateValue::S64(offset as i64),
- }));
- let dst = self.id_def.new_non_variable(typ.clone());
- self.func.push(Statement::PtrAccess(PtrAccess {
- underlying_type: underlying_type.clone(),
- state_space: *state_space,
- dst,
- ptr_src: reg,
- offset_src: id_constant_stmt,
- }));
- return Ok(dst);
- } else {
- add_type = self.id_def.get_typed(reg)?;
- }
- }
- _ => {
- add_type = typ.clone();
+ if !desc.is_memory_access {
+ let (reg_type, reg_space) = self.id_def.get_typed(reg)?;
+ if !reg_space.is_compatible(ast::StateSpace::Reg) {
+ return Err(TranslateError::MismatchedType);
}
- };
- let (width, kind) = match add_type {
- ast::Type::Scalar(scalar_t) => {
- let kind = match scalar_t.kind() {
- kind @ ScalarKind::Bit
- | kind @ ScalarKind::Unsigned
- | kind @ ScalarKind::Signed => kind,
- ScalarKind::Float => return Err(TranslateError::MismatchedType),
- ScalarKind::Float2 => return Err(TranslateError::MismatchedType),
- ScalarKind::Pred => return Err(TranslateError::MismatchedType),
- };
- (scalar_t.size_of(), kind)
- }
- _ => return Err(TranslateError::MismatchedType),
- };
- let arith_detail = if kind == ScalarKind::Signed {
- ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::from_size(width),
- saturate: false,
- })
- } else {
- ast::ArithDetails::Unsigned(ast::UIntType::from_size(width))
- };
- let id_constant_stmt = self.id_def.new_non_variable(add_type.clone());
- let result_id = self.id_def.new_non_variable(add_type);
- // TODO: check for edge cases around min value/max value/wrapping
- if offset < 0 && kind != ScalarKind::Signed {
+ let reg_scalar_type = match reg_type {
+ ast::Type::Scalar(underlying_type) => underlying_type,
+ _ => return Err(TranslateError::MismatchedType),
+ };
+ let id_constant_stmt = self
+ .id_def
+ .register_intermediate(reg_type.clone(), ast::StateSpace::Reg);
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
- typ: ast::ScalarType::from_parts(width, kind),
- value: ast::ImmediateValue::U64(-(offset as i64) as u64),
+ typ: reg_scalar_type,
+ value: ast::ImmediateValue::S64(offset as i64),
}));
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Sub(
- arith_detail,
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
+ let arith_details = match reg_scalar_type.kind() {
+ ast::ScalarKind::Signed => ast::ArithDetails::Signed(ast::ArithSInt {
+ typ: reg_scalar_type,
+ saturate: false,
+ }),
+ ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
+ ast::ArithDetails::Unsigned(reg_scalar_type)
+ }
+ _ => return Err(error_unreachable()),
+ };
+ let id_add_result = self.id_def.register_intermediate(reg_type, state_space);
+ self.func.push(Statement::Instruction(ast::Instruction::Add(
+ arith_details,
+ ast::Arg3 {
+ dst: id_add_result,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ )));
+ Ok(id_add_result)
} else {
+ let scalar_type = match typ {
+ ast::Type::Scalar(underlying_type) => *underlying_type,
+ _ => return Err(error_unreachable()),
+ };
+ let id_constant_stmt = self.id_def.register_intermediate(
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ );
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
- typ: ast::ScalarType::from_parts(width, kind),
+ typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64),
}));
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- arith_detail,
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
+ let dst = self.id_def.register_intermediate(typ.clone(), state_space);
+ self.func.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: scalar_type,
+ state_space: state_space,
+ dst,
+ ptr_src: reg,
+ offset_src: id_constant_stmt,
+ }));
+ Ok(dst)
}
- Ok(result_id)
}
fn immediate(
&mut self,
desc: ArgumentDescriptor<ast::ImmediateValue>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
*scalar
} else {
todo!()
};
- let id = self.id_def.new_non_variable(ast::Type::Scalar(scalar_t));
+ let id = self
+ .id_def
+ .register_intermediate(ast::Type::Scalar(scalar_t), state_space);
self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
@@ -2588,7 +2441,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self.reg(desc, t)
}
@@ -2597,12 +2450,13 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr &mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
- TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
- TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ),
+ TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some((typ, state_space))),
+ TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ, state_space),
TypedOperand::RegOffset(reg, offset) => {
- self.reg_offset(desc.new_op((reg, offset)), typ)
+ self.reg_offset(desc.new_op((reg, offset)), typ, state_space)
}
TypedOperand::VecMember(..) => Err(error_unreachable()),
}
@@ -2630,79 +2484,18 @@ fn insert_implicit_conversions( let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
match s {
- Statement::Call(call) => insert_implicit_conversions_impl(
- &mut result,
- id_def,
- call,
- should_bitcast_wrapper,
- None,
- )?,
+ Statement::Call(call) => {
+ insert_implicit_conversions_impl(&mut result, id_def, call)?;
+ }
Statement::Instruction(inst) => {
- let mut default_conversion_fn =
- should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _;
- let mut state_space = None;
- if let ast::Instruction::Ld(d, _) = &inst {
- state_space = Some(d.state_space);
- }
- if let ast::Instruction::St(d, _) = &inst {
- state_space = Some(d.state_space.to_ld_ss());
- }
- if let ast::Instruction::Atom(d, _) = &inst {
- state_space = Some(d.space.to_ld_ss());
- }
- if let ast::Instruction::AtomCas(d, _) = &inst {
- state_space = Some(d.space.to_ld_ss());
- }
- if let ast::Instruction::Mov(..) = &inst {
- default_conversion_fn = should_bitcast_packed;
- }
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- inst,
- default_conversion_fn,
- state_space,
- )?;
+ insert_implicit_conversions_impl(&mut result, id_def, inst)?;
}
- Statement::PtrAccess(PtrAccess {
- underlying_type,
- state_space,
- dst,
- ptr_src,
- offset_src: constant_src,
- }) => {
- let visit_desc = VisitArgumentDescriptor {
- desc: ArgumentDescriptor {
- op: ptr_src,
- is_dst: false,
- sema: ArgumentSemantics::PhysicalPointer,
- },
- typ: &ast::Type::Pointer(underlying_type.clone(), state_space),
- stmt_ctor: |new_ptr_src| {
- Statement::PtrAccess(PtrAccess {
- underlying_type,
- state_space,
- dst,
- ptr_src: new_ptr_src,
- offset_src: constant_src,
- })
- },
- };
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- visit_desc,
- bitcast_physical_pointer,
- Some(state_space),
- )?;
+ Statement::PtrAccess(access) => {
+ insert_implicit_conversions_impl(&mut result, id_def, access)?;
+ }
+ Statement::RepackVector(repack) => {
+ insert_implicit_conversions_impl(&mut result, id_def, repack)?;
}
- Statement::RepackVector(repack) => insert_implicit_conversions_impl(
- &mut result,
- id_def,
- repack,
- should_bitcast_wrapper,
- None,
- )?,
s @ Statement::Conditional(_)
| s @ Statement::Conversion(_)
| s @ Statement::Label(_)
@@ -2720,72 +2513,56 @@ fn insert_implicit_conversions_impl( func: &mut Vec<ExpandedStatement>,
id_def: &mut MutableNumericIdResolver,
stmt: impl Visitable<ExpandedArgParams, ExpandedArgParams>,
- default_conversion_fn: for<'a> fn(
- &'a ast::Type,
- &'a ast::Type,
- Option<ast::LdStateSpace>,
- ) -> Result<Option<ConversionKind>, TranslateError>,
- state_space: Option<ast::LdStateSpace>,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
- let statement = stmt.visit(
- &mut |desc: ArgumentDescriptor<spirv::Word>, typ: Option<&ast::Type>| {
- let instr_type = match typ {
+ let statement =
+ stmt.visit(&mut |desc: ArgumentDescriptor<spirv::Word>,
+ typ: Option<(&ast::Type, ast::StateSpace)>| {
+ let (instr_type, instruction_space) = match typ {
None => return Ok(desc.op),
Some(t) => t,
};
- let operand_type = id_def.get_typed(desc.op)?;
- let mut conversion_fn = default_conversion_fn;
- match desc.sema {
- ArgumentSemantics::Default => {}
- ArgumentSemantics::DefaultRelaxed => {
- if desc.is_dst {
- conversion_fn = should_convert_relaxed_dst_wrapper;
- } else {
- conversion_fn = should_convert_relaxed_src_wrapper;
- }
- }
- ArgumentSemantics::PhysicalPointer => {
- conversion_fn = bitcast_physical_pointer;
- }
- ArgumentSemantics::RegisterPointer => {
- conversion_fn = bitcast_register_pointer;
- }
- ArgumentSemantics::Address => {
- conversion_fn = force_bitcast_ptr_to_bit;
- }
- };
- match conversion_fn(&operand_type, instr_type, state_space)? {
+ let (operand_type, operand_space) = id_def.get_typed(desc.op)?;
+ let conversion_fn = desc
+ .non_default_implicit_conversion
+ .unwrap_or(default_implicit_conversion);
+ match conversion_fn(
+ (operand_space, &operand_type),
+ (instruction_space, instr_type),
+ )? {
Some(conv_kind) => {
let conv_output = if desc.is_dst {
&mut post_conv
} else {
&mut *func
};
- let mut from = instr_type.clone();
- let mut to = operand_type;
- let mut src = id_def.new_non_variable(instr_type.clone());
+ let mut from_type = instr_type.clone();
+ let mut from_space = instruction_space;
+ let mut to_type = operand_type;
+ let mut to_space = operand_space;
+ let mut src =
+ id_def.register_intermediate(instr_type.clone(), instruction_space);
let mut dst = desc.op;
let result = Ok(src);
if !desc.is_dst {
mem::swap(&mut src, &mut dst);
- mem::swap(&mut from, &mut to);
+ mem::swap(&mut from_type, &mut to_type);
+ mem::swap(&mut from_space, &mut to_space);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
- from,
- to,
+ from_type,
+ from_space,
+ to_type,
+ to_space,
kind: conv_kind,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
}));
result
}
None => Ok(desc.op),
}
- },
- )?;
+ })?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
@@ -2794,17 +2571,15 @@ fn insert_implicit_conversions_impl( fn get_function_type(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- spirv_input: &[ast::Variable<ast::Type, spirv::Word>],
- spirv_output: &[ast::Variable<ast::Type, spirv::Word>],
+ spirv_input: impl Iterator<Item = SpirvType>,
+ spirv_output: &[ast::Variable<spirv::Word>],
) -> (spirv::Word, spirv::Word) {
map.get_or_add_fn(
builder,
- spirv_input
- .iter()
- .map(|var| SpirvType::from(var.v_type.clone())),
+ spirv_input,
spirv_output
.iter()
- .map(|var| SpirvType::from(var.v_type.clone())),
+ .map(|var| SpirvType::new(var.v_type.clone())),
)
}
@@ -2831,20 +2606,25 @@ fn emit_function_body_ops( match s {
Statement::Label(_) => (),
Statement::Call(call) => {
- let (result_type, result_id) = match &*call.ret_params {
- [(id, typ)] => (
- map.get_or_add(builder, SpirvType::from(typ.to_func_type())),
- Some(*id),
- ),
+ let (result_type, result_id) = match &*call.return_arguments {
+ [(id, typ, space)] => {
+ if *space != ast::StateSpace::Reg {
+ return Err(error_unreachable());
+ }
+ (
+ map.get_or_add(builder, SpirvType::new(typ.clone())),
+ Some(*id),
+ )
+ }
[] => (map.void(), None),
_ => todo!(),
};
let arg_list = call
- .param_list
+ .input_arguments
.iter()
- .map(|(id, _)| *id)
+ .map(|(id, _, _)| *id)
.collect::<Vec<_>>();
- builder.function_call(result_type, result_id, call.func, arg_list)?;
+ builder.function_call(result_type, result_id, call.name, arg_list)?;
}
Statement::Variable(var) => {
emit_variable(builder, map, var)?;
@@ -2966,7 +2746,7 @@ fn emit_function_body_ops( todo!()
}
let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone())));
+ map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone())));
builder.load(
result_type,
Some(arg.dst),
@@ -2998,7 +2778,7 @@ fn emit_function_body_ops( ast::Instruction::Ret(_) => builder.ret()?,
ast::Instruction::Mov(d, arg) => {
let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone())));
+ map.get_or_add(builder, SpirvType::new(ast::Type::from(d.typ.clone())));
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::Mul(mul, arg) => match mul {
@@ -3026,20 +2806,20 @@ fn emit_function_body_ops( emit_setp(builder, map, setp, arg)?;
}
ast::Instruction::Not(t, a) => {
- let result_type = map.get_or_add(builder, SpirvType::from(t.to_type()));
+ let result_type = map.get_or_add(builder, SpirvType::from(*t));
let result_id = Some(a.dst);
let operand = a.src;
match t {
- ast::BooleanType::Pred => {
+ ast::ScalarType::Pred => {
logical_not(builder, result_type, result_id, operand)
}
_ => builder.not(result_type, result_id, operand),
}?;
}
ast::Instruction::Shl(t, a) => {
- let full_type = t.to_type();
+ let full_type = ast::Type::Scalar(*t);
let size_of = full_type.size_of();
- let result_type = map.get_or_add(builder, SpirvType::from(full_type));
+ let result_type = map.get_or_add(builder, SpirvType::new(full_type));
let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?;
builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?;
}
@@ -3048,7 +2828,7 @@ fn emit_function_body_ops( let size_of = full_type.size_of();
let result_type = map.get_or_add_scalar(builder, full_type);
let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?;
- if t.signed() {
+ if t.kind() == ast::ScalarKind::Signed {
builder.shift_right_arithmetic(
result_type,
Some(a.dst),
@@ -3088,7 +2868,7 @@ fn emit_function_body_ops( },
ast::Instruction::Or(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- if *t == ast::BooleanType::Pred {
+ if *t == ast::ScalarType::Pred {
builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?;
@@ -3116,7 +2896,7 @@ fn emit_function_body_ops( }
ast::Instruction::And(t, a) => {
let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t));
- if *t == ast::BooleanType::Pred {
+ if *t == ast::ScalarType::Pred {
builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?;
} else {
builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?;
@@ -3202,7 +2982,7 @@ fn emit_function_body_ops( }
ast::Instruction::Neg(details, arg) => {
let result_type = map.get_or_add_scalar(builder, details.typ);
- let negate_func = if details.typ.kind() == ScalarKind::Float {
+ let negate_func = if details.typ.kind() == ast::ScalarKind::Float {
dr::Builder::f_negate
} else {
dr::Builder::s_negate
@@ -3269,7 +3049,7 @@ fn emit_function_body_ops( }
ast::Instruction::Xor { typ, arg } => {
let builder_fn = match typ {
- ast::BooleanType::Pred => emit_logical_xor_spirv,
+ ast::ScalarType::Pred => emit_logical_xor_spirv,
_ => dr::Builder::bitwise_xor,
};
let result_type = map.get_or_add_scalar(builder, (*typ).into());
@@ -3284,7 +3064,7 @@ fn emit_function_body_ops( return Err(error_unreachable());
}
ast::Instruction::Rem { typ, arg } => {
- let builder_fn = if typ.is_signed() {
+ let builder_fn = if typ.kind() == ast::ScalarKind::Signed {
dr::Builder::s_mod
} else {
dr::Builder::u_mod
@@ -3301,7 +3081,7 @@ fn emit_function_body_ops( Some(index) => {
let result_ptr_type = map.get_or_add(
builder,
- SpirvType::new_pointer(
+ SpirvType::pointer_to(
details.typ.clone(),
spirv::StorageClass::Function,
),
@@ -3334,14 +3114,11 @@ fn emit_function_body_ops( }) => {
let u8_pointer = map.get_or_add(
builder,
- SpirvType::from(ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- *state_space,
- )),
+ SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)),
);
let result_type = map.get_or_add(
builder,
- SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)),
+ SpirvType::new(ast::Type::Pointer(*underlying_type, *state_space)),
);
let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
let temp = builder.in_bounds_ptr_access_chain(
@@ -3553,11 +3330,16 @@ fn ptx_scope_name(scope: ast::MemScope) -> &'static str { }
}
-fn ptx_space_name(space: ast::AtomSpace) -> &'static str {
+fn ptx_space_name(space: ast::StateSpace) -> &'static str {
match space {
- ast::AtomSpace::Generic => "generic",
- ast::AtomSpace::Global => "global",
- ast::AtomSpace::Shared => "shared",
+ ast::StateSpace::Generic => "generic",
+ ast::StateSpace::Global => "global",
+ ast::StateSpace::Shared => "shared",
+ ast::StateSpace::Reg => "reg",
+ ast::StateSpace::Const => "const",
+ ast::StateSpace::Local => "local",
+ ast::StateSpace::Param => "param",
+ ast::StateSpace::Sreg => "sreg",
}
}
@@ -3612,14 +3394,17 @@ fn vec_repr<T: Copy>(t: T) -> Vec<u8> { fn emit_variable(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
- var: &ast::Variable<ast::VariableType, spirv::Word>,
+ var: &ast::Variable<spirv::Word>,
) -> Result<(), TranslateError> {
- let (must_init, st_class) = match var.v_type {
- ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
+ let (must_init, st_class) = match var.state_space {
+ ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
(false, spirv::StorageClass::Function)
}
- ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
- ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup),
+ ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup),
+ ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
+ ast::StateSpace::Const => todo!(),
+ ast::StateSpace::Generic => todo!(),
+ ast::StateSpace::Sreg => todo!(),
};
let initalizer = if var.array_init.len() > 0 {
Some(map.get_or_add_constant(
@@ -3628,18 +3413,12 @@ fn emit_variable( &*var.array_init,
)?)
} else if must_init {
- let type_id = map.get_or_add(
- builder,
- SpirvType::from(ast::Type::from(var.v_type.clone())),
- );
+ let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone()));
Some(builder.constant_null(type_id, None))
} else {
None
};
- let ptr_type_id = map.get_or_add(
- builder,
- SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class),
- );
+ let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class));
builder.variable(ptr_type_id, Some(var.name), st_class, initalizer);
if let Some(align) = var.align {
builder.decorate(
@@ -3777,7 +3556,7 @@ fn emit_min( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
};
- let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
+ let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
@@ -3802,7 +3581,7 @@ fn emit_max( ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
};
- let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
+ let inst_type = map.get_or_add(builder, SpirvType::new(desc.get_type()));
builder.ext_inst(
inst_type,
Some(arg.dst),
@@ -3882,7 +3661,7 @@ fn emit_cvt( }
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- if desc.src.is_signed() {
+ if desc.src.kind() == ast::ScalarKind::Signed {
builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?;
} else {
builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?;
@@ -3892,7 +3671,7 @@ fn emit_cvt( ast::CvtDetails::IntFromFloat(desc) => {
let dest_t: ast::ScalarType = desc.dst.into();
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
- if desc.dst.is_signed() {
+ if desc.dst.kind() == ast::ScalarKind::Signed {
builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?;
} else {
builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?;
@@ -3904,7 +3683,7 @@ fn emit_cvt( let dest_t: ast::ScalarType = desc.dst.into();
let src_t: ast::ScalarType = desc.src.into();
// first do shortening/widening
- let src = if desc.dst.width() != desc.src.width() {
+ let src = if desc.dst.size_of() != desc.src.size_of() {
let new_dst = if dest_t.kind() == src_t.kind() {
arg.dst
} else {
@@ -3913,14 +3692,14 @@ fn emit_cvt( let cv = ImplicitConversion {
src: arg.src,
dst: new_dst,
- from: ast::Type::Scalar(src_t),
- to: ast::Type::Scalar(ast::ScalarType::from_parts(
+ from_type: ast::Type::Scalar(src_t),
+ from_space: ast::StateSpace::Reg,
+ to_type: ast::Type::Scalar(ast::ScalarType::from_parts(
dest_t.size_of(),
src_t.kind(),
)),
+ to_space: ast::StateSpace::Reg,
kind: ConversionKind::Default,
- src_sema: ArgumentSemantics::Default,
- dst_sema: ArgumentSemantics::Default,
};
emit_implicit_conversion(builder, map, &cv)?;
new_dst
@@ -3933,7 +3712,7 @@ fn emit_cvt( // now do actual conversion
let result_type = map.get_or_add(builder, SpirvType::from(dest_t));
if desc.saturate {
- if desc.dst.is_signed() {
+ if desc.dst.kind() == ast::ScalarKind::Signed {
builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?;
} else {
builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?;
@@ -3989,60 +3768,60 @@ fn emit_setp( let operand_1 = arg.src1;
let operand_2 = arg.src2;
match (setp.cmp_op, setp.typ.kind()) {
- (ast::SetpCompareOp::Eq, ScalarKind::Signed)
- | (ast::SetpCompareOp::Eq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Eq, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::Eq, ast::ScalarKind::Signed)
+ | (ast::SetpCompareOp::Eq, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::Eq, ast::ScalarKind::Bit) => {
builder.i_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Eq, ScalarKind::Float) => {
+ (ast::SetpCompareOp::Eq, ast::ScalarKind::Float) => {
builder.f_ord_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::NotEq, ScalarKind::Signed)
- | (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::NotEq, ast::ScalarKind::Signed)
+ | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::NotEq, ast::ScalarKind::Bit) => {
builder.i_not_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::NotEq, ScalarKind::Float) => {
+ (ast::SetpCompareOp::NotEq, ast::ScalarKind::Float) => {
builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Less, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Less, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::Less, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::Less, ast::ScalarKind::Bit) => {
builder.u_less_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Less, ScalarKind::Signed) => {
+ (ast::SetpCompareOp::Less, ast::ScalarKind::Signed) => {
builder.s_less_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Less, ScalarKind::Float) => {
+ (ast::SetpCompareOp::Less, ast::ScalarKind::Float) => {
builder.f_ord_less_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Bit) => {
builder.u_less_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => {
+ (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Signed) => {
builder.s_less_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => {
+ (ast::SetpCompareOp::LessOrEq, ast::ScalarKind::Float) => {
builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Greater, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::Greater, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::Greater, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::Greater, ast::ScalarKind::Bit) => {
builder.u_greater_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Greater, ScalarKind::Signed) => {
+ (ast::SetpCompareOp::Greater, ast::ScalarKind::Signed) => {
builder.s_greater_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::Greater, ScalarKind::Float) => {
+ (ast::SetpCompareOp::Greater, ast::ScalarKind::Float) => {
builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned)
- | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => {
+ (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Unsigned)
+ | (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Bit) => {
builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => {
+ (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Signed) => {
builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
- (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => {
+ (ast::SetpCompareOp::GreaterOrEq, ast::ScalarKind::Float) => {
builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2)
}
(ast::SetpCompareOp::NanEq, _) => {
@@ -4222,7 +4001,7 @@ fn emit_abs( ) -> Result<(), dr::Error> {
let scalar_t = ast::ScalarType::from(d.typ);
let result_type = map.get_or_add(builder, SpirvType::from(scalar_t));
- let cl_abs = if scalar_t.kind() == ScalarKind::Signed {
+ let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed {
spirv::CLOp::s_abs
} else {
spirv::CLOp::fabs
@@ -4272,22 +4051,21 @@ fn emit_implicit_conversion( map: &mut TypeWordMap,
cv: &ImplicitConversion,
) -> Result<(), TranslateError> {
- let from_parts = cv.from.to_parts();
- let to_parts = cv.to.to_parts();
- match (from_parts.kind, to_parts.kind, cv.kind) {
- (_, _, ConversionKind::PtrToBit(typ)) => {
- let dst_type = map.get_or_add_scalar(builder, typ.into());
- builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
- }
- (_, _, ConversionKind::BitToPtr(_)) => {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ let from_parts = cv.from_type.to_parts();
+ let to_parts = cv.to_type.to_parts();
+ match (from_parts.kind, to_parts.kind, &cv.kind) {
+ (_, _, &ConversionKind::BitToPtr) => {
+ let dst_type = map.get_or_add(
+ builder,
+ SpirvType::pointer_to(cv.to_type.clone(), cv.to_space.to_spirv()),
+ );
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
- (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
+ (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => {
if from_parts.width == to_parts.width {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
- if from_parts.scalar_kind != ScalarKind::Float
- && to_parts.scalar_kind != ScalarKind::Float
+ let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ if from_parts.scalar_kind != ast::ScalarKind::Float
+ && to_parts.scalar_kind != ast::ScalarKind::Float
{
// It is noop, but another instruction expects result of this conversion
builder.copy_object(dst_type, Some(cv.dst), cv.src)?;
@@ -4295,28 +4073,28 @@ fn emit_implicit_conversion( builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
}
} else {
- // This block is safe because it's illegal to implictly convert between floating point instructions
+ // This block is safe because it's illegal to implictly convert between floating point values
let same_width_bit_type = map.get_or_add(
builder,
- SpirvType::from(ast::Type::from_parts(TypeParts {
- scalar_kind: ScalarKind::Bit,
+ SpirvType::new(ast::Type::from_parts(TypeParts {
+ scalar_kind: ast::ScalarKind::Bit,
..from_parts
})),
);
let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
let wide_bit_type = ast::Type::from_parts(TypeParts {
- scalar_kind: ScalarKind::Bit,
+ scalar_kind: ast::ScalarKind::Bit,
..to_parts
});
let wide_bit_type_spirv =
- map.get_or_add(builder, SpirvType::from(wide_bit_type.clone()));
- if to_parts.scalar_kind == ScalarKind::Unsigned
- || to_parts.scalar_kind == ScalarKind::Bit
+ map.get_or_add(builder, SpirvType::new(wide_bit_type.clone()));
+ if to_parts.scalar_kind == ast::ScalarKind::Unsigned
+ || to_parts.scalar_kind == ast::ScalarKind::Bit
{
builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?;
} else {
- let conversion_fn = if from_parts.scalar_kind == ScalarKind::Signed
- && to_parts.scalar_kind == ScalarKind::Signed
+ let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed
+ && to_parts.scalar_kind == ast::ScalarKind::Signed
{
dr::Builder::s_convert
} else {
@@ -4330,40 +4108,48 @@ fn emit_implicit_conversion( &ImplicitConversion {
src: wide_bit_value,
dst: cv.dst,
- from: wide_bit_type,
- to: cv.to.clone(),
+ from_type: wide_bit_type,
+ from_space: cv.from_space,
+ to_type: cv.to_type.clone(),
+ to_space: cv.to_space,
kind: ConversionKind::Default,
- src_sema: cv.src_sema,
- dst_sema: cv.dst_sema,
},
)?;
}
}
}
- (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => {
- let result_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
builder.s_convert(result_type, Some(cv.dst), cv.src)?;
}
- (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
- | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default)
- | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
- let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default)
+ | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default)
+ | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => {
+ let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
}
- (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => {
- let result_type = if spirv_ptr {
- map.get_or_add(
- builder,
- SpirvType::Pointer(
- Box::new(SpirvType::from(cv.to.clone())),
- spirv::StorageClass::Function,
- ),
- )
- } else {
- map.get_or_add(builder, SpirvType::from(cv.to.clone()))
- };
+ (_, _, &ConversionKind::PtrToPtr) => {
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::Pointer(
+ Box::new(SpirvType::new(cv.to_type.clone())),
+ cv.to_space.to_spirv(),
+ ),
+ );
builder.bitcast(result_type, Some(cv.dst), cv.src)?;
}
+ (_, _, &ConversionKind::AddressOf) => {
+ let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
+ }
+ (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_ptr_to_u(result_type, Some(cv.dst), cv.src)?;
+ }
+ (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => {
+ let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone()));
+ builder.convert_u_to_ptr(result_type, Some(cv.dst), cv.src)?;
+ }
_ => unreachable!(),
}
Ok(())
@@ -4374,14 +4160,14 @@ fn emit_load_var( map: &mut TypeWordMap,
details: &LoadVarDetails,
) -> Result<(), TranslateError> {
- let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone()));
+ let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone()));
match details.member_index {
Some((index, Some(width))) => {
let vector_type = match details.typ {
ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
_ => return Err(TranslateError::MismatchedType),
};
- let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type));
+ let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type));
let vector_temp = builder.load(
vector_type_spirv,
None,
@@ -4399,7 +4185,7 @@ fn emit_load_var( Some((index, None)) => {
let result_ptr_type = map.get_or_add(
builder,
- SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function),
+ SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function),
);
let index_spirv = map.get_or_add_constant(
builder,
@@ -4427,10 +4213,10 @@ fn emit_load_var( Ok(())
}
-fn normalize_identifiers<'a, 'b>(
- id_defs: &mut FnStringIdResolver<'a, 'b>,
- fn_defs: &GlobalFnDeclResolver<'a, 'b>,
- func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
+fn normalize_identifiers<'input, 'b>(
+ id_defs: &mut FnStringIdResolver<'input, 'b>,
+ fn_defs: &GlobalFnDeclResolver<'input, 'b>,
+ func: Vec<ast::Statement<ast::ParsedArgParams<'input>>>,
) -> Result<Vec<NormalizedStatement>, TranslateError> {
for s in func.iter() {
match s {
@@ -4468,48 +4254,28 @@ fn expand_map_variables<'a, 'b>( i.map_variable(&mut |id| id_defs.get_id(id))?,
))),
ast::Statement::Variable(var) => {
- let mut var_type = ast::Type::from(var.var.v_type.clone());
- let mut is_variable = false;
- var_type = match var.var.v_type {
- ast::VariableType::Reg(_) => {
- is_variable = true;
- var_type
- }
- ast::VariableType::Shared(_) => {
- // If it's a pointer it will be translated to a method parameter later
- if let ast::Type::Pointer(..) = var_type {
- is_variable = true;
- var_type
- } else {
- var_type.param_pointer_to(ast::LdStateSpace::Shared)?
- }
- }
- ast::VariableType::Global(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Global)?
- }
- ast::VariableType::Param(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Param)?
- }
- ast::VariableType::Local(_) => {
- var_type.param_pointer_to(ast::LdStateSpace::Local)?
- }
- };
+ let var_type = var.var.v_type.clone();
match var.count {
Some(count) => {
- for new_id in id_defs.add_defs(var.var.name, count, var_type, is_variable) {
+ for new_id in
+ id_defs.add_defs(var.var.name, count, var_type, var.var.state_space, true)
+ {
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init.clone(),
}))
}
}
None => {
- let new_id = id_defs.add_def(var.var.name, Some(var_type), is_variable);
+ let new_id =
+ id_defs.add_def(var.var.name, Some((var_type, var.var.state_space)), true);
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
+ state_space: var.var.state_space,
name: new_id,
array_init: var.var.array_init,
}));
@@ -4520,18 +4286,62 @@ fn expand_map_variables<'a, 'b>( Ok(())
}
+/*
+ Our goal here is to transform
+ .visible .entry foobar(.param .u64 input) {
+ .reg .b64 in_addr;
+ .reg .b64 in_addr2;
+ ld.param.u64 in_addr, [input];
+ cvta.to.global.u64 in_addr2, in_addr;
+ }
+ into:
+ .visible .entry foobar(.param .u8 input[]) {
+ .reg .u8 in_addr[];
+ .reg .u8 in_addr2[];
+ ld.param.u8[] in_addr, [input];
+ mov.u8[] in_addr2, in_addr;
+ }
+ or:
+ .visible .entry foobar(.reg .u8 input[]) {
+ .reg .u8 in_addr[];
+ .reg .u8 in_addr2[];
+ mov.u8[] in_addr, input;
+ mov.u8[] in_addr2, in_addr;
+ }
+ or:
+ .visible .entry foobar(.param ptr<u8, global> input) {
+ .reg ptr<u8, global> in_addr;
+ .reg ptr<u8, global> in_addr2;
+ ld.param.ptr<u8, global> in_addr, [input];
+ mov.ptr<u8, global> in_addr2, in_addr;
+ }
+*/
// TODO: detect more patterns (mov, call via reg, call via param)
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
-// TODO: propagate through calls?
-fn convert_to_stateful_memory_access<'a>(
- func_args: &mut SpirvMethodDecl,
+// TODO: propagate out of calls and into calls
+fn convert_to_stateful_memory_access<'a, 'input>(
+ func_args: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
func_body: Vec<TypedStatement>,
id_defs: &mut NumericIdResolver<'a>,
-) -> Result<Vec<TypedStatement>, TranslateError> {
- let func_args_64bit = func_args
- .input
+) -> Result<
+ (
+ Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
+ Vec<TypedStatement>,
+ ),
+ TranslateError,
+> {
+ let mut method_decl = func_args.borrow_mut();
+ if !method_decl.name.is_kernel() {
+ drop(method_decl);
+ return Ok((func_args, func_body));
+ }
+ if Rc::strong_count(&func_args) != 1 {
+ return Err(error_unreachable());
+ }
+ let func_args_64bit = (*method_decl)
+ .input_arguments
.iter()
.filter_map(|arg| match arg.v_type {
ast::Type::Scalar(ast::ScalarType::U64)
@@ -4546,9 +4356,9 @@ fn convert_to_stateful_memory_access<'a>( match statement {
Statement::Instruction(ast::Instruction::Cvta(
ast::CvtaDetails {
- to: ast::CvtaStateSpace::Global,
+ to: ast::StateSpace::Global,
size: ast::CvtaSize::U64,
- from: ast::CvtaStateSpace::Generic,
+ from: ast::StateSpace::Generic,
},
arg,
)) => {
@@ -4562,24 +4372,24 @@ fn convert_to_stateful_memory_access<'a>( }
Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::LdStType::Scalar(ast::LdStScalarType::U64),
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::U64),
..
},
arg,
))
| Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::LdStType::Scalar(ast::LdStScalarType::S64),
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::S64),
..
},
arg,
))
| Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::LdStType::Scalar(ast::LdStScalarType::B64),
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::B64),
..
},
arg,
@@ -4595,6 +4405,10 @@ fn convert_to_stateful_memory_access<'a>( _ => {}
}
}
+ if stateful_markers.len() == 0 {
+ drop(method_decl);
+ return Ok((func_args, func_body));
+ }
let mut func_args_ptr = HashSet::new();
let mut regs_ptr_current = HashSet::new();
for (dst, src) in stateful_markers {
@@ -4614,23 +4428,23 @@ fn convert_to_stateful_memory_access<'a>( for statement in func_body.iter() {
match statement {
Statement::Instruction(ast::Instruction::Add(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
+ typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
))
| Statement::Instruction(ast::Instruction::Sub(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
+ typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
@@ -4661,21 +4475,32 @@ fn convert_to_stateful_memory_access<'a>( let mut remapped_ids = HashMap::new();
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
for reg in regs_ptr_seen {
- let new_id = id_defs.new_variable(ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ));
+ let new_id = id_defs.register_variable(
+ ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ ast::StateSpace::Reg,
+ );
result.push(Statement::Variable(ast::Variable {
align: None,
name: new_id,
array_init: Vec::new(),
- v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer(
- ast::SizedScalarType::U8,
- ast::PointerStateSpace::Global,
- )),
+ v_type: ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ state_space: ast::StateSpace::Reg,
}));
remapped_ids.insert(reg, new_id);
}
+ for arg in (*method_decl).input_arguments.iter_mut() {
+ if !func_args_ptr.contains(&arg.name) {
+ continue;
+ }
+ let new_id = id_defs.register_variable(
+ ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ ast::StateSpace::Param,
+ );
+ let old_name = arg.name;
+ arg.v_type = ast::Type::Pointer(ast::ScalarType::U8, ast::StateSpace::Global);
+ arg.name = new_id;
+ remapped_ids.insert(old_name, new_id);
+ }
for statement in func_body {
match statement {
l @ Statement::Label(_) => result.push(l),
@@ -4686,12 +4511,12 @@ fn convert_to_stateful_memory_access<'a>( }
}
Statement::Instruction(ast::Instruction::Add(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Add(
ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
+ typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
@@ -4707,20 +4532,20 @@ fn convert_to_stateful_memory_access<'a>( };
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
- state_space: ast::LdStateSpace::Global,
+ underlying_type: ast::ScalarType::U8,
+ state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: offset,
}))
}
Statement::Instruction(ast::Instruction::Sub(
- ast::ArithDetails::Unsigned(ast::UIntType::U64),
+ ast::ArithDetails::Unsigned(ast::ScalarType::U64),
arg,
))
| Statement::Instruction(ast::Instruction::Sub(
ast::ArithDetails::Signed(ast::ArithSInt {
- typ: ast::SIntType::S64,
+ typ: ast::ScalarType::S64,
saturate: false,
}),
arg,
@@ -4734,8 +4559,10 @@ fn convert_to_stateful_memory_access<'a>( }
_ => return Err(error_unreachable()),
};
- let offset_neg =
- id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64)));
+ let offset_neg = id_defs.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ )));
result.push(Statement::Instruction(ast::Instruction::Neg(
ast::NegDetails {
typ: ast::ScalarType::S64,
@@ -4748,8 +4575,8 @@ fn convert_to_stateful_memory_access<'a>( )));
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
- state_space: ast::LdStateSpace::Global,
+ underlying_type: ast::ScalarType::U8,
+ state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: TypedOperand::Reg(offset_neg),
@@ -4757,151 +4584,116 @@ fn convert_to_stateful_memory_access<'a>( }
Statement::Instruction(inst) => {
let mut post_statements = Vec::new();
- let new_statement = inst.visit(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>| {
+ let new_statement =
+ inst.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
- &func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
expected_type,
)
- },
- )?;
+ })?;
result.push(new_statement);
result.extend(post_statements);
}
Statement::Call(call) => {
let mut post_statements = Vec::new();
- let new_statement = call.visit(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>| {
+ let new_statement =
+ call.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
- &func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
expected_type,
)
- },
- )?;
+ })?;
result.push(new_statement);
result.extend(post_statements);
}
Statement::RepackVector(pack) => {
let mut post_statements = Vec::new();
- let new_statement = pack.visit(
- &mut |arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>| {
+ let new_statement =
+ pack.visit(&mut |arg_desc, expected_type: Option<(&ast::Type, _)>| {
convert_to_stateful_memory_access_postprocess(
id_defs,
&remapped_ids,
- &func_args_ptr,
&mut result,
&mut post_statements,
arg_desc,
expected_type,
)
- },
- )?;
+ })?;
result.push(new_statement);
result.extend(post_statements);
}
_ => return Err(error_unreachable()),
}
}
- for arg in func_args.input.iter_mut() {
- if func_args_ptr.contains(&arg.name) {
- arg.v_type = ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- );
- }
- }
- Ok(result)
+ drop(method_decl);
+ Ok((func_args, result))
}
fn convert_to_stateful_memory_access_postprocess(
id_defs: &mut NumericIdResolver,
remapped_ids: &HashMap<spirv::Word, spirv::Word>,
- func_args_ptr: &HashSet<spirv::Word>,
result: &mut Vec<TypedStatement>,
post_statements: &mut Vec<TypedStatement>,
arg_desc: ArgumentDescriptor<spirv::Word>,
- expected_type: Option<&ast::Type>,
+ expected_type: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(match remapped_ids.get(&arg_desc.op) {
Some(new_id) => {
- // We skip conversion here to trigger PtrAcces in a later pass
- let old_type = match expected_type {
- Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
- _ => id_defs.get_typed(arg_desc.op)?.0,
+ let (new_operand_type, new_operand_space, _) = id_defs.get_typed(*new_id)?;
+ if let Some((expected_type, expected_space)) = expected_type {
+ let implicit_conversion = arg_desc
+ .non_default_implicit_conversion
+ .unwrap_or(default_implicit_conversion);
+ if implicit_conversion(
+ (new_operand_space, &new_operand_type),
+ (expected_space, expected_type),
+ )
+ .is_ok()
+ {
+ return Ok(*new_id);
+ }
+ }
+ let (old_operand_type, old_operand_space, _) = id_defs.get_typed(arg_desc.op)?;
+ let converting_id =
+ id_defs.register_intermediate(Some((old_operand_type.clone(), old_operand_space)));
+ let kind = if new_operand_space.is_compatible(ast::StateSpace::Reg) {
+ ConversionKind::Default
+ } else {
+ ConversionKind::PtrToPtr
};
- let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type_clone));
if arg_desc.is_dst {
post_statements.push(Statement::Conversion(ImplicitConversion {
src: converting_id,
dst: *new_id,
- from: old_type,
- to: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global),
- src_sema: ArgumentSemantics::Default,
- dst_sema: arg_desc.sema,
+ from_type: old_operand_type,
+ from_space: old_operand_space,
+ to_type: new_operand_type,
+ to_space: new_operand_space,
+ kind,
}));
converting_id
} else {
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- to: old_type,
- kind: ConversionKind::PtrToBit(ast::UIntType::U64),
- src_sema: arg_desc.sema,
- dst_sema: ArgumentSemantics::Default,
+ from_type: new_operand_type,
+ from_space: new_operand_space,
+ to_type: old_operand_type,
+ to_space: old_operand_space,
+ kind,
}));
converting_id
}
}
- None => match func_args_ptr.get(&arg_desc.op) {
- Some(new_id) => {
- if arg_desc.is_dst {
- return Err(error_unreachable());
- }
- // We skip conversion here to trigger PtrAcces in a later pass
- let old_type = match expected_type {
- Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
- _ => id_defs.get_typed(arg_desc.op)?.0,
- };
- let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type));
- result.push(Statement::Conversion(ImplicitConversion {
- src: *new_id,
- dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global),
- ast::LdStateSpace::Param,
- ),
- to: old_type_clone,
- kind: ConversionKind::PtrToPtr { spirv_ptr: false },
- src_sema: arg_desc.sema,
- dst_sema: ArgumentSemantics::Default,
- }));
- converting_id
- }
- None => arg_desc.op,
- },
+ None => arg_desc.op,
})
}
@@ -4925,9 +4717,9 @@ fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgP fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
match id_defs.get_typed(id) {
- Ok((ast::Type::Scalar(ast::ScalarType::U64), _))
- | Ok((ast::Type::Scalar(ast::ScalarType::S64), _))
- | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true,
+ Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true,
_ => false,
}
}
@@ -5055,20 +4847,95 @@ impl SpecialRegistersMap { }
}
+struct FnSigMapper<'input> {
+ // true - stays as return argument
+ // false - is moved to input argument
+ return_param_args: Vec<bool>,
+ func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
+}
+
+impl<'input> FnSigMapper<'input> {
+ fn remap_to_spirv_repr(mut method: ast::MethodDeclaration<'input, spirv::Word>) -> Self {
+ let return_param_args = method
+ .return_arguments
+ .iter()
+ .map(|a| a.state_space != ast::StateSpace::Param)
+ .collect::<Vec<_>>();
+ let mut new_return_arguments = Vec::new();
+ for arg in method.return_arguments.into_iter() {
+ if arg.state_space == ast::StateSpace::Param {
+ method.input_arguments.push(arg);
+ } else {
+ new_return_arguments.push(arg);
+ }
+ }
+ method.return_arguments = new_return_arguments;
+ FnSigMapper {
+ return_param_args,
+ func_decl: Rc::new(RefCell::new(method)),
+ }
+ }
+
+ fn resolve_in_spirv_repr(
+ &self,
+ call_inst: ast::CallInst<NormalizedArgParams>,
+ ) -> Result<ResolvedCall<NormalizedArgParams>, TranslateError> {
+ let func_decl = (*self.func_decl).borrow();
+ let mut return_arguments = Vec::new();
+ let mut input_arguments = call_inst
+ .param_list
+ .into_iter()
+ .zip(func_decl.input_arguments.iter())
+ .map(|(id, var)| (id, var.v_type.clone(), var.state_space))
+ .collect::<Vec<_>>();
+ let mut func_decl_return_iter = func_decl.return_arguments.iter();
+ let mut func_decl_input_iter = func_decl.input_arguments[input_arguments.len()..].iter();
+ for (idx, id) in call_inst.ret_params.iter().enumerate() {
+ let stays_as_return = match self.return_param_args.get(idx) {
+ Some(x) => *x,
+ None => return Err(TranslateError::MismatchedType),
+ };
+ if stays_as_return {
+ if let Some(var) = func_decl_return_iter.next() {
+ return_arguments.push((*id, var.v_type.clone(), var.state_space));
+ } else {
+ return Err(TranslateError::MismatchedType);
+ }
+ } else {
+ if let Some(var) = func_decl_input_iter.next() {
+ input_arguments.push((
+ ast::Operand::Reg(*id),
+ var.v_type.clone(),
+ var.state_space,
+ ));
+ } else {
+ return Err(TranslateError::MismatchedType);
+ }
+ }
+ }
+ if return_arguments.len() != func_decl.return_arguments.len()
+ || input_arguments.len() != func_decl.input_arguments.len()
+ {
+ return Err(TranslateError::MismatchedType);
+ }
+ Ok(ResolvedCall {
+ return_arguments,
+ input_arguments,
+ uniform: call_inst.uniform,
+ name: call_inst.func,
+ })
+ }
+}
+
struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
- variables_type_check: HashMap<u32, Option<(ast::Type, bool)>>,
+ variables_type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: SpecialRegistersMap,
- fns: HashMap<spirv::Word, FnDecl>,
-}
-
-pub struct FnDecl {
- ret_vals: Vec<ast::FnArgumentType>,
- params: Vec<ast::FnArgumentType>,
+ fns: HashMap<spirv::Word, FnSigMapper<'input>>,
}
-impl<'a> GlobalStringIdResolver<'a> {
+impl<'input> GlobalStringIdResolver<'input> {
fn new(start_id: spirv::Word) -> Self {
Self {
current_id: start_id,
@@ -5079,20 +4946,25 @@ impl<'a> GlobalStringIdResolver<'a> { }
}
- fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word {
+ fn get_or_add_def(&mut self, id: &'input str) -> spirv::Word {
self.get_or_add_impl(id, None)
}
fn get_or_add_def_typed(
&mut self,
- id: &'a str,
+ id: &'input str,
typ: ast::Type,
+ state_space: ast::StateSpace,
is_variable: bool,
) -> spirv::Word {
- self.get_or_add_impl(id, Some((typ, is_variable)))
+ self.get_or_add_impl(id, Some((typ, state_space, is_variable)))
}
- fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word {
+ fn get_or_add_impl(
+ &mut self,
+ id: &'input str,
+ typ: Option<(ast::Type, ast::StateSpace, bool)>,
+ ) -> spirv::Word {
let id = match self.variables.entry(Cow::Borrowed(id)) {
hash_map::Entry::Occupied(e) => *(e.get()),
hash_map::Entry::Vacant(e) => {
@@ -5119,12 +4991,12 @@ impl<'a> GlobalStringIdResolver<'a> { fn start_fn<'b>(
&'b mut self,
- header: &'b ast::MethodDecl<'a, &'a str>,
+ header: &'b ast::MethodDeclaration<'input, &'input str>,
) -> Result<
(
- FnStringIdResolver<'a, 'b>,
- GlobalFnDeclResolver<'a, 'b>,
- ast::MethodDecl<'a, spirv::Word>,
+ FnStringIdResolver<'input, 'b>,
+ GlobalFnDeclResolver<'input, 'b>,
+ Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
),
TranslateError,
> {
@@ -5138,60 +5010,51 @@ impl<'a> GlobalStringIdResolver<'a> { variables: vec![HashMap::new(); 1],
type_check: HashMap::new(),
};
- let new_fn_decl = match header {
- ast::MethodDecl::Kernel { name, in_args } => ast::MethodDecl::Kernel {
- name,
- in_args: expand_kernel_params(&mut fn_resolver, in_args.iter())?,
- },
- ast::MethodDecl::Func(ret_params, _, params) => {
- let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter())?;
- let params_ids = expand_fn_params(&mut fn_resolver, params.iter())?;
- self.fns.insert(
- name_id,
- FnDecl {
- ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(),
- params: params_ids.iter().map(|p| p.v_type.clone()).collect(),
- },
- );
- ast::MethodDecl::Func(ret_params_ids, name_id, params_ids)
- }
+ let return_arguments = rename_fn_params(&mut fn_resolver, &header.return_arguments);
+ let input_arguments = rename_fn_params(&mut fn_resolver, &header.input_arguments);
+ let name = match header.name {
+ ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
+ ast::MethodName::Func(_) => ast::MethodName::Func(name_id),
+ };
+ let fn_decl = ast::MethodDeclaration {
+ return_arguments,
+ name,
+ input_arguments,
+ shared_mem: None,
+ };
+ let new_fn_decl = if !fn_decl.name.is_kernel() {
+ let resolver = FnSigMapper::remap_to_spirv_repr(fn_decl);
+ let new_fn_decl = resolver.func_decl.clone();
+ self.fns.insert(name_id, resolver);
+ new_fn_decl
+ } else {
+ Rc::new(RefCell::new(fn_decl))
};
Ok((
fn_resolver,
- GlobalFnDeclResolver {
- variables: &self.variables,
- fns: &self.fns,
- },
+ GlobalFnDeclResolver { fns: &self.fns },
new_fn_decl,
))
}
}
pub struct GlobalFnDeclResolver<'input, 'a> {
- variables: &'a HashMap<Cow<'input, str>, spirv::Word>,
- fns: &'a HashMap<spirv::Word, FnDecl>,
+ fns: &'a HashMap<spirv::Word, FnSigMapper<'input>>,
}
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
- fn get_fn_decl(&self, id: spirv::Word) -> Result<&FnDecl, TranslateError> {
+ fn get_fn_sig_resolver(&self, id: spirv::Word) -> Result<&FnSigMapper<'input>, TranslateError> {
self.fns.get(&id).ok_or(TranslateError::UnknownSymbol)
}
-
- fn get_fn_decl_str(&self, id: &str) -> Result<&'a FnDecl, TranslateError> {
- match self.variables.get(id).map(|var_id| self.fns.get(var_id)) {
- Some(Some(fn_d)) => Ok(fn_d),
- _ => Err(TranslateError::UnknownSymbol),
- }
- }
}
struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut spirv::Word,
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
- global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: &'b mut SpecialRegistersMap,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
- type_check: HashMap<u32, Option<(ast::Type, bool)>>,
+ type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
@@ -5229,14 +5092,21 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { }
}
- fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>, is_variable: bool) -> spirv::Word {
+ fn add_def(
+ &mut self,
+ id: &'a str,
+ typ: Option<(ast::Type, ast::StateSpace)>,
+ is_variable: bool,
+ ) -> spirv::Word {
let numeric_id = *self.current_id;
self.variables
.last_mut()
.unwrap()
.insert(Cow::Borrowed(id), numeric_id);
- self.type_check
- .insert(numeric_id, typ.map(|t| (t, is_variable)));
+ self.type_check.insert(
+ numeric_id,
+ typ.map(|(typ, space)| (typ, space, is_variable)),
+ );
*self.current_id += 1;
numeric_id
}
@@ -5247,6 +5117,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { base_id: &'a str,
count: u32,
typ: ast::Type,
+ state_space: ast::StateSpace,
is_variable: bool,
) -> impl Iterator<Item = spirv::Word> {
let numeric_id = *self.current_id;
@@ -5255,8 +5126,10 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { .last_mut()
.unwrap()
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
- self.type_check
- .insert(numeric_id + i, Some((typ.clone(), is_variable)));
+ self.type_check.insert(
+ numeric_id + i,
+ Some((typ.clone(), state_space, is_variable)),
+ );
}
*self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
@@ -5265,8 +5138,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> { struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
- global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
- type_check: HashMap<u32, Option<(ast::Type, bool)>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
+ type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: &'b mut SpecialRegistersMap,
}
@@ -5275,12 +5148,15 @@ impl<'b> NumericIdResolver<'b> { MutableNumericIdResolver { base: self }
}
- fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, bool), TranslateError> {
+ fn get_typed(
+ &self,
+ id: spirv::Word,
+ ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
match self.type_check.get(&id) {
Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(id) {
- Some(x) => Ok((x.get_type(), true)),
+ Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
None => match self.global_type_check.get(&id) {
Some(Some(result)) => Ok(result.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol),
@@ -5291,16 +5167,18 @@ impl<'b> NumericIdResolver<'b> { // This is for identifiers which will be emitted later as OpVariable
// They are candidates for insertion of LoadVar/StoreVar
- fn new_variable(&mut self, typ: ast::Type) -> spirv::Word {
+ fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> spirv::Word {
let new_id = *self.current_id;
- self.type_check.insert(new_id, Some((typ, true)));
+ self.type_check
+ .insert(new_id, Some((typ, state_space, true)));
*self.current_id += 1;
new_id
}
- fn new_non_variable(&mut self, typ: Option<ast::Type>) -> spirv::Word {
+ fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> spirv::Word {
let new_id = *self.current_id;
- self.type_check.insert(new_id, typ.map(|t| (t, false)));
+ self.type_check
+ .insert(new_id, typ.map(|(t, space)| (t, space, false)));
*self.current_id += 1;
new_id
}
@@ -5315,18 +5193,22 @@ impl<'b> MutableNumericIdResolver<'b> { self.base
}
- fn get_typed(&self, id: spirv::Word) -> Result<ast::Type, TranslateError> {
- self.base.get_typed(id).map(|(t, _)| t)
+ fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, ast::StateSpace), TranslateError> {
+ self.base.get_typed(id).map(|(t, space, _)| (t, space))
}
- fn new_non_variable(&mut self, typ: ast::Type) -> spirv::Word {
- self.base.new_non_variable(Some(typ))
+ fn register_intermediate(
+ &mut self,
+ typ: ast::Type,
+ state_space: ast::StateSpace,
+ ) -> spirv::Word {
+ self.base.register_intermediate(Some((typ, state_space)))
}
}
enum Statement<I, P: ast::ArgParams> {
Label(u32),
- Variable(ast::Variable<ast::VariableType, P::Id>),
+ Variable(ast::Variable<P::Id>),
Instruction(I),
// SPIR-V compatible replacement for PTX predicates
Conditional(BrachCondition),
@@ -5349,7 +5231,8 @@ impl ExpandedStatement { Statement::Variable(var)
}
Statement::Instruction(inst) => inst
- .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| {
+ .visit(&mut |arg: ArgumentDescriptor<_>,
+ _: Option<(&ast::Type, ast::StateSpace)>| {
Ok(f(arg.op, arg.is_dst))
})
.unwrap(),
@@ -5364,16 +5247,17 @@ impl ExpandedStatement { Statement::StoreVar(details)
}
Statement::Call(mut call) => {
- for (id, typ) in call.ret_params.iter_mut() {
- let is_dst = match typ {
- ast::FnArgumentType::Reg(_) => true,
- ast::FnArgumentType::Param(_) => false,
- ast::FnArgumentType::Shared => false,
+ for (id, _, space) in call.return_arguments.iter_mut() {
+ let is_dst = match space {
+ ast::StateSpace::Reg => true,
+ ast::StateSpace::Param => false,
+ ast::StateSpace::Shared => false,
+ _ => todo!(),
};
*id = f(*id, is_dst);
}
- call.func = f(call.func, false);
- for (id, _) in call.param_list.iter_mut() {
+ call.name = f(call.name, false);
+ for (id, _, _) in call.input_arguments.iter_mut() {
*id = f(*id, false);
}
Statement::Call(call)
@@ -5435,6 +5319,7 @@ impl ExpandedStatement { struct LoadVarDetails {
arg: ast::Arg2<ExpandedArgParams>,
typ: ast::Type,
+ state_space: ast::StateSpace,
// (index, vector_width)
// HACK ALERT
// For some reason IGC explodes when you try to load from builtin vectors
@@ -5454,7 +5339,12 @@ struct RepackVectorDetails { typ: ast::ScalarType,
packed: spirv::Word,
unpacked: Vec<spirv::Word>,
- vector_sema: ArgumentSemantics,
+ non_default_implicit_conversion: Option<
+ fn(
+ (ast::StateSpace, &ast::Type),
+ (ast::StateSpace, &ast::Type),
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ >,
}
impl RepackVectorDetails {
@@ -5470,13 +5360,17 @@ impl RepackVectorDetails { ArgumentDescriptor {
op: self.packed,
is_dst: !self.is_extract,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)),
+ Some((
+ &ast::Type::Vector(self.typ, self.unpacked.len() as u8),
+ ast::StateSpace::Reg,
+ )),
)?;
let scalar_type = self.typ;
let is_extract = self.is_extract;
- let vector_sema = self.vector_sema;
+ let non_default_implicit_conversion = self.non_default_implicit_conversion;
let vector = self
.unpacked
.into_iter()
@@ -5485,9 +5379,10 @@ impl RepackVectorDetails { ArgumentDescriptor {
op: id,
is_dst: is_extract,
- sema: vector_sema,
+ is_memory_access: false,
+ non_default_implicit_conversion,
},
- Some(&ast::Type::Scalar(scalar_type)),
+ Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)),
)
})
.collect::<Result<_, _>>()?;
@@ -5496,7 +5391,7 @@ impl RepackVectorDetails { typ: self.typ,
packed: scalar,
unpacked: vector,
- vector_sema,
+ non_default_implicit_conversion,
})
}
}
@@ -5514,18 +5409,18 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
- pub ret_params: Vec<(P::Id, ast::FnArgumentType)>,
- pub func: P::Id,
- pub param_list: Vec<(P::Operand, ast::FnArgumentType)>,
+ pub return_arguments: Vec<(P::Id, ast::Type, ast::StateSpace)>,
+ pub name: P::Id,
+ pub input_arguments: Vec<(P::Operand, ast::Type, ast::StateSpace)>,
}
impl<T: ast::ArgParams> ResolvedCall<T> {
fn cast<U: ast::ArgParams<Id = T::Id, Operand = T::Operand>>(self) -> ResolvedCall<U> {
ResolvedCall {
uniform: self.uniform,
- ret_params: self.ret_params,
- func: self.func,
- param_list: self.param_list,
+ return_arguments: self.return_arguments,
+ name: self.name,
+ input_arguments: self.input_arguments,
}
}
}
@@ -5535,49 +5430,53 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> { self,
visitor: &mut V,
) -> Result<ResolvedCall<To>, TranslateError> {
- let ret_params = self
- .ret_params
+ let return_arguments = self
+ .return_arguments
.into_iter()
- .map::<Result<_, TranslateError>, _>(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.id(
ArgumentDescriptor {
op: id,
- is_dst: !typ.is_param(),
- sema: typ.semantics(),
+ is_dst: space != ast::StateSpace::Param,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&typ.to_func_type()),
+ Some((&typ, space)),
)?;
- Ok((new_id, typ))
+ Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
let func = visitor.id(
ArgumentDescriptor {
- op: self.func,
+ op: self.name,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
None,
)?;
- let param_list = self
- .param_list
+ let input_arguments = self
+ .input_arguments
.into_iter()
- .map::<Result<_, TranslateError>, _>(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ, space)| {
let new_id = visitor.operand(
ArgumentDescriptor {
op: id,
is_dst: false,
- sema: typ.semantics(),
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- &typ.to_func_type(),
+ &typ,
+ space,
)?;
- Ok((new_id, typ))
+ Ok((new_id, typ, space))
})
.collect::<Result<Vec<_>, _>>()?;
Ok(ResolvedCall {
uniform: self.uniform,
- ret_params,
- func,
- param_list,
+ return_arguments,
+ name: func,
+ input_arguments,
})
}
}
@@ -5598,39 +5497,34 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> { self,
visitor: &mut V,
) -> Result<PtrAccess<To>, TranslateError> {
- let sema = match self.state_space {
- ast::LdStateSpace::Const
- | ast::LdStateSpace::Global
- | ast::LdStateSpace::Shared
- | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
- ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
- ArgumentSemantics::RegisterPointer
- }
- };
- let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space);
+ let ptr_type = ast::Type::Scalar(self.underlying_type.clone());
let new_dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ptr_type),
+ Some((&ptr_type, self.state_space)),
)?;
let new_ptr_src = visitor.id(
ArgumentDescriptor {
op: self.ptr_src,
is_dst: false,
- sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ptr_type),
+ Some((&ptr_type, self.state_space)),
)?;
let new_constant_src = visitor.operand(
ArgumentDescriptor {
op: self.offset_src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
)?;
Ok(PtrAccess {
underlying_type: self.underlying_type,
@@ -5653,21 +5547,9 @@ impl<T: ArgParamsEx<Id = spirv::Word>, U: ArgParamsEx<Id = spirv::Word>> Visitab }
}
-pub trait ArgParamsEx: ast::ArgParams + Sized {
- fn get_fn_decl<'x, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'x, 'b>,
- ) -> Result<&'b FnDecl, TranslateError>;
-}
+pub trait ArgParamsEx: ast::ArgParams + Sized {}
-impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
- fn get_fn_decl<'x, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'x, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl_str(id)
- }
-}
+impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {}
enum NormalizedArgParams {}
@@ -5676,14 +5558,7 @@ impl ast::ArgParams for NormalizedArgParams { type Operand = ast::Operand<spirv::Word>;
}
-impl ArgParamsEx for NormalizedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for NormalizedArgParams {}
type NormalizedStatement = Statement<
(
@@ -5702,14 +5577,7 @@ impl ast::ArgParams for TypedArgParams { type Operand = TypedOperand;
}
-impl ArgParamsEx for TypedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for TypedArgParams {}
#[derive(Copy, Clone)]
enum TypedOperand {
@@ -5740,24 +5608,16 @@ impl ast::ArgParams for ExpandedArgParams { type Operand = spirv::Word;
}
-impl ArgParamsEx for ExpandedArgParams {
- fn get_fn_decl<'a, 'b>(
- id: &Self::Id,
- decl: &'b GlobalFnDeclResolver<'a, 'b>,
- ) -> Result<&'b FnDecl, TranslateError> {
- decl.get_fn_decl(*id)
- }
-}
+impl ArgParamsEx for ExpandedArgParams {}
enum Directive<'input> {
- Variable(ast::Variable<ast::VariableType, spirv::Word>),
+ Variable(ast::LinkingDirective, ast::Variable<spirv::Word>),
Method(Function<'input>),
}
struct Function<'input> {
- pub func_decl: ast::MethodDecl<'input, spirv::Word>,
- pub spirv_decl: SpirvMethodDecl<'input>,
- pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
+ pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, spirv::Word>>>,
+ pub globals: Vec<ast::Variable<spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
import_as: Option<String>,
tuning: Vec<ast::TuningDirective>,
@@ -5767,12 +5627,13 @@ pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> { fn id(
&mut self,
desc: ArgumentDescriptor<T::Id>,
- typ: Option<&ast::Type>,
+ typ: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<U::Id, TranslateError>;
fn operand(
&mut self,
desc: ArgumentDescriptor<T::Operand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<U::Operand, TranslateError>;
}
@@ -5780,13 +5641,13 @@ impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
+ Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -5795,8 +5656,9 @@ where &mut self,
desc: ArgumentDescriptor<spirv::Word>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
- self(desc, Some(typ))
+ self(desc, Some((typ, state_space)))
}
}
@@ -5807,7 +5669,7 @@ where fn id(
&mut self,
desc: ArgumentDescriptor<&str>,
- _: Option<&ast::Type>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self(desc.op)
}
@@ -5816,6 +5678,7 @@ where &mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?),
@@ -5824,7 +5687,7 @@ where ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member),
ast::Operand::VecPack(ref ids) => ast::Operand::VecPack(
ids.into_iter()
- .map(|id| self.id(desc.new_op(id), Some(typ)))
+ .map(|id| self.id(desc.new_op(id), Some((typ, state_space))))
.collect::<Result<Vec<_>, _>>()?,
),
})
@@ -5834,37 +5697,30 @@ where pub struct ArgumentDescriptor<Op> {
op: Op,
is_dst: bool,
- sema: ArgumentSemantics,
+ is_memory_access: bool,
+ non_default_implicit_conversion: Option<
+ fn(
+ (ast::StateSpace, &ast::Type),
+ (ast::StateSpace, &ast::Type),
+ ) -> Result<Option<ConversionKind>, TranslateError>,
+ >,
}
pub struct PtrAccess<P: ast::ArgParams> {
- underlying_type: ast::PointerType,
- state_space: ast::LdStateSpace,
+ underlying_type: ast::ScalarType,
+ state_space: ast::StateSpace,
dst: spirv::Word,
ptr_src: spirv::Word,
offset_src: P::Operand,
}
-#[derive(Copy, Clone, PartialEq, Eq, Debug)]
-pub enum ArgumentSemantics {
- // normal register access
- Default,
- // normal register access with relaxed conversion rules (ld/st)
- DefaultRelaxed,
- // st/ld global
- PhysicalPointer,
- // st/ld .param, .local
- RegisterPointer,
- // mov of .local/.global variables
- Address,
-}
-
impl<T> ArgumentDescriptor<T> {
fn new_op<U>(&self, u: U) -> ArgumentDescriptor<U> {
ArgumentDescriptor {
op: u,
is_dst: self.is_dst,
- sema: self.sema,
+ is_memory_access: self.is_memory_access,
+ non_default_implicit_conversion: self.non_default_implicit_conversion,
}
}
}
@@ -5905,7 +5761,9 @@ impl<T: ArgParamsEx> ast::Instruction<T> { let inst_type = d.typ;
ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?)
}
- ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?),
+ ast::Instruction::Not(t, a) => {
+ ast::Instruction::Not(t, a.map(visitor, &ast::Type::Scalar(t))?)
+ }
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (
@@ -5928,7 +5786,7 @@ impl<T: ArgParamsEx> ast::Instruction<T> { ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?)
}
ast::Instruction::Shl(t, a) => {
- ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?)
+ ast::Instruction::Shl(t, a.map_shift(visitor, &ast::Type::Scalar(t))?)
}
ast::Instruction::Shr(t, a) => {
ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?)
@@ -6101,17 +5959,19 @@ impl ImplicitConversion { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: self.dst_sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&self.to),
+ Some((&self.to_type, self.to_space)),
)?;
let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: self.src_sema,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&self.from),
+ Some((&self.from_type, self.from_space)),
)?;
Ok(Statement::Conversion({
ImplicitConversion {
@@ -6138,13 +5998,13 @@ impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
+ Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -6153,12 +6013,15 @@ where &mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
- TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?),
+ TypedOperand::Reg(id) => {
+ TypedOperand::Reg(self(desc.new_op(id), Some((typ, state_space)))?)
+ }
TypedOperand::Imm(imm) => TypedOperand::Imm(imm),
TypedOperand::RegOffset(id, imm) => {
- TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm)
+ TypedOperand::RegOffset(self(desc.new_op(id), Some((typ, state_space)))?, imm)
}
TypedOperand::VecMember(reg, index) => {
let scalar_type = match typ {
@@ -6166,7 +6029,10 @@ where _ => return Err(error_unreachable()),
};
let vec_type = ast::Type::Vector(scalar_type, index + 1);
- TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index)
+ TypedOperand::VecMember(
+ self(desc.new_op(reg), Some((&vec_type, state_space)))?,
+ index,
+ )
}
})
}
@@ -6178,9 +6044,9 @@ impl ast::Type { ast::Type::Scalar(scalar) => {
let kind = scalar.kind();
let width = scalar.size_of();
- if (kind != ScalarKind::Signed
- && kind != ScalarKind::Unsigned
- && kind != ScalarKind::Bit)
+ if (kind != ast::ScalarKind::Signed
+ && kind != ast::ScalarKind::Unsigned
+ && kind != ast::ScalarKind::Bit)
|| (width == 8)
{
return Err(TranslateError::MismatchedType);
@@ -6198,57 +6064,32 @@ impl ast::Type { match self {
ast::Type::Scalar(scalar) => TypeParts {
kind: TypeKind::Scalar,
+ state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
- state_space: ast::LdStateSpace::Global,
},
ast::Type::Vector(scalar, components) => TypeParts {
kind: TypeKind::Vector,
+ state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: vec![*components as u32],
- state_space: ast::LdStateSpace::Global,
},
ast::Type::Array(scalar, components) => TypeParts {
kind: TypeKind::Array,
+ state_space: ast::StateSpace::Reg,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: components.clone(),
- state_space: ast::LdStateSpace::Global,
},
- ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts {
- kind: TypeKind::PointerScalar,
+ ast::Type::Pointer(scalar, space) => TypeParts {
+ kind: TypeKind::Pointer,
+ state_space: *space,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
- state_space: *state_space,
- },
- ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts {
- kind: TypeKind::PointerVector,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: vec![*len as u32],
- state_space: *state_space,
},
- ast::Type::Pointer(ast::PointerType::Array(scalar, components), state_space) => {
- TypeParts {
- kind: TypeKind::PointerArray,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: components.clone(),
- state_space: *state_space,
- }
- }
- ast::Type::Pointer(ast::PointerType::Pointer(scalar, inner_space), state_space) => {
- TypeParts {
- kind: TypeKind::PointerPointer,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: vec![*inner_space as u32],
- state_space: *state_space,
- }
- }
}
}
@@ -6265,29 +6106,8 @@ impl ast::Type { ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components,
),
- TypeKind::PointerScalar => ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)),
- t.state_space,
- ),
- TypeKind::PointerVector => ast::Type::Pointer(
- ast::PointerType::Vector(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- t.components[0] as u8,
- ),
- t.state_space,
- ),
- TypeKind::PointerArray => ast::Type::Pointer(
- ast::PointerType::Array(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- t.components,
- ),
- t.state_space,
- ),
- TypeKind::PointerPointer => ast::Type::Pointer(
- ast::PointerType::Pointer(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- unsafe { mem::transmute::<_, ast::LdStateSpace>(t.components[0] as u8) },
- ),
+ TypeKind::Pointer => ast::Type::Pointer(
+ ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.state_space,
),
}
@@ -6300,7 +6120,7 @@ impl ast::Type { ast::Type::Array(typ, len) => len
.iter()
.fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)),
- ast::Type::Pointer(_, _) => mem::size_of::<usize>(),
+ ast::Type::Pointer(..) => mem::size_of::<usize>(),
}
}
}
@@ -6308,10 +6128,10 @@ impl ast::Type { #[derive(Eq, PartialEq, Clone)]
struct TypeParts {
kind: TypeKind,
- scalar_kind: ScalarKind,
+ scalar_kind: ast::ScalarKind,
width: u8,
+ state_space: ast::StateSpace,
components: Vec<u32>,
- state_space: ast::LdStateSpace,
}
#[derive(Eq, PartialEq, Copy, Clone)]
@@ -6319,10 +6139,7 @@ enum TypeKind { Scalar,
Vector,
Array,
- PointerScalar,
- PointerVector,
- PointerArray,
- PointerPointer,
+ Pointer,
}
impl ast::Instruction<ExpandedArgParams> {
@@ -6450,21 +6267,21 @@ struct BrachCondition { struct ImplicitConversion {
src: spirv::Word,
dst: spirv::Word,
- from: ast::Type,
- to: ast::Type,
+ from_type: ast::Type,
+ to_type: ast::Type,
+ from_space: ast::StateSpace,
+ to_space: ast::StateSpace,
kind: ConversionKind,
- src_sema: ArgumentSemantics,
- dst_sema: ArgumentSemantics,
}
-#[derive(PartialEq, Copy, Clone)]
+#[derive(PartialEq, Clone)]
enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
- BitToPtr(ast::LdStateSpace),
- PtrToBit(ast::UIntType),
- PtrToPtr { spirv_ptr: bool },
+ BitToPtr,
+ PtrToPtr,
+ AddressOf,
}
impl<T> ast::PredAt<T> {
@@ -6512,13 +6329,14 @@ impl<T: ArgParamsEx> ast::Arg1<T> { fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<ast::Arg1<U>, TranslateError> {
let new_src = visitor.id(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
)?;
@@ -6535,9 +6353,11 @@ impl<T: ArgParamsEx> ast::Arg1Bar<T> { ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg1Bar { src: new_src })
}
@@ -6553,17 +6373,21 @@ impl<T: ArgParamsEx> ast::Arg2<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let new_src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2 {
dst: new_dst,
@@ -6581,17 +6405,21 @@ impl<T: ArgParamsEx> ast::Arg2<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
dst_t,
+ ast::StateSpace::Reg,
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
src_t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2 { dst, src })
}
@@ -6607,26 +6435,21 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::DefaultRelaxed,
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(should_convert_relaxed_dst_wrapper),
},
&ast::Type::from(details.typ.clone()),
+ ast::StateSpace::Reg,
)?;
- let is_logical_ptr = details.state_space == ast::LdStateSpace::Param
- || details.state_space == ast::LdStateSpace::Local;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: if is_logical_ptr {
- ArgumentSemantics::RegisterPointer
- } else {
- ArgumentSemantics::PhysicalPointer
- },
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::from(details.typ.clone()),
- details.state_space,
- ),
+ &details.typ,
+ details.state_space,
)?;
Ok(ast::Arg2Ld { dst, src })
}
@@ -6638,30 +6461,25 @@ impl<T: ArgParamsEx> ast::Arg2St<T> { visitor: &mut V,
details: &ast::StData,
) -> Result<ast::Arg2St<U>, TranslateError> {
- let is_logical_ptr = details.state_space == ast::StStateSpace::Param
- || details.state_space == ast::StStateSpace::Local;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: if is_logical_ptr {
- ArgumentSemantics::RegisterPointer
- } else {
- ArgumentSemantics::PhysicalPointer
- },
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::from(details.typ.clone()),
- details.state_space.to_ld_ss(),
- ),
+ &details.typ,
+ details.state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::DefaultRelaxed,
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(should_convert_relaxed_src_wrapper),
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2St { src1, src2 })
}
@@ -6677,21 +6495,21 @@ impl<T: ArgParamsEx> ast::Arg2Mov<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
is_dst: false,
- sema: if details.src_is_address {
- ArgumentSemantics::Address
- } else {
- ArgumentSemantics::Default
- },
+ is_memory_access: false,
+ non_default_implicit_conversion: Some(implicit_conversion_mov),
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2Mov { dst, src })
}
@@ -6713,25 +6531,31 @@ impl<T: ArgParamsEx> ast::Arg3<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
wide_type.as_ref().unwrap_or(typ),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6745,25 +6569,31 @@ impl<T: ArgParamsEx> ast::Arg3<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6772,35 +6602,38 @@ impl<T: ArgParamsEx> ast::Arg3<T> { self,
visitor: &mut V,
t: ast::ScalarType,
- state_space: ast::AtomSpace,
+ state_space: ast::StateSpace,
) -> Result<ast::Arg3<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::PhysicalPointer,
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::Scalar(scalar_type),
- state_space.to_ld_ss(),
- ),
+ &ast::Type::Scalar(scalar_type),
+ state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6822,33 +6655,41 @@ impl<T: ArgParamsEx> ast::Arg4<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
wide_type.as_ref().unwrap_or(t),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6861,39 +6702,47 @@ impl<T: ArgParamsEx> ast::Arg4<T> { fn map_selp<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::SelpType,
+ t: ast::ScalarType,
) -> Result<ast::Arg4<U>, TranslateError> {
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6906,44 +6755,49 @@ impl<T: ArgParamsEx> ast::Arg4<T> { fn map_atom<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: ast::BitType,
- state_space: ast::AtomSpace,
+ t: ast::ScalarType,
+ state_space: ast::StateSpace,
) -> Result<ast::Arg4<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
let dst = visitor.operand(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::PhysicalPointer,
+ is_memory_access: true,
+ non_default_implicit_conversion: None,
},
- &ast::Type::Pointer(
- ast::PointerType::Scalar(scalar_type),
- state_space.to_ld_ss(),
- ),
+ &ast::Type::Scalar(scalar_type),
+ state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6962,34 +6816,42 @@ impl<T: ArgParamsEx> ast::Arg4<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
typ,
+ ast::StateSpace::Reg,
)?;
let u32_type = ast::Type::Scalar(ast::ScalarType::U32);
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&u32_type,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&u32_type,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -7010,9 +6872,13 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> { ArgumentDescriptor {
op: self.dst1,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)?;
let dst2 = self
.dst2
@@ -7021,9 +6887,13 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> { ArgumentDescriptor {
op: dst2,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)
})
.transpose()?;
@@ -7031,17 +6901,21 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> { ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4Setp {
dst1,
@@ -7062,41 +6936,51 @@ impl<T: ArgParamsEx> ast::Arg5<T> { ArgumentDescriptor {
op: self.dst,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
base_type,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
base_type,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
base_type,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
let src4 = visitor.operand(
ArgumentDescriptor {
op: self.src4,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg5 {
dst,
@@ -7118,9 +7002,13 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> { ArgumentDescriptor {
op: self.dst1,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)?;
let dst2 = self
.dst2
@@ -7129,9 +7017,13 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> { ArgumentDescriptor {
op: dst2,
is_dst: true,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)
})
.transpose()?;
@@ -7139,25 +7031,31 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> { ArgumentDescriptor {
op: self.src1,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
op: self.src2,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
t,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
op: self.src3,
is_dst: false,
- sema: ArgumentSemantics::Default,
+ is_memory_access: false,
+ non_default_implicit_conversion: None,
},
&ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg5Setp {
dst1,
@@ -7195,115 +7093,41 @@ impl ast::Operand<spirv::Word> { }
}
-impl ast::StStateSpace {
- fn to_ld_ss(self) -> ast::LdStateSpace {
- match self {
- ast::StStateSpace::Generic => ast::LdStateSpace::Generic,
- ast::StStateSpace::Global => ast::LdStateSpace::Global,
- ast::StStateSpace::Local => ast::LdStateSpace::Local,
- ast::StStateSpace::Param => ast::LdStateSpace::Param,
- ast::StStateSpace::Shared => ast::LdStateSpace::Shared,
- }
- }
-}
-
-#[derive(Clone, Copy, PartialEq, Eq)]
-enum ScalarKind {
- Bit,
- Unsigned,
- Signed,
- Float,
- Float2,
- Pred,
-}
-
impl ast::ScalarType {
- fn kind(self) -> ScalarKind {
- match self {
- ast::ScalarType::U8 => ScalarKind::Unsigned,
- ast::ScalarType::U16 => ScalarKind::Unsigned,
- ast::ScalarType::U32 => ScalarKind::Unsigned,
- ast::ScalarType::U64 => ScalarKind::Unsigned,
- ast::ScalarType::S8 => ScalarKind::Signed,
- ast::ScalarType::S16 => ScalarKind::Signed,
- ast::ScalarType::S32 => ScalarKind::Signed,
- ast::ScalarType::S64 => ScalarKind::Signed,
- ast::ScalarType::B8 => ScalarKind::Bit,
- ast::ScalarType::B16 => ScalarKind::Bit,
- ast::ScalarType::B32 => ScalarKind::Bit,
- ast::ScalarType::B64 => ScalarKind::Bit,
- ast::ScalarType::F16 => ScalarKind::Float,
- ast::ScalarType::F32 => ScalarKind::Float,
- ast::ScalarType::F64 => ScalarKind::Float,
- ast::ScalarType::F16x2 => ScalarKind::Float2,
- ast::ScalarType::Pred => ScalarKind::Pred,
- }
- }
-
- fn from_parts(width: u8, kind: ScalarKind) -> Self {
+ fn from_parts(width: u8, kind: ast::ScalarKind) -> Self {
match kind {
- ScalarKind::Float => match width {
+ ast::ScalarKind::Float => match width {
2 => ast::ScalarType::F16,
4 => ast::ScalarType::F32,
8 => ast::ScalarType::F64,
_ => unreachable!(),
},
- ScalarKind::Bit => match width {
+ ast::ScalarKind::Bit => match width {
1 => ast::ScalarType::B8,
2 => ast::ScalarType::B16,
4 => ast::ScalarType::B32,
8 => ast::ScalarType::B64,
_ => unreachable!(),
},
- ScalarKind::Signed => match width {
+ ast::ScalarKind::Signed => match width {
1 => ast::ScalarType::S8,
2 => ast::ScalarType::S16,
4 => ast::ScalarType::S32,
8 => ast::ScalarType::S64,
_ => unreachable!(),
},
- ScalarKind::Unsigned => match width {
+ ast::ScalarKind::Unsigned => match width {
1 => ast::ScalarType::U8,
2 => ast::ScalarType::U16,
4 => ast::ScalarType::U32,
8 => ast::ScalarType::U64,
_ => unreachable!(),
},
- ScalarKind::Float2 => match width {
+ ast::ScalarKind::Float2 => match width {
4 => ast::ScalarType::F16x2,
_ => unreachable!(),
},
- ScalarKind::Pred => ast::ScalarType::Pred,
- }
- }
-}
-
-impl ast::BooleanType {
- fn to_type(self) -> ast::Type {
- match self {
- ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred),
- ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
- ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
- ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
- }
- }
-}
-
-impl ast::ShlType {
- fn to_type(self) -> ast::Type {
- match self {
- ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16),
- ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32),
- ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64),
- }
- }
-}
-
-impl ast::ShrType {
- fn signed(&self) -> bool {
- match self {
- ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true,
- _ => false,
+ ast::ScalarKind::Pred => ast::ScalarType::Pred,
}
}
}
@@ -7359,49 +7183,47 @@ impl ast::AtomInnerDetails { }
}
-impl ast::SIntType {
- fn from_size(width: u8) -> Self {
- match width {
- 1 => ast::SIntType::S8,
- 2 => ast::SIntType::S16,
- 4 => ast::SIntType::S32,
- 8 => ast::SIntType::S64,
- _ => unreachable!(),
+impl ast::StateSpace {
+ fn to_spirv(self) -> spirv::StorageClass {
+ match self {
+ ast::StateSpace::Const => spirv::StorageClass::UniformConstant,
+ ast::StateSpace::Generic => spirv::StorageClass::Generic,
+ ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup,
+ ast::StateSpace::Local => spirv::StorageClass::Function,
+ ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
+ ast::StateSpace::Param => spirv::StorageClass::Function,
+ ast::StateSpace::Reg => spirv::StorageClass::Function,
+ ast::StateSpace::Sreg => spirv::StorageClass::Input,
}
}
-}
-impl ast::UIntType {
- fn from_size(width: u8) -> Self {
- match width {
- 1 => ast::UIntType::U8,
- 2 => ast::UIntType::U16,
- 4 => ast::UIntType::U32,
- 8 => ast::UIntType::U64,
- _ => unreachable!(),
- }
+ fn is_compatible(self, other: ast::StateSpace) -> bool {
+ self == other
+ || self == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
+ || self == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
}
-}
-impl ast::LdStateSpace {
- fn to_spirv(self) -> spirv::StorageClass {
+ fn coerces_to_generic(self) -> bool {
match self {
- ast::LdStateSpace::Const => spirv::StorageClass::UniformConstant,
- ast::LdStateSpace::Generic => spirv::StorageClass::Generic,
- ast::LdStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
- ast::LdStateSpace::Local => spirv::StorageClass::Function,
- ast::LdStateSpace::Shared => spirv::StorageClass::Workgroup,
- ast::LdStateSpace::Param => spirv::StorageClass::Function,
+ ast::StateSpace::Global
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Reg
+ | ast::StateSpace::Param
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Sreg => false,
}
}
-}
-impl From<ast::FnArgumentType> for ast::VariableType {
- fn from(t: ast::FnArgumentType) -> Self {
- match t {
- ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
- ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
- ast::FnArgumentType::Shared => todo!(),
+ fn is_addressable(self) -> bool {
+ match self {
+ ast::StateSpace::Const
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Global
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => true,
+ ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false,
}
}
}
@@ -7427,16 +7249,6 @@ impl ast::MulDetails { }
}
-impl ast::AtomSpace {
- fn to_ld_ss(self) -> ast::LdStateSpace {
- match self {
- ast::AtomSpace::Generic => ast::LdStateSpace::Generic,
- ast::AtomSpace::Global => ast::LdStateSpace::Global,
- ast::AtomSpace::Shared => ast::LdStateSpace::Shared,
- }
- }
-}
-
impl ast::MemScope {
fn to_spirv(self) -> spirv::Scope {
match self {
@@ -7458,109 +7270,96 @@ impl ast::AtomSemantics { }
}
-impl ast::FnArgumentType {
- fn semantics(&self) -> ArgumentSemantics {
- match self {
- ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default,
- ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer,
- ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer,
- }
- }
-}
-
-fn bitcast_register_pointer(
- operand_type: &ast::Type,
- instr_type: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+fn default_implicit_conversion(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- bitcast_physical_pointer(operand_type, instr_type, ss)
+ if !instruction_space.is_compatible(operand_space) {
+ default_implicit_conversion_space(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
+ } else if instruction_type != operand_type {
+ default_implicit_conversion_type(instruction_space, operand_type, instruction_type)
+ } else {
+ Ok(None)
+ }
}
-fn bitcast_physical_pointer(
- operand_type: &ast::Type,
- instr_type: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+// Space is different
+fn default_implicit_conversion_space(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- match operand_type {
- // array decays to a pointer
- ast::Type::Array(op_scalar_t, _) => {
- if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
- if ss == Some(*instr_space) {
- if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) {
- Ok(None)
- } else {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- }
+ if (instruction_space == ast::StateSpace::Generic && operand_space.coerces_to_generic())
+ || (operand_space == ast::StateSpace::Generic && instruction_space.coerces_to_generic())
+ {
+ Ok(Some(ConversionKind::PtrToPtr))
+ } else if operand_space.is_compatible(ast::StateSpace::Reg) {
+ match operand_type {
+ ast::Type::Pointer(operand_ptr_type, operand_ptr_space)
+ if *operand_ptr_space == instruction_space =>
+ {
+ if instruction_type != &ast::Type::Scalar(*operand_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
} else {
- if ss == Some(ast::LdStateSpace::Generic)
- || *instr_space == ast::LdStateSpace::Generic
- {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- } else {
- Err(TranslateError::MismatchedType)
- }
+ Ok(None)
}
- } else {
- Err(TranslateError::MismatchedType)
- }
- }
- ast::Type::Scalar(ast::ScalarType::B64)
- | ast::Type::Scalar(ast::ScalarType::U64)
- | ast::Type::Scalar(ast::ScalarType::S64) => {
- if let Some(space) = ss {
- Ok(Some(ConversionKind::BitToPtr(space)))
- } else {
- Err(error_unreachable())
- }
- }
- ast::Type::Scalar(ast::ScalarType::B32)
- | ast::Type::Scalar(ast::ScalarType::U32)
- | ast::Type::Scalar(ast::ScalarType::S32) => match ss {
- Some(ast::LdStateSpace::Shared)
- | Some(ast::LdStateSpace::Generic)
- | Some(ast::LdStateSpace::Param)
- | Some(ast::LdStateSpace::Local) => {
- Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared)))
}
+ // TODO: 32 bit
+ ast::Type::Scalar(ast::ScalarType::B64)
+ | ast::Type::Scalar(ast::ScalarType::U64)
+ | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space {
+ ast::StateSpace::Global
+ | ast::StateSpace::Generic
+ | ast::StateSpace::Const
+ | ast::StateSpace::Local
+ | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)),
+ _ => Err(TranslateError::MismatchedType),
+ },
+ ast::Type::Scalar(ast::ScalarType::B32)
+ | ast::Type::Scalar(ast::ScalarType::U32)
+ | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space {
+ ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => {
+ Ok(Some(ConversionKind::BitToPtr))
+ }
+ _ => Err(TranslateError::MismatchedType),
+ },
_ => Err(TranslateError::MismatchedType),
- },
- ast::Type::Pointer(op_scalar_t, op_space) => {
- if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
- if op_space == instr_space {
- if op_scalar_t == instr_scalar_t {
- Ok(None)
- } else {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- }
+ }
+ } else if instruction_space.is_compatible(ast::StateSpace::Reg) {
+ match instruction_type {
+ ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space)
+ if operand_space == *instruction_ptr_space =>
+ {
+ if operand_type != &ast::Type::Scalar(*instruction_ptr_type) {
+ Ok(Some(ConversionKind::PtrToPtr))
} else {
- if *op_space == ast::LdStateSpace::Generic
- || *instr_space == ast::LdStateSpace::Generic
- {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- } else {
- Err(TranslateError::MismatchedType)
- }
+ Ok(None)
}
- } else {
- Err(TranslateError::MismatchedType)
}
+ _ => Err(TranslateError::MismatchedType),
}
- _ => Err(TranslateError::MismatchedType),
+ } else {
+ Err(TranslateError::MismatchedType)
}
}
-fn force_bitcast_ptr_to_bit(
- _: &ast::Type,
- instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+// Space is same, but type is different
+fn default_implicit_conversion_type(
+ space: ast::StateSpace,
+ operand_type: &ast::Type,
+ instruction_type: &ast::Type,
) -> Result<Option<ConversionKind>, TranslateError> {
- // TODO: verify this on f32, u16 and the like
- if let ast::Type::Scalar(scalar_t) = instr_type {
- if let Ok(int_type) = (*scalar_t).try_into() {
- return Ok(Some(ConversionKind::PtrToBit(int_type)));
+ if space.is_compatible(ast::StateSpace::Reg) {
+ if should_bitcast(instruction_type, operand_type) {
+ Ok(Some(ConversionKind::Default))
+ } else {
+ Err(TranslateError::MismatchedType)
}
+ } else {
+ Ok(Some(ConversionKind::PtrToPtr))
}
- Err(TranslateError::MismatchedType)
}
fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
@@ -7570,16 +7369,18 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { return false;
}
match inst.kind() {
- ScalarKind::Bit => operand.kind() != ScalarKind::Bit,
- ScalarKind::Float => operand.kind() == ScalarKind::Bit,
- ScalarKind::Signed => {
- operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned
+ ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit,
+ ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit,
+ ast::ScalarKind::Signed => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Unsigned
}
- ScalarKind::Unsigned => {
- operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed
+ ast::ScalarKind::Unsigned => {
+ operand.kind() == ast::ScalarKind::Bit
+ || operand.kind() == ast::ScalarKind::Signed
}
- ScalarKind::Float2 => false,
- ScalarKind::Pred => false,
+ ast::ScalarKind::Float2 => false,
+ ast::ScalarKind::Pred => false,
}
}
(ast::Type::Vector(inst, _), ast::Type::Vector(operand, _))
@@ -7590,47 +7391,45 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { }
}
-fn should_bitcast_packed(
- operand: &ast::Type,
- instr: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+fn implicit_conversion_mov(
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
- (operand, instr)
- {
- if scalar.kind() == ScalarKind::Bit
- && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ // instruction_space is always reg
+ if operand_space.is_compatible(ast::StateSpace::Reg) {
+ if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
+ (operand_type, instruction_type)
{
- return Ok(Some(ConversionKind::Default));
- }
- }
- should_bitcast_wrapper(operand, instr, ss)
-}
-
-fn should_bitcast_wrapper(
- operand: &ast::Type,
- instr: &ast::Type,
- _: Option<ast::LdStateSpace>,
-) -> Result<Option<ConversionKind>, TranslateError> {
- if instr == operand {
- return Ok(None);
- }
- if should_bitcast(instr, operand) {
- Ok(Some(ConversionKind::Default))
- } else {
- Err(TranslateError::MismatchedType)
- }
+ if scalar.kind() == ast::ScalarKind::Bit
+ && scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
+ {
+ return Ok(Some(ConversionKind::Default));
+ }
+ }
+ // TODO: verify .params addressability:
+ // * kernel arg
+ // * func arg
+ // * variable
+ } else if operand_space.is_addressable() {
+ return Ok(Some(ConversionKind::AddressOf));
+ }
+ default_implicit_conversion(
+ (operand_space, operand_type),
+ (instruction_space, instruction_type),
+ )
}
fn should_convert_relaxed_src_wrapper(
- src_type: &ast::Type,
- instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if src_type == instr_type {
+ if !operand_space.is_compatible(instruction_space) {
+ return Err(TranslateError::MismatchedType);
+ }
+ if operand_type == instruction_type {
return Ok(None);
}
- match should_convert_relaxed_src(src_type, instr_type) {
+ match should_convert_relaxed_src(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
@@ -7646,32 +7445,33 @@ fn should_convert_relaxed_src( }
match (src_type, instr_type) {
(ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
- ScalarKind::Bit => {
+ ast::ScalarKind::Bit => {
if instr_type.size_of() <= src_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Signed | ScalarKind::Unsigned => {
+ ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= src_type.size_of()
- && src_type.kind() != ScalarKind::Float
+ && src_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float => {
- if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= src_type.size_of()
+ && src_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float2 => todo!(),
- ScalarKind::Pred => None,
+ ast::ScalarKind::Float2 => todo!(),
+ ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
@@ -7685,14 +7485,16 @@ fn should_convert_relaxed_src( }
fn should_convert_relaxed_dst_wrapper(
- dst_type: &ast::Type,
- instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+ (operand_space, operand_type): (ast::StateSpace, &ast::Type),
+ (instruction_space, instruction_type): (ast::StateSpace, &ast::Type),
) -> Result<Option<ConversionKind>, TranslateError> {
- if dst_type == instr_type {
+ if !operand_space.is_compatible(instruction_space) {
+ return Err(TranslateError::MismatchedType);
+ }
+ if operand_type == instruction_type {
return Ok(None);
}
- match should_convert_relaxed_dst(dst_type, instr_type) {
+ match should_convert_relaxed_dst(operand_type, instruction_type) {
conv @ Some(_) => Ok(conv),
None => Err(TranslateError::MismatchedType),
}
@@ -7708,15 +7510,15 @@ fn should_convert_relaxed_dst( }
match (dst_type, instr_type) {
(ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() {
- ScalarKind::Bit => {
+ ast::ScalarKind::Bit => {
if instr_type.size_of() <= dst_type.size_of() {
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Signed => {
- if dst_type.kind() != ScalarKind::Float {
+ ast::ScalarKind::Signed => {
+ if dst_type.kind() != ast::ScalarKind::Float {
if instr_type.size_of() == dst_type.size_of() {
Some(ConversionKind::Default)
} else if instr_type.size_of() < dst_type.size_of() {
@@ -7728,25 +7530,26 @@ fn should_convert_relaxed_dst( None
}
}
- ScalarKind::Unsigned => {
+ ast::ScalarKind::Unsigned => {
if instr_type.size_of() <= dst_type.size_of()
- && dst_type.kind() != ScalarKind::Float
+ && dst_type.kind() != ast::ScalarKind::Float
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float => {
- if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit
+ ast::ScalarKind::Float => {
+ if instr_type.size_of() <= dst_type.size_of()
+ && dst_type.kind() == ast::ScalarKind::Bit
{
Some(ConversionKind::Default)
} else {
None
}
}
- ScalarKind::Float2 => todo!(),
- ScalarKind::Pred => None,
+ ast::ScalarKind::Float2 => todo!(),
+ ast::ScalarKind::Pred => None,
},
(ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _))
| (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => {
@@ -7759,77 +7562,46 @@ fn should_convert_relaxed_dst( }
}
-impl<'a> ast::MethodDecl<'a, &'a str> {
+impl<'a> ast::MethodDeclaration<'a, &'a str> {
fn name(&self) -> &'a str {
- match self {
- ast::MethodDecl::Kernel { name, .. } => name,
- ast::MethodDecl::Func(_, name, _) => name,
+ match self.name {
+ ast::MethodName::Kernel(name) => name,
+ ast::MethodName::Func(name) => name,
}
}
}
-struct SpirvMethodDecl<'input> {
- input: Vec<ast::Variable<ast::Type, spirv::Word>>,
- output: Vec<ast::Variable<ast::Type, spirv::Word>>,
- name: MethodName<'input>,
- uses_shared_mem: bool,
+impl<'a> ast::MethodDeclaration<'a, spirv::Word> {
+ fn effective_input_arguments(&self) -> impl Iterator<Item = (spirv::Word, SpirvType)> + '_ {
+ let is_kernel = self.name.is_kernel();
+ self.input_arguments
+ .iter()
+ .map(move |arg| {
+ if !is_kernel && arg.state_space != ast::StateSpace::Reg {
+ let spirv_type =
+ SpirvType::pointer_to(arg.v_type.clone(), arg.state_space.to_spirv());
+ (arg.name, spirv_type)
+ } else {
+ (arg.name, SpirvType::new(arg.v_type.clone()))
+ }
+ })
+ .chain(self.shared_mem.iter().map(|id| {
+ (
+ *id,
+ SpirvType::Pointer(
+ Box::new(SpirvType::Base(SpirvScalarKey::B8)),
+ spirv::StorageClass::Workgroup,
+ ),
+ )
+ }))
+ }
}
-impl<'input> SpirvMethodDecl<'input> {
- fn new(ast_decl: &ast::MethodDecl<'input, spirv::Word>) -> Self {
- let (input, output) = match ast_decl {
- ast::MethodDecl::Kernel { in_args, .. } => {
- let spirv_input = in_args
- .iter()
- .map(|var| {
- let v_type = match &var.v_type {
- ast::KernelArgumentType::Normal(t) => {
- ast::FnArgumentType::Param(t.clone())
- }
- ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared,
- };
- ast::Variable {
- name: var.name,
- align: var.align,
- v_type: v_type.to_kernel_type(),
- array_init: var.array_init.clone(),
- }
- })
- .collect();
- (spirv_input, Vec::new())
- }
- ast::MethodDecl::Func(out_args, _, in_args) => {
- let (param_output, non_param_output): (Vec<_>, Vec<_>) =
- out_args.iter().partition(|var| var.v_type.is_param());
- let spirv_output = non_param_output
- .into_iter()
- .cloned()
- .map(|var| ast::Variable {
- name: var.name,
- align: var.align,
- v_type: var.v_type.to_func_type(),
- array_init: var.array_init.clone(),
- })
- .collect();
- let spirv_input = param_output
- .into_iter()
- .cloned()
- .chain(in_args.iter().cloned())
- .map(|var| ast::Variable {
- name: var.name,
- align: var.align,
- v_type: var.v_type.to_func_type(),
- array_init: var.array_init.clone(),
- })
- .collect();
- (spirv_input, spirv_output)
- }
- };
- SpirvMethodDecl {
- input,
- output,
- name: MethodName::new(ast_decl),
- uses_shared_mem: false,
+impl<'input, ID> ast::MethodName<'input, ID> {
+ fn is_kernel(&self) -> bool {
+ match self {
+ ast::MethodName::Kernel(..) => true,
+ ast::MethodName::Func(..) => false,
}
}
}
|