aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2021-05-05 22:56:58 +0200
committerAndrzej Janik <[email protected]>2021-05-05 22:56:58 +0200
commit9d92a6e284dce00b0b785a50f623d3715f8aeac4 (patch)
tree497a61b19168f2c5741dff6a348e4107605167df
parentd51aaaf5529dbfec0735c73768e468728112c26b (diff)
downloadZLUDA-9d92a6e284dce00b0b785a50f623d3715f8aeac4.tar.gz
ZLUDA-9d92a6e284dce00b0b785a50f623d3715f8aeac4.zip
Start converting the translation to one type type
-rw-r--r--ptx/src/ast.rs85
-rw-r--r--ptx/src/ptx.lalrpop74
-rw-r--r--ptx/src/translate.rs1187
3 files changed, 665 insertions, 681 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index c7b9563..364ec01 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -1,6 +1,6 @@
use half::f16;
use lalrpop_util::{lexer::Token, ParseError};
-use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr};
+use std::{convert::From, mem, num::ParseFloatError, str::FromStr};
use std::{marker::PhantomData, num::ParseIntError};
#[derive(Debug, thiserror::Error)]
@@ -110,35 +110,7 @@ pub enum Type {
Scalar(ScalarType),
Vector(ScalarType, u8),
Array(ScalarType, Vec<u32>),
- Pointer(PointerType, LdStateSpace),
-}
-
-#[derive(PartialEq, Eq, Clone)]
-pub enum PointerType {
- Scalar(ScalarType),
- Vector(ScalarType, u8),
- Array(ScalarType, Vec<u32>),
- // Instances of this variant are generated during stateful conversion
- Pointer(ScalarType, LdStateSpace),
-}
-
-impl From<ScalarType> for PointerType {
- fn from(t: ScalarType) -> Self {
- PointerType::Scalar(t.into())
- }
-}
-
-impl TryFrom<PointerType> for ScalarType {
- type Error = ();
-
- fn try_from(value: PointerType) -> Result<Self, Self::Error> {
- match value {
- PointerType::Scalar(t) => Ok(t),
- PointerType::Vector(_, _) => Err(()),
- PointerType::Array(_, _) => Err(()),
- PointerType::Pointer(_, _) => Err(()),
- }
- }
+ Pointer(ScalarType),
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
@@ -222,6 +194,7 @@ pub enum StateSpace {
Shared,
Param,
Generic,
+ Sreg,
}
pub struct PredAt<ID> {
@@ -397,9 +370,9 @@ pub enum VectorPrefix {
pub struct LdDetails {
pub qualifier: LdStQualifier,
- pub state_space: LdStateSpace,
+ pub state_space: StateSpace,
pub caching: LdCacheOperator,
- pub typ: PointerType,
+ pub typ: Type,
pub non_coherent: bool,
}
@@ -418,17 +391,6 @@ pub enum MemScope {
Sys,
}
-#[derive(Copy, Clone, PartialEq, Eq, Debug)]
-#[repr(u8)]
-pub enum LdStateSpace {
- Generic,
- Const,
- Global,
- Local,
- Param,
- Shared,
-}
-
#[derive(Copy, Clone, PartialEq, Eq)]
pub enum LdCacheOperator {
Cached,
@@ -612,20 +574,11 @@ impl CvtDetails {
}
pub struct CvtaDetails {
- pub to: CvtaStateSpace,
- pub from: CvtaStateSpace,
+ pub to: StateSpace,
+ pub from: StateSpace,
pub size: CvtaSize,
}
-#[derive(Copy, Clone, PartialEq, Eq)]
-pub enum CvtaStateSpace {
- Generic,
- Const,
- Global,
- Local,
- Shared,
-}
-
pub enum CvtaSize {
U32,
U64,
@@ -633,18 +586,9 @@ pub enum CvtaSize {
pub struct StData {
pub qualifier: LdStQualifier,
- pub state_space: StStateSpace,
+ pub state_space: StateSpace,
pub caching: StCacheOperator,
- pub typ: PointerType,
-}
-
-#[derive(PartialEq, Eq, Copy, Clone)]
-pub enum StStateSpace {
- Generic,
- Global,
- Local,
- Param,
- Shared,
+ pub typ: Type,
}
#[derive(PartialEq, Eq)]
@@ -717,7 +661,7 @@ pub struct MinMaxFloat {
pub struct AtomDetails {
pub semantics: AtomSemantics,
pub scope: MemScope,
- pub space: AtomSpace,
+ pub space: StateSpace,
pub inner: AtomInnerDetails,
}
@@ -730,13 +674,6 @@ pub enum AtomSemantics {
}
#[derive(Copy, Clone)]
-pub enum AtomSpace {
- Generic,
- Global,
- Shared,
-}
-
-#[derive(Copy, Clone)]
pub enum AtomInnerDetails {
Bit { op: AtomBitOp, typ: ScalarType },
Unsigned { op: AtomUIntOp, typ: ScalarType },
@@ -777,7 +714,7 @@ pub enum AtomFloatOp {
pub struct AtomCasDetails {
pub semantics: AtomSemantics,
pub scope: MemScope,
- pub space: AtomSpace,
+ pub space: StateSpace,
pub typ: ScalarType,
}
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index dc439b7..8fee7c2 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -619,9 +619,9 @@ ModuleVariable: ast::Variable<&'input str> = {
return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
}
if space == ".global" {
- (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Global), ast::StateSpace::Global, Vec::new())
+ (ast::Type::Pointer(t), ast::StateSpace::Global, Vec::new())
} else {
- (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Shared), ast::StateSpace::Shared, Vec::new())
+ (ast::Type::Pointer(t), ast::StateSpace::Shared, Vec::new())
}
}
};
@@ -643,7 +643,7 @@ ParamVariable: (Option<u32>, Vec<u8>, ast::Type, &'input str) = {
(ast::Type::Array(t, dimensions), init)
}
ast::ArrayOrPointer::Pointer => {
- (ast::Type::Pointer(ast::PointerType::Scalar(t), ast::LdStateSpace::Param), Vec::new())
+ (ast::Type::Pointer(t), Vec::new())
}
};
(align, array_init, v_type, name)
@@ -763,7 +763,7 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
ast::Instruction::Ld(
ast::LdDetails {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
- state_space: ss.unwrap_or(ast::LdStateSpace::Generic),
+ state_space: ss.unwrap_or(ast::StateSpace::Generic),
caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
typ: t,
non_coherent: false
@@ -775,7 +775,7 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
ast::Instruction::Ld(
ast::LdDetails {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
- state_space: ast::LdStateSpace::Global,
+ state_space: ast::StateSpace::Global,
caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
typ: t,
non_coherent: false
@@ -787,7 +787,7 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
ast::Instruction::Ld(
ast::LdDetails {
qualifier: ast::LdStQualifier::Weak,
- state_space: ast::LdStateSpace::Global,
+ state_space: ast::StateSpace::Global,
caching: cop.unwrap_or(ast::LdCacheOperator::Cached),
typ: t,
non_coherent: true
@@ -797,9 +797,9 @@ InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
}
};
-LdStType: ast::PointerType = {
- <v:VectorPrefix> <t:LdStScalarType> => ast::PointerType::Vector(t, v),
- <t:LdStScalarType> => ast::PointerType::Scalar(t),
+LdStType: ast::Type = {
+ <v:VectorPrefix> <t:LdStScalarType> => ast::Type::Vector(t, v),
+ <t:LdStScalarType> => ast::Type::Scalar(t),
}
LdStQualifier: ast::LdStQualifier = {
@@ -815,11 +815,11 @@ MemScope: ast::MemScope = {
".sys" => ast::MemScope::Sys
};
-LdNonGlobalStateSpace: ast::LdStateSpace = {
- ".const" => ast::LdStateSpace::Const,
- ".local" => ast::LdStateSpace::Local,
- ".param" => ast::LdStateSpace::Param,
- ".shared" => ast::LdStateSpace::Shared,
+LdNonGlobalStateSpace: ast::StateSpace = {
+ ".const" => ast::StateSpace::Const,
+ ".local" => ast::StateSpace::Local,
+ ".param" => ast::StateSpace::Param,
+ ".shared" => ast::StateSpace::Shared,
};
LdCacheOperator: ast::LdCacheOperator = {
@@ -1235,7 +1235,7 @@ InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
ast::Instruction::St(
ast::StData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
- state_space: ss.unwrap_or(ast::StStateSpace::Generic),
+ state_space: ss.unwrap_or(ast::StateSpace::Generic),
caching: cop.unwrap_or(ast::StCacheOperator::Writeback),
typ: t
},
@@ -1249,11 +1249,11 @@ MemoryOperand: ast::Operand<&'input str> = {
"[" <o:Operand> "]" => o
}
-StStateSpace: ast::StStateSpace = {
- ".global" => ast::StStateSpace::Global,
- ".local" => ast::StStateSpace::Local,
- ".param" => ast::StStateSpace::Param,
- ".shared" => ast::StStateSpace::Shared,
+StStateSpace: ast::StateSpace = {
+ ".global" => ast::StateSpace::Global,
+ ".local" => ast::StateSpace::Local,
+ ".param" => ast::StateSpace::Param,
+ ".shared" => ast::StateSpace::Shared,
};
StCacheOperator: ast::StCacheOperator = {
@@ -1272,7 +1272,7 @@ InstRet: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstCvta: ast::Instruction<ast::ParsedArgParams<'input>> = {
"cvta" <from:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
ast::Instruction::Cvta(ast::CvtaDetails {
- to: ast::CvtaStateSpace::Generic,
+ to: ast::StateSpace::Generic,
from,
size: s
},
@@ -1281,18 +1281,18 @@ InstCvta: ast::Instruction<ast::ParsedArgParams<'input>> = {
"cvta" ".to" <to:CvtaStateSpace> <s:CvtaSize> <a:Arg2> => {
ast::Instruction::Cvta(ast::CvtaDetails {
to,
- from: ast::CvtaStateSpace::Generic,
+ from: ast::StateSpace::Generic,
size: s
},
a)
}
}
-CvtaStateSpace: ast::CvtaStateSpace = {
- ".const" => ast::CvtaStateSpace::Const,
- ".global" => ast::CvtaStateSpace::Global,
- ".local" => ast::CvtaStateSpace::Local,
- ".shared" => ast::CvtaStateSpace::Shared,
+CvtaStateSpace: ast::StateSpace = {
+ ".const" => ast::StateSpace::Const,
+ ".global" => ast::StateSpace::Global,
+ ".local" => ast::StateSpace::Local,
+ ".shared" => ast::StateSpace::Shared,
}
CvtaSize: ast::CvtaSize = {
@@ -1450,7 +1450,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
inner: ast::AtomInnerDetails::Bit { op, typ }
};
ast::Instruction::Atom(details,a)
@@ -1459,7 +1459,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
inner: ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Inc,
typ: ast::ScalarType::U32
@@ -1471,7 +1471,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
inner: ast::AtomInnerDetails::Unsigned {
op: ast::AtomUIntOp::Dec,
typ: ast::ScalarType::U32
@@ -1484,7 +1484,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
inner: ast::AtomInnerDetails::Float { op, typ }
};
ast::Instruction::Atom(details,a)
@@ -1493,7 +1493,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
inner: ast::AtomInnerDetails::Unsigned { op, typ }
};
ast::Instruction::Atom(details,a)
@@ -1502,7 +1502,7 @@ InstAtom: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
inner: ast::AtomInnerDetails::Signed { op, typ }
};
ast::Instruction::Atom(details,a)
@@ -1514,7 +1514,7 @@ InstAtomCas: ast::Instruction<ast::ParsedArgParams<'input>> = {
let details = ast::AtomCasDetails {
semantics: sema.unwrap_or(ast::AtomSemantics::Relaxed),
scope: scope.unwrap_or(ast::MemScope::Gpu),
- space: space.unwrap_or(ast::AtomSpace::Generic),
+ space: space.unwrap_or(ast::StateSpace::Generic),
typ,
};
ast::Instruction::AtomCas(details,a)
@@ -1528,9 +1528,9 @@ AtomSemantics: ast::AtomSemantics = {
".acq_rel" => ast::AtomSemantics::AcquireRelease
}
-AtomSpace: ast::AtomSpace = {
- ".global" => ast::AtomSpace::Global,
- ".shared" => ast::AtomSpace::Shared
+AtomSpace: ast::StateSpace = {
+ ".global" => ast::StateSpace::Global,
+ ".shared" => ast::StateSpace::Shared
}
AtomBitOp: ast::AtomBitOp = {
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 4ba5729..a743496 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -37,6 +37,12 @@ fn error_unreachable() -> TranslateError {
TranslateError::Unreachable
}
+macro_rules! new_todo {
+ () => {
+ todo!()
+ };
+}
+
#[derive(PartialEq, Eq, Hash, Clone)]
enum SpirvType {
Base(SpirvScalarKey),
@@ -48,52 +54,40 @@ enum SpirvType {
}
impl SpirvType {
- fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
- let key = t.into();
- SpirvType::Pointer(Box::new(key), sc)
- }
-}
-
-impl From<ast::Type> for SpirvType {
- fn from(t: ast::Type) -> Self {
+ fn new(t: ast::Type, decl_space: ast::StateSpace) -> Self {
match t {
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
- ast::Type::Pointer(pointer_t, state_space) => SpirvType::Pointer(
- Box::new(SpirvType::from(ast::Type::from(pointer_t))),
- state_space.to_spirv(),
- ),
+ ast::Type::Pointer(pointer_t) => {
+ let spirv_space = match decl_space {
+ ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => {
+ spirv::StorageClass::Private
+ }
+ ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup,
+ ast::StateSpace::Const => spirv::StorageClass::UniformConstant,
+ ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
+ ast::StateSpace::Generic => spirv::StorageClass::Generic,
+ ast::StateSpace::Sreg => spirv::StorageClass::Input,
+ };
+ SpirvType::Pointer(Box::new(SpirvType::Base(pointer_t.into())), spirv_space)
+ }
}
}
-}
-impl From<ast::PointerType> for ast::Type {
- fn from(t: ast::PointerType) -> Self {
- match t {
- ast::PointerType::Scalar(t) => ast::Type::Scalar(t),
- ast::PointerType::Vector(t, len) => ast::Type::Vector(t, len),
- ast::PointerType::Array(t, dims) => ast::Type::Array(t, dims),
- ast::PointerType::Pointer(t, space) => {
- ast::Type::Pointer(ast::PointerType::Scalar(t), space)
- }
- }
+ fn pointer_to(
+ t: ast::Type,
+ inner_space: ast::StateSpace,
+ outer_space: spirv::StorageClass,
+ ) -> Self {
+ let key = Self::new(t, inner_space);
+ SpirvType::Pointer(Box::new(key), outer_space)
}
}
impl ast::Type {
- fn param_pointer_to(self, space: ast::LdStateSpace) -> Result<Self, TranslateError> {
- Ok(match self {
- ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
- ast::Type::Vector(t, len) => {
- ast::Type::Pointer(ast::PointerType::Vector(t, len), space)
- }
- ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space),
- ast::Type::Pointer(ast::PointerType::Scalar(t), space) => {
- ast::Type::Pointer(ast::PointerType::Pointer(t, space), space)
- }
- ast::Type::Pointer(_, _) => return Err(error_unreachable()),
- })
+ fn param_pointer_to(self, space: ast::StateSpace) -> Result<Self, TranslateError> {
+ Ok(self)
}
}
@@ -398,18 +392,7 @@ impl TypeWordMap {
b.constant_composite(result_type, None, components.into_iter())
}
},
- ast::Type::Pointer(typ, state_space) => {
- let base_t = typ.clone().into();
- let base = self.get_or_add_constant(b, &base_t, &[])?;
- let result_type = self.get_or_add(
- b,
- SpirvType::Pointer(
- Box::new(SpirvType::from(base_t)),
- (*state_space).to_spirv(),
- ),
- );
- b.variable(result_type, None, (*state_space).to_spirv(), Some(base))
- }
+ ast::Type::Pointer(typ) => return Err(error_unreachable()),
})
}
@@ -702,11 +685,29 @@ fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>,
}
}
-// PTX represents dynamically allocated shared local memory as
-// .extern .shared .align 4 .b8 shared_mem[];
-// In SPIRV/OpenCL world this is expressed as an additional argument
-// This pass looks for all uses of .extern .shared and converts them to
-// an additional method argument
+/*
+ PTX represents dynamically allocated shared local memory as
+ .extern .shared .b32 shared_mem[];
+ In SPIRV/OpenCL world this is expressed as an additional argument
+ This pass looks for all uses of .extern .shared and converts them to
+ an additional method argument
+ The question is how this artificial argument should be expressed. There are
+ several options:
+ * Straight conversion:
+ .shared .b32 shared_mem[]
+ * Introduce .param_shared statespace:
+ .param_shared .b32 shared_mem
+ or
+ .param_shared .b32 shared_mem[]
+ * Introduce .shared_ptr <SCALAR> type:
+ .param .shared_ptr .b32 shared_mem
+ * Reuse .ptr hint:
+ .param .u64 .ptr shared_mem
+ This is the most tempting, but also the most nonsensical, .ptr is just a
+ hint, which has no semantical meaning (and the output of our
+ transformation has a semantical meaning - we emit additional
+ "OpFunctionParameter ..." with type "OpTypePointer Workgroup ...")
+*/
fn convert_dynamic_shared_memory_usage<'input>(
module: Vec<Directive<'input>>,
new_id: &mut impl FnMut() -> spirv::Word,
@@ -715,7 +716,7 @@ fn convert_dynamic_shared_memory_usage<'input>(
for dir in module.iter() {
match dir {
Directive::Variable(ast::Variable {
- v_type: ast::Type::Pointer(p_type, ast::LdStateSpace::Shared),
+ v_type: ast::Type::Pointer(p_type),
state_space: ast::StateSpace::Shared,
name,
..
@@ -799,48 +800,23 @@ fn convert_dynamic_shared_memory_usage<'input>(
ast::Variable {
name: shared_id_param,
align: None,
- v_type: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::B8),
- ast::LdStateSpace::Shared,
- ),
- state_space: ast::StateSpace::Param,
+ v_type: ast::Type::Pointer(ast::ScalarType::B8),
+ state_space: ast::StateSpace::Shared,
array_init: Vec::new(),
}
});
spirv_decl.uses_shared_mem = true;
- let shared_var_id = new_id();
- let shared_var = ExpandedStatement::Variable(ast::Variable {
- name: shared_var_id,
- align: None,
- v_type: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::B8),
- ast::LdStateSpace::Shared,
- ),
- state_space: ast::StateSpace::Reg,
- array_init: Vec::new(),
- });
- let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails {
- arg: ast::Arg2St {
- src1: shared_var_id,
- src2: shared_id_param,
- },
- typ: ast::Type::Scalar(ast::ScalarType::B8),
- member_index: None,
- });
- let mut new_statements = vec![shared_var, shared_var_st];
- replace_uses_of_shared_memory(
- &mut new_statements,
+ let statements = replace_uses_of_shared_memory(
new_id,
&extern_shared_decls,
&mut methods_using_extern_shared,
shared_id_param,
- shared_var_id,
statements,
);
Directive::Method(Function {
func_decl,
globals,
- body: Some(new_statements),
+ body: Some(statements),
import_as,
spirv_decl,
tuning,
@@ -852,14 +828,13 @@ fn convert_dynamic_shared_memory_usage<'input>(
}
fn replace_uses_of_shared_memory<'a>(
- result: &mut Vec<ExpandedStatement>,
new_id: &mut impl FnMut() -> spirv::Word,
- extern_shared_decls: &HashMap<spirv::Word, ast::PointerType>,
+ extern_shared_decls: &HashMap<spirv::Word, ast::ScalarType>,
methods_using_extern_shared: &mut HashSet<MethodName<'a>>,
shared_id_param: spirv::Word,
- shared_var_id: spirv::Word,
statements: Vec<ExpandedStatement>,
-) {
+) -> Vec<ExpandedStatement> {
+ let mut result = Vec::with_capacity(statements.len());
for statement in statements {
match statement {
Statement::Call(mut call) => {
@@ -877,22 +852,18 @@ fn replace_uses_of_shared_memory<'a>(
}
statement => {
let new_statement = statement.map_id(&mut |id, _| {
- if let Some(ast::PointerType::Scalar(typ)) = extern_shared_decls.get(&id) {
- if *typ == ast::ScalarType::B8 {
- return shared_var_id;
+ if let Some(scalar_type) = extern_shared_decls.get(&id) {
+ if *scalar_type == ast::ScalarType::B8 {
+ return shared_id_param;
}
let replacement_id = new_id();
result.push(Statement::Conversion(ImplicitConversion {
- src: shared_var_id,
+ src: shared_id_param,
dst: replacement_id,
- from: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::B8),
- ast::LdStateSpace::Shared,
- ),
- to: ast::Type::Pointer(
- ast::PointerType::Scalar((*typ).into()),
- ast::LdStateSpace::Shared,
- ),
+ from_type: ast::Type::Pointer(ast::ScalarType::B8),
+ from_space: ast::StateSpace::Shared,
+ to_type: ast::Type::Pointer((*scalar_type).into()),
+ to_space: ast::StateSpace::Shared,
kind: ConversionKind::PtrToPtr { spirv_ptr: true },
src_sema: ArgumentSemantics::Default,
dst_sema: ArgumentSemantics::Default,
@@ -906,6 +877,7 @@ fn replace_uses_of_shared_memory<'a>(
}
}
}
+ result
}
fn get_callers_of_extern_shared<'a>(
@@ -1055,8 +1027,9 @@ fn emit_builtins(
for (reg, id) in id_defs.special_registers.builtins() {
let result_type = map.get_or_add(
builder,
- SpirvType::Pointer(
- Box::new(SpirvType::from(reg.get_type())),
+ SpirvType::pointer_to(
+ reg.get_type(),
+ ast::StateSpace::Reg,
spirv::StorageClass::Input,
),
);
@@ -1158,7 +1131,10 @@ fn emit_function_header<'a>(
}
*/
for input in &func_decl.input {
- let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone()));
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::new(input.v_type.clone(), input.state_space),
+ );
builder.function_parameter(Some(input.name), result_type)?;
}
Ok(fn_id)
@@ -1219,26 +1195,26 @@ fn translate_variable<'a>(
is_variable = true;
var_type
}
- ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?,
- ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
- ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
+ ast::StateSpace::Const => var_type.param_pointer_to(ast::StateSpace::Const)?,
+ ast::StateSpace::Global => var_type.param_pointer_to(ast::StateSpace::Global)?,
+ ast::StateSpace::Local => var_type.param_pointer_to(ast::StateSpace::Local)?,
ast::StateSpace::Shared => {
// If it's a pointer it will be translated to a method parameter later
if let ast::Type::Pointer(..) = var_type {
is_variable = true;
var_type
} else {
- var_type.param_pointer_to(ast::LdStateSpace::Shared)?
+ var_type.param_pointer_to(ast::StateSpace::Shared)?
}
}
- ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
- ast::StateSpace::Generic => todo!(),
+ ast::StateSpace::Param => var_type.param_pointer_to(ast::StateSpace::Param)?,
+ ast::StateSpace::Generic | ast::StateSpace::Sreg => return Err(error_unreachable()),
};
Ok(ast::Variable {
align: var.align,
v_type: var.v_type,
state_space: var.state_space,
- name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable),
+ name: id_defs.get_or_add_def_typed(var.name, var_type, var.state_space, is_variable),
array_init: var.array_init,
})
}
@@ -1283,7 +1259,10 @@ fn expand_kernel_params<'a, 'b>(
Ok(ast::KernelArgument {
name: fn_resolver.add_def(
a.name,
- Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?),
+ Some((
+ ast::Type::from(a.v_type.clone()).param_pointer_to(ast::StateSpace::Param)?,
+ a.state_space,
+ )),
false,
),
v_type: a.v_type.clone(),
@@ -1302,7 +1281,7 @@ fn expand_fn_params<'a, 'b>(
args.map(|a| {
let is_variable = a.state_space == ast::StateSpace::Reg;
Ok(ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some(a.v_type.clone()), is_variable),
+ name: fn_resolver.add_def(a.name, Some((a.v_type.clone(), a.state_space)), is_variable),
v_type: a.v_type.clone(),
state_space: a.state_space,
align: a.align,
@@ -1339,15 +1318,15 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?;
let typed_statements =
convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?;
- let typed_statements =
- convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
+ //let typed_statements =
+ // convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?;
let ssa_statements = insert_mem_ssa_statements(
typed_statements,
&mut numeric_id_defs,
&f_args,
&mut spirv_decl,
)?;
- let ssa_statements = fix_builtins(ssa_statements, &mut numeric_id_defs)?;
+ let ssa_statements = fix_special_registers(ssa_statements, &mut numeric_id_defs)?;
let mut numeric_id_defs = numeric_id_defs.finish();
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?;
let expanded_statements =
@@ -1366,7 +1345,7 @@ fn to_ssa<'input, 'b>(
})
}
-fn fix_builtins(
+fn fix_special_registers(
typed_statements: Vec<TypedStatement>,
numeric_id_defs: &mut NumericIdResolver,
) -> Result<Vec<TypedStatement>, TranslateError> {
@@ -1402,7 +1381,8 @@ fn fix_builtins(
continue;
}
};
- let temp_id = numeric_id_defs.new_non_variable(Some(details.typ.clone()));
+ let temp_id = numeric_id_defs
+ .register_intermediate(Some((details.typ.clone(), details.state_space)));
let real_dst = details.arg.dst;
details.arg.dst = temp_id;
result.push(Statement::LoadVar(LoadVarDetails {
@@ -1410,14 +1390,17 @@ fn fix_builtins(
src: sreg_src,
dst: temp_id,
},
+ state_space: ast::StateSpace::Sreg,
typ: ast::Type::Scalar(scalar_typ),
member_index: Some((index, Some(vector_width))),
}));
result.push(Statement::Conversion(ImplicitConversion {
src: temp_id,
dst: real_dst,
- from: ast::Type::Scalar(scalar_typ),
- to: ast::Type::Scalar(ast::ScalarType::U32),
+ from_type: ast::Type::Scalar(scalar_typ),
+ from_space: ast::StateSpace::Sreg,
+ to_type: ast::Type::Scalar(ast::ScalarType::U32),
+ to_space: ast::StateSpace::Sreg,
kind: ConversionKind::Default,
src_sema: ArgumentSemantics::Default,
dst_sema: ArgumentSemantics::Default,
@@ -1614,12 +1597,12 @@ fn convert_to_typed_statements(
}
ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => {
if let Some(src_id) = src.underlying() {
- let (typ, _) = id_defs.get_typed(*src_id)?;
+ let (typ, _, _) = id_defs.get_typed(*src_id)?;
let take_address = match typ {
- ast::Type::Scalar(_) => false,
- ast::Type::Vector(_, _) => false,
- ast::Type::Array(_, _) => true,
- ast::Type::Pointer(_, _) => true,
+ ast::Type::Scalar(..) => false,
+ ast::Type::Vector(..) => false,
+ ast::Type::Array(..) => true,
+ ast::Type::Pointer(..) => true,
};
d.src_is_address = take_address;
}
@@ -1666,6 +1649,7 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
is_dst: bool,
vector_sema: ArgumentSemantics,
typ: &ast::Type,
+ state_space: ast::StateSpace,
idx: Vec<spirv::Word>,
) -> Result<spirv::Word, TranslateError> {
// mov.u32 foobar, {a,b};
@@ -1673,7 +1657,9 @@ impl<'a, 'b> VectorRepackVisitor<'a, 'b> {
ast::Type::Vector(scalar_t, _) => *scalar_t,
_ => return Err(TranslateError::MismatchedType),
};
- let temp_vec = self.id_def.new_non_variable(Some(typ.clone()));
+ let temp_vec = self
+ .id_def
+ .register_intermediate(Some((typ.clone(), state_space)));
let statement = Statement::RepackVector(RepackVectorDetails {
is_extract: is_dst,
typ: scalar_t,
@@ -1696,7 +1682,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- _: Option<&ast::Type>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(desc.op)
}
@@ -1705,15 +1691,20 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, TypedArgParams>
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
ast::Operand::Reg(reg) => TypedOperand::Reg(reg),
ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset),
ast::Operand::Imm(x) => TypedOperand::Imm(x),
ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx),
- ast::Operand::VecPack(vec) => {
- TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?)
- }
+ ast::Operand::VecPack(vec) => TypedOperand::Reg(self.convert_vector(
+ desc.is_dst,
+ desc.sema,
+ typ,
+ state_space,
+ vec,
+ )?),
})
}
}
@@ -1735,37 +1726,33 @@ fn to_ptx_impl_atomic_call(
semantics, scope, space, op
);
// TODO: extract to a function
- let ptr_space = match details.space {
- ast::AtomSpace::Generic => ast::LdStateSpace::Generic,
- ast::AtomSpace::Global => ast::LdStateSpace::Global,
- ast::AtomSpace::Shared => ast::LdStateSpace::Shared,
- };
+ let ptr_space = details.space;
let scalar_typ = ast::ScalarType::from(typ);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
+ let fn_id = id_defs.register_intermediate(None);
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
v_type: ast::Type::Scalar(scalar_typ),
state_space: ast::StateSpace::Reg,
- name: id_defs.new_non_variable(None),
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
fn_id,
vec![
ast::FnArgument {
align: None,
- v_type: ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space),
- state_space: ast::StateSpace::Reg,
- name: id_defs.new_non_variable(None),
+ v_type: ast::Type::Pointer(typ),
+ state_space: ptr_space,
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::Type::Scalar(scalar_typ),
state_space: ast::StateSpace::Reg,
- name: id_defs.new_non_variable(None),
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
@@ -1795,11 +1782,7 @@ fn to_ptx_impl_atomic_call(
func: fn_id,
ret_params: vec![(arg.dst, ast::Type::Scalar(scalar_typ), ast::StateSpace::Reg)],
param_list: vec![
- (
- arg.src1,
- ast::Type::Pointer(ast::PointerType::Scalar(typ), ptr_space),
- ast::StateSpace::Reg,
- ),
+ (arg.src1, ast::Type::Pointer(typ), ptr_space),
(
arg.src2,
ast::Type::Scalar(scalar_typ),
@@ -1826,13 +1809,13 @@ fn to_ptx_impl_bfe_call(
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
+ let fn_id = id_defs.register_intermediate(None);
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
- name: id_defs.new_non_variable(None),
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
fn_id,
@@ -1841,21 +1824,21 @@ fn to_ptx_impl_bfe_call(
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
- name: id_defs.new_non_variable(None),
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
- name: id_defs.new_non_variable(None),
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
- name: id_defs.new_non_variable(None),
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
@@ -1919,13 +1902,13 @@ fn to_ptx_impl_bfi_call(
let fn_name = format!("{}{}", prefix, suffix);
let fn_id = match ptx_impl_imports.entry(fn_name) {
hash_map::Entry::Vacant(entry) => {
- let fn_id = id_defs.new_non_variable(None);
+ let fn_id = id_defs.register_intermediate(None);
let func_decl = ast::MethodDecl::Func::<spirv::Word>(
vec![ast::FnArgument {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
- name: id_defs.new_non_variable(None),
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
}],
fn_id,
@@ -1934,28 +1917,28 @@ fn to_ptx_impl_bfi_call(
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
- name: id_defs.new_non_variable(None),
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::Type::Scalar(typ.into()),
state_space: ast::StateSpace::Reg,
- name: id_defs.new_non_variable(None),
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
- name: id_defs.new_non_variable(None),
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
ast::FnArgument {
align: None,
v_type: ast::Type::Scalar(ast::ScalarType::U32),
state_space: ast::StateSpace::Reg,
- name: id_defs.new_non_variable(None),
+ name: id_defs.register_intermediate(None),
array_init: Vec::new(),
},
],
@@ -2048,7 +2031,7 @@ fn normalize_labels(
| Statement::RepackVector(..) => {}
}
}
- iter::once(Statement::Label(id_def.new_non_variable(None)))
+ iter::once(Statement::Label(id_def.register_intermediate(None)))
.chain(func.into_iter().filter(|s| match s {
Statement::Label(i) => labels_in_use.contains(i),
_ => true,
@@ -2066,8 +2049,8 @@ fn normalize_predicates(
Statement::Label(id) => result.push(Statement::Label(id)),
Statement::Instruction((pred, inst)) => {
if let Some(pred) = pred {
- let if_true = id_def.new_non_variable(None);
- let if_false = id_def.new_non_variable(None);
+ let if_true = id_def.register_intermediate(None);
+ let if_false = id_def.register_intermediate(None);
let folded_bra = match &inst {
ast::Instruction::Bra(_, arg) => Some(arg.src),
_ => None,
@@ -2116,7 +2099,8 @@ fn insert_mem_ssa_statements<'a, 'b>(
}
for spirv_arg in fn_decl.input.iter_mut() {
let typ = spirv_arg.v_type.clone();
- let new_id = id_def.new_non_variable(Some(typ.clone()));
+ let state_space = spirv_arg.state_space;
+ let new_id = id_def.register_intermediate(Some((typ.clone(), state_space)));
result.push(Statement::Variable(ast::Variable {
align: spirv_arg.align,
v_type: spirv_arg.v_type.clone(),
@@ -2129,6 +2113,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
src1: spirv_arg.name,
src2: new_id,
},
+ state_space,
typ,
member_index: None,
}));
@@ -2143,13 +2128,15 @@ fn insert_mem_ssa_statements<'a, 'b>(
ast::Instruction::Ret(d) => {
// TODO: handle multiple output args
if let &[out_param] = &fn_decl.output.as_slice() {
- let (typ, _) = id_def.get_typed(out_param.name)?;
- let new_id = id_def.new_non_variable(Some(typ.clone()));
+ let (typ, space, _) = id_def.get_typed(out_param.name)?;
+ let new_id = id_def.register_intermediate(Some((typ.clone(), space)));
result.push(Statement::LoadVar(LoadVarDetails {
arg: ast::Arg2 {
dst: new_id,
src: out_param.name,
},
+ // TODO: ret with stateful conversion
+ state_space: new_todo!(),
typ: typ.clone(),
member_index: None,
}));
@@ -2161,13 +2148,16 @@ fn insert_mem_ssa_statements<'a, 'b>(
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
},
Statement::Conditional(mut bra) => {
- let generated_id =
- id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
+ let generated_id = id_def.register_intermediate(Some((
+ ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )));
result.push(Statement::LoadVar(LoadVarDetails {
arg: Arg2 {
dst: generated_id,
src: bra.predicate,
},
+ state_space: ast::StateSpace::Reg,
typ: ast::Type::Scalar(ast::ScalarType::Pred),
member_index: None,
}));
@@ -2204,6 +2194,7 @@ struct VisitArgumentDescriptor<
> {
desc: ArgumentDescriptor<spirv::Word>,
typ: &'a ast::Type,
+ state_space: ast::StateSpace,
stmt_ctor: Ctor,
}
@@ -2218,7 +2209,9 @@ impl<
self,
visitor: &mut impl ArgumentMapVisitor<T, U>,
) -> Result<Statement<ast::Instruction<U>, U>, TranslateError> {
- Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?))
+ Ok((self.stmt_ctor)(
+ visitor.id(self.desc, Some((self.typ, self.state_space)))?,
+ ))
}
}
@@ -2232,13 +2225,13 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
fn symbol(
&mut self,
desc: ArgumentDescriptor<(spirv::Word, Option<u8>)>,
- expected_type: Option<&ast::Type>,
+ expected_type: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
let symbol = desc.op.0;
if expected_type.is_none() {
return Ok(symbol);
};
- let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?;
+ let (mut var_type, _, is_variable) = self.id_def.get_typed(symbol)?;
if !is_variable {
return Ok(symbol);
};
@@ -2262,13 +2255,16 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
}
None => None,
};
- let generated_id = self.id_def.new_non_variable(Some(var_type.clone()));
+ let generated_id = self
+ .id_def
+ .register_intermediate(Some((var_type.clone(), ast::StateSpace::Reg)));
if !desc.is_dst {
self.func.push(Statement::LoadVar(LoadVarDetails {
arg: Arg2 {
dst: generated_id,
src: symbol,
},
+ state_space: ast::StateSpace::Reg,
typ: var_type,
member_index,
}));
@@ -2279,6 +2275,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
src1: symbol,
src2: generated_id,
},
+ state_space: ast::StateSpace::Reg,
typ: var_type,
member_index: member_index.map(|(idx, _)| idx),
}));
@@ -2293,7 +2290,7 @@ impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- typ: Option<&ast::Type>,
+ typ: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self.symbol(desc.new_op((desc.op, None)), typ)
}
@@ -2302,18 +2299,20 @@ impl<'a, 'input> ArgumentMapVisitor<TypedArgParams, TypedArgParams>
&mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
TypedOperand::Reg(reg) => {
- TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?)
- }
- TypedOperand::RegOffset(reg, offset) => {
- TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset)
+ TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?)
}
+ TypedOperand::RegOffset(reg, offset) => TypedOperand::RegOffset(
+ self.symbol(desc.new_op((reg, None)), Some((typ, state_space)))?,
+ offset,
+ ),
op @ TypedOperand::Imm(..) => op,
- TypedOperand::VecMember(symbol, index) => {
- TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?)
- }
+ TypedOperand::VecMember(symbol, index) => TypedOperand::Reg(
+ self.symbol(desc.new_op((symbol, Some(index))), Some((typ, state_space)))?,
+ ),
})
}
}
@@ -2411,7 +2410,7 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
fn reg(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- _: Option<&ast::Type>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
Ok(desc.op)
}
@@ -2420,30 +2419,31 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
&mut self,
desc: ArgumentDescriptor<(spirv::Word, i32)>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
let (reg, offset) = desc.op;
let add_type;
match typ {
- ast::Type::Pointer(underlying_type, state_space) => {
- let reg_typ = self.id_def.get_typed(reg)?;
- if let ast::Type::Pointer(_, _) = reg_typ {
- let id_constant_stmt = self.id_def.new_non_variable(typ.clone());
+ ast::Type::Pointer(underlying_type) => {
+ let (reg_typ, space) = self.id_def.get_typed(reg)?;
+ if let ast::Type::Pointer(..) = reg_typ {
+ let id_constant_stmt = self.id_def.register_intermediate(typ.clone(), space);
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: ast::ScalarType::S64,
value: ast::ImmediateValue::S64(offset as i64),
}));
- let dst = self.id_def.new_non_variable(typ.clone());
+ let dst = self.id_def.register_intermediate(typ.clone(), space);
self.func.push(Statement::PtrAccess(PtrAccess {
- underlying_type: underlying_type.clone(),
- state_space: *state_space,
+ underlying_type: *underlying_type,
+ state_space: state_space,
dst,
ptr_src: reg,
offset_src: id_constant_stmt,
}));
return Ok(dst);
} else {
- add_type = self.id_def.get_typed(reg)?;
+ add_type = self.id_def.get_typed(reg)?.0;
}
}
_ => {
@@ -2475,8 +2475,12 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
ast::ScalarKind::Unsigned,
))
};
- let id_constant_stmt = self.id_def.new_non_variable(add_type.clone());
- let result_id = self.id_def.new_non_variable(add_type);
+ let id_constant_stmt = self
+ .id_def
+ .register_intermediate(add_type.clone(), ast::StateSpace::Reg);
+ let result_id = self
+ .id_def
+ .register_intermediate(add_type, ast::StateSpace::Reg);
// TODO: check for edge cases around min value/max value/wrapping
if offset < 0 && kind != ast::ScalarKind::Signed {
self.func.push(Statement::Constant(ConstantDefinition {
@@ -2518,13 +2522,16 @@ impl<'a, 'b> FlattenArguments<'a, 'b> {
&mut self,
desc: ArgumentDescriptor<ast::ImmediateValue>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
let scalar_t = if let ast::Type::Scalar(scalar) = typ {
*scalar
} else {
todo!()
};
- let id = self.id_def.new_non_variable(ast::Type::Scalar(scalar_t));
+ let id = self
+ .id_def
+ .register_intermediate(ast::Type::Scalar(scalar_t), state_space);
self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
@@ -2538,7 +2545,7 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self.reg(desc, t)
}
@@ -2547,12 +2554,13 @@ impl<'a, 'b> ArgumentMapVisitor<TypedArgParams, ExpandedArgParams> for FlattenAr
&mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
match desc.op {
- TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)),
- TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ),
+ TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some((typ, state_space))),
+ TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ, state_space),
TypedOperand::RegOffset(reg, offset) => {
- self.reg_offset(desc.new_op((reg, offset)), typ)
+ self.reg_offset(desc.new_op((reg, offset)), typ, state_space)
}
TypedOperand::VecMember(..) => Err(error_unreachable()),
}
@@ -2580,39 +2588,29 @@ fn insert_implicit_conversions(
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
match s {
- Statement::Call(call) => insert_implicit_conversions_impl(
- &mut result,
- id_def,
- call,
- should_bitcast_wrapper,
- None,
- )?,
+ Statement::Call(call) => {
+ insert_implicit_conversions_impl(&mut result, id_def, call, should_bitcast_wrapper)?
+ }
Statement::Instruction(inst) => {
let mut default_conversion_fn =
- should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _;
+ should_bitcast_wrapper as for<'a> fn(&'a _, _, &'a _, _) -> _;
let mut state_space = None;
if let ast::Instruction::Ld(d, _) = &inst {
state_space = Some(d.state_space);
}
if let ast::Instruction::St(d, _) = &inst {
- state_space = Some(d.state_space.to_ld_ss());
+ state_space = Some(d.state_space);
}
if let ast::Instruction::Atom(d, _) = &inst {
- state_space = Some(d.space.to_ld_ss());
+ state_space = Some(d.space);
}
if let ast::Instruction::AtomCas(d, _) = &inst {
- state_space = Some(d.space.to_ld_ss());
+ state_space = Some(d.space);
}
if let ast::Instruction::Mov(..) = &inst {
default_conversion_fn = should_bitcast_packed;
}
- insert_implicit_conversions_impl(
- &mut result,
- id_def,
- inst,
- default_conversion_fn,
- state_space,
- )?;
+ insert_implicit_conversions_impl(&mut result, id_def, inst, default_conversion_fn)?;
}
Statement::PtrAccess(PtrAccess {
underlying_type,
@@ -2627,7 +2625,8 @@ fn insert_implicit_conversions(
is_dst: false,
sema: ArgumentSemantics::PhysicalPointer,
},
- typ: &ast::Type::Pointer(underlying_type.clone(), state_space),
+ typ: &ast::Type::Pointer(underlying_type),
+ state_space,
stmt_ctor: |new_ptr_src| {
Statement::PtrAccess(PtrAccess {
underlying_type,
@@ -2643,7 +2642,6 @@ fn insert_implicit_conversions(
id_def,
visit_desc,
bitcast_physical_pointer,
- Some(state_space),
)?;
}
Statement::RepackVector(repack) => insert_implicit_conversions_impl(
@@ -2651,7 +2649,6 @@ fn insert_implicit_conversions(
id_def,
repack,
should_bitcast_wrapper,
- None,
)?,
s @ Statement::Conditional(_)
| s @ Statement::Conversion(_)
@@ -2672,19 +2669,20 @@ fn insert_implicit_conversions_impl(
stmt: impl Visitable<ExpandedArgParams, ExpandedArgParams>,
default_conversion_fn: for<'a> fn(
&'a ast::Type,
+ ast::StateSpace,
&'a ast::Type,
- Option<ast::LdStateSpace>,
+ ast::StateSpace,
) -> Result<Option<ConversionKind>, TranslateError>,
- state_space: Option<ast::LdStateSpace>,
) -> Result<(), TranslateError> {
let mut post_conv = Vec::new();
- let statement = stmt.visit(
- &mut |desc: ArgumentDescriptor<spirv::Word>, typ: Option<&ast::Type>| {
- let instr_type = match typ {
+ let statement =
+ stmt.visit(&mut |desc: ArgumentDescriptor<spirv::Word>,
+ typ: Option<(&ast::Type, ast::StateSpace)>| {
+ let (instr_type, instruction_space) = match typ {
None => return Ok(desc.op),
Some(t) => t,
};
- let operand_type = id_def.get_typed(desc.op)?;
+ let (operand_type, operand_space) = id_def.get_typed(desc.op)?;
let mut conversion_fn = default_conversion_fn;
match desc.sema {
ArgumentSemantics::Default => {}
@@ -2705,27 +2703,33 @@ fn insert_implicit_conversions_impl(
conversion_fn = force_bitcast_ptr_to_bit;
}
};
- match conversion_fn(&operand_type, instr_type, state_space)? {
+ match conversion_fn(&operand_type, operand_space, instr_type, instruction_space)? {
Some(conv_kind) => {
let conv_output = if desc.is_dst {
&mut post_conv
} else {
&mut *func
};
- let mut from = instr_type.clone();
- let mut to = operand_type;
- let mut src = id_def.new_non_variable(instr_type.clone());
+ let mut from_type = instr_type.clone();
+ let mut from_space = instruction_space;
+ let mut to_type = operand_type;
+ let mut to_space = operand_space;
+ let mut src =
+ id_def.register_intermediate(instr_type.clone(), instruction_space);
let mut dst = desc.op;
let result = Ok(src);
if !desc.is_dst {
mem::swap(&mut src, &mut dst);
- mem::swap(&mut from, &mut to);
+ mem::swap(&mut from_type, &mut to_type);
+ mem::swap(&mut from_space, &mut to_space);
}
conv_output.push(Statement::Conversion(ImplicitConversion {
src,
dst,
- from,
- to,
+ from_type,
+ from_space,
+ to_type,
+ to_space,
kind: conv_kind,
src_sema: ArgumentSemantics::Default,
dst_sema: ArgumentSemantics::Default,
@@ -2734,8 +2738,7 @@ fn insert_implicit_conversions_impl(
}
None => Ok(desc.op),
}
- },
- )?;
+ })?;
func.push(statement);
func.append(&mut post_conv);
Ok(())
@@ -2751,10 +2754,10 @@ fn get_function_type(
builder,
spirv_input
.iter()
- .map(|var| SpirvType::from(var.v_type.clone())),
+ .map(|var| SpirvType::new(var.v_type.clone(), var.state_space)),
spirv_output
.iter()
- .map(|var| SpirvType::from(var.v_type.clone())),
+ .map(|var| SpirvType::new(var.v_type.clone(), var.state_space)),
)
}
@@ -2782,8 +2785,8 @@ fn emit_function_body_ops(
Statement::Label(_) => (),
Statement::Call(call) => {
let (result_type, result_id) = match &*call.ret_params {
- [(id, typ, _)] => (
- map.get_or_add(builder, SpirvType::from(typ.clone())),
+ [(id, typ, space)] => (
+ map.get_or_add(builder, SpirvType::new(typ.clone(), *space)),
Some(*id),
),
[] => (map.void(), None),
@@ -2915,8 +2918,10 @@ fn emit_function_body_ops(
if data.qualifier != ast::LdStQualifier::Weak {
todo!()
}
- let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone())));
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::new(ast::Type::from(data.typ.clone()), data.state_space),
+ );
builder.load(
result_type,
Some(arg.dst),
@@ -2947,8 +2952,10 @@ fn emit_function_body_ops(
// SPIR-V does not support ret as guaranteed-converged
ast::Instruction::Ret(_) => builder.ret()?,
ast::Instruction::Mov(d, arg) => {
- let result_type =
- map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone())));
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::new(ast::Type::from(d.typ.clone()), ast::StateSpace::Reg),
+ );
builder.copy_object(result_type, Some(arg.dst), arg.src)?;
}
ast::Instruction::Mul(mul, arg) => match mul {
@@ -2989,7 +2996,8 @@ fn emit_function_body_ops(
ast::Instruction::Shl(t, a) => {
let full_type = ast::Type::Scalar(*t);
let size_of = full_type.size_of();
- let result_type = map.get_or_add(builder, SpirvType::from(full_type));
+ let result_type =
+ map.get_or_add(builder, SpirvType::new(full_type, ast::StateSpace::Reg));
let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?;
builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?;
}
@@ -3251,8 +3259,9 @@ fn emit_function_body_ops(
Some(index) => {
let result_ptr_type = map.get_or_add(
builder,
- SpirvType::new_pointer(
+ SpirvType::pointer_to(
details.typ.clone(),
+ details.state_space,
spirv::StorageClass::Function,
),
);
@@ -3284,14 +3293,11 @@ fn emit_function_body_ops(
}) => {
let u8_pointer = map.get_or_add(
builder,
- SpirvType::from(ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- *state_space,
- )),
+ SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8), *state_space),
);
let result_type = map.get_or_add(
builder,
- SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)),
+ SpirvType::new(ast::Type::Pointer(*underlying_type), *state_space),
);
let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?;
let temp = builder.in_bounds_ptr_access_chain(
@@ -3503,11 +3509,16 @@ fn ptx_scope_name(scope: ast::MemScope) -> &'static str {
}
}
-fn ptx_space_name(space: ast::AtomSpace) -> &'static str {
+fn ptx_space_name(space: ast::StateSpace) -> &'static str {
match space {
- ast::AtomSpace::Generic => "generic",
- ast::AtomSpace::Global => "global",
- ast::AtomSpace::Shared => "shared",
+ ast::StateSpace::Generic => "generic",
+ ast::StateSpace::Global => "global",
+ ast::StateSpace::Shared => "shared",
+ ast::StateSpace::Reg => "reg",
+ ast::StateSpace::Const => "const",
+ ast::StateSpace::Local => "local",
+ ast::StateSpace::Param => "param",
+ ast::StateSpace::Sreg => "sreg",
}
}
@@ -3572,6 +3583,7 @@ fn emit_variable(
ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup),
ast::StateSpace::Const => todo!(),
ast::StateSpace::Generic => todo!(),
+ ast::StateSpace::Sreg => todo!(),
};
let initalizer = if var.array_init.len() > 0 {
Some(map.get_or_add_constant(
@@ -3580,17 +3592,14 @@ fn emit_variable(
&*var.array_init,
)?)
} else if must_init {
- let type_id = map.get_or_add(
- builder,
- SpirvType::from(ast::Type::from(var.v_type.clone())),
- );
+ let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone(), var.state_space));
Some(builder.constant_null(type_id, None))
} else {
None
};
let ptr_type_id = map.get_or_add(
builder,
- SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class),
+ SpirvType::pointer_to(var.v_type.clone(), var.state_space, st_class),
);
builder.variable(ptr_type_id, Some(var.name), st_class, initalizer);
if let Some(align) = var.align {
@@ -3729,7 +3738,10 @@ fn emit_min(
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin,
};
- let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
+ let inst_type = map.get_or_add(
+ builder,
+ SpirvType::new(desc.get_type(), ast::StateSpace::Reg),
+ );
builder.ext_inst(
inst_type,
Some(arg.dst),
@@ -3754,7 +3766,10 @@ fn emit_max(
ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max,
ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax,
};
- let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type()));
+ let inst_type = map.get_or_add(
+ builder,
+ SpirvType::new(desc.get_type(), ast::StateSpace::Reg),
+ );
builder.ext_inst(
inst_type,
Some(arg.dst),
@@ -3865,11 +3880,13 @@ fn emit_cvt(
let cv = ImplicitConversion {
src: arg.src,
dst: new_dst,
- from: ast::Type::Scalar(src_t),
- to: ast::Type::Scalar(ast::ScalarType::from_parts(
+ from_type: ast::Type::Scalar(src_t),
+ from_space: ast::StateSpace::Reg,
+ to_type: ast::Type::Scalar(ast::ScalarType::from_parts(
dest_t.size_of(),
src_t.kind(),
)),
+ to_space: ast::StateSpace::Reg,
kind: ConversionKind::Default,
src_sema: ArgumentSemantics::Default,
dst_sema: ArgumentSemantics::Default,
@@ -4224,20 +4241,24 @@ fn emit_implicit_conversion(
map: &mut TypeWordMap,
cv: &ImplicitConversion,
) -> Result<(), TranslateError> {
- let from_parts = cv.from.to_parts();
- let to_parts = cv.to.to_parts();
+ let from_parts = cv.from_type.to_parts();
+ let to_parts = cv.to_type.to_parts();
match (from_parts.kind, to_parts.kind, cv.kind) {
(_, _, ConversionKind::PtrToBit(typ)) => {
let dst_type = map.get_or_add_scalar(builder, typ.into());
builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?;
}
- (_, _, ConversionKind::BitToPtr(_)) => {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ (_, _, ConversionKind::BitToPtr) => {
+ let dst_type = map.get_or_add(
+ builder,
+ SpirvType::pointer_to(cv.to_type.clone(), cv.from_space, cv.to_space.to_spirv()),
+ );
builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?;
}
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => {
if from_parts.width == to_parts.width {
- let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ let dst_type =
+ map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space));
if from_parts.scalar_kind != ast::ScalarKind::Float
&& to_parts.scalar_kind != ast::ScalarKind::Float
{
@@ -4247,13 +4268,16 @@ fn emit_implicit_conversion(
builder.bitcast(dst_type, Some(cv.dst), cv.src)?;
}
} else {
- // This block is safe because it's illegal to implictly convert between floating point instructions
+ // This block is safe because it's illegal to implictly convert between floating point values
let same_width_bit_type = map.get_or_add(
builder,
- SpirvType::from(ast::Type::from_parts(TypeParts {
- scalar_kind: ast::ScalarKind::Bit,
- ..from_parts
- })),
+ SpirvType::new(
+ ast::Type::from_parts(TypeParts {
+ scalar_kind: ast::ScalarKind::Bit,
+ ..from_parts
+ }),
+ cv.from_space,
+ ),
);
let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?;
let wide_bit_type = ast::Type::from_parts(TypeParts {
@@ -4261,7 +4285,7 @@ fn emit_implicit_conversion(
..to_parts
});
let wide_bit_type_spirv =
- map.get_or_add(builder, SpirvType::from(wide_bit_type.clone()));
+ map.get_or_add(builder, SpirvType::new(wide_bit_type.clone(), cv.to_space));
if to_parts.scalar_kind == ast::ScalarKind::Unsigned
|| to_parts.scalar_kind == ast::ScalarKind::Bit
{
@@ -4282,8 +4306,10 @@ fn emit_implicit_conversion(
&ImplicitConversion {
src: wide_bit_value,
dst: cv.dst,
- from: wide_bit_type,
- to: cv.to.clone(),
+ from_type: wide_bit_type,
+ from_space: new_todo!(),
+ to_type: cv.to_type.clone(),
+ to_space: new_todo!(),
kind: ConversionKind::Default,
src_sema: cv.src_sema,
dst_sema: cv.dst_sema,
@@ -4293,13 +4319,15 @@ fn emit_implicit_conversion(
}
}
(TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => {
- let result_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ let result_type =
+ map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space));
builder.s_convert(result_type, Some(cv.dst), cv.src)?;
}
(TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default)
| (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default)
| (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => {
- let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone()));
+ let into_type =
+ map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space));
builder.bitcast(into_type, Some(cv.dst), cv.src)?;
}
(_, _, ConversionKind::PtrToPtr { spirv_ptr }) => {
@@ -4307,12 +4335,12 @@ fn emit_implicit_conversion(
map.get_or_add(
builder,
SpirvType::Pointer(
- Box::new(SpirvType::from(cv.to.clone())),
+ Box::new(SpirvType::new(cv.to_type.clone(), cv.to_space)),
spirv::StorageClass::Function,
),
)
} else {
- map.get_or_add(builder, SpirvType::from(cv.to.clone()))
+ map.get_or_add(builder, SpirvType::new(cv.to_type.clone(), cv.to_space))
};
builder.bitcast(result_type, Some(cv.dst), cv.src)?;
}
@@ -4326,14 +4354,18 @@ fn emit_load_var(
map: &mut TypeWordMap,
details: &LoadVarDetails,
) -> Result<(), TranslateError> {
- let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone()));
+ let result_type = map.get_or_add(
+ builder,
+ SpirvType::new(details.typ.clone(), details.state_space),
+ );
match details.member_index {
Some((index, Some(width))) => {
let vector_type = match details.typ {
ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width),
_ => return Err(TranslateError::MismatchedType),
};
- let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type));
+ let vector_type_spirv =
+ map.get_or_add(builder, SpirvType::new(vector_type, details.state_space));
let vector_temp = builder.load(
vector_type_spirv,
None,
@@ -4351,7 +4383,11 @@ fn emit_load_var(
Some((index, None)) => {
let result_ptr_type = map.get_or_add(
builder,
- SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function),
+ SpirvType::pointer_to(
+ details.typ.clone(),
+ details.state_space,
+ spirv::StorageClass::Function,
+ ),
);
let index_spirv = map.get_or_add_constant(
builder,
@@ -4433,18 +4469,25 @@ fn expand_map_variables<'a, 'b>(
is_variable = true;
var_type
} else {
- var_type.param_pointer_to(ast::LdStateSpace::Shared)?
+ var_type.param_pointer_to(ast::StateSpace::Shared)?
}
}
- ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?,
- ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?,
- ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?,
- ast::StateSpace::Const => todo!(),
- ast::StateSpace::Generic => todo!(),
+ ast::StateSpace::Global => var_type.param_pointer_to(ast::StateSpace::Global)?,
+ ast::StateSpace::Param => var_type.param_pointer_to(ast::StateSpace::Param)?,
+ ast::StateSpace::Local => var_type.param_pointer_to(ast::StateSpace::Local)?,
+ ast::StateSpace::Const => new_todo!(),
+ ast::StateSpace::Generic => new_todo!(),
+ ast::StateSpace::Sreg => new_todo!(),
};
match var.count {
Some(count) => {
- for new_id in id_defs.add_defs(var.var.name, count, var_type, is_variable) {
+ for new_id in id_defs.add_defs(
+ var.var.name,
+ count,
+ var_type,
+ var.var.state_space,
+ is_variable,
+ ) {
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
@@ -4455,7 +4498,11 @@ fn expand_map_variables<'a, 'b>(
}
}
None => {
- let new_id = id_defs.add_def(var.var.name, Some(var_type), is_variable);
+ let new_id = id_defs.add_def(
+ var.var.name,
+ Some((var_type, var.var.state_space)),
+ is_variable,
+ );
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type.clone(),
@@ -4470,11 +4517,42 @@ fn expand_map_variables<'a, 'b>(
Ok(())
}
+/*
+ Our goal here is to transform
+ .visible .entry foobar(.param .u64 input) {
+ .reg .b64 in_addr;
+ .reg .b64 in_addr2;
+ ld.param.u64 in_addr, [input];
+ cvta.to.global.u64 in_addr2, in_addr;
+ }
+ into:
+ .visible .entry foobar(.param .u8 input[]) {
+ .reg .u8 in_addr[];
+ .reg .u8 in_addr2[];
+ ld.param.u8[] in_addr, [input];
+ mov.u8[] in_addr2, in_addr;
+ }
+ or:
+ .visible .entry foobar(.reg .u8 input[]) {
+ .reg .u8 in_addr[];
+ .reg .u8 in_addr2[];
+ mov.u8[] in_addr, input;
+ mov.u8[] in_addr2, in_addr;
+ }
+ or:
+ .visible .entry foobar(.param ptr<u8, global> input) {
+ .reg ptr<u8, global> in_addr;
+ .reg ptr<u8, global> in_addr2;
+ ld.param.ptr<u8, global> in_addr, [input];
+ mov.ptr<u8, global> in_addr2, in_addr;
+ }
+*/
// TODO: detect more patterns (mov, call via reg, call via param)
// TODO: don't convert to ptr if the register is not ultimately used for ld/st
// TODO: once insert_mem_ssa_statements is moved to later, move this pass after
// argument expansion
// TODO: propagate through calls?
+/*
fn convert_to_stateful_memory_access<'a>(
func_args: &mut SpirvMethodDecl,
func_body: Vec<TypedStatement>,
@@ -4496,9 +4574,9 @@ fn convert_to_stateful_memory_access<'a>(
match statement {
Statement::Instruction(ast::Instruction::Cvta(
ast::CvtaDetails {
- to: ast::CvtaStateSpace::Global,
+ to: ast::StateSpace::Global,
size: ast::CvtaSize::U64,
- from: ast::CvtaStateSpace::Generic,
+ from: ast::StateSpace::Generic,
},
arg,
)) => {
@@ -4512,24 +4590,24 @@ fn convert_to_stateful_memory_access<'a>(
}
Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::PointerType::Scalar(ast::ScalarType::U64),
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::U64),
..
},
arg,
))
| Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::PointerType::Scalar(ast::ScalarType::S64),
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::S64),
..
},
arg,
))
| Statement::Instruction(ast::Instruction::Ld(
ast::LdDetails {
- state_space: ast::LdStateSpace::Param,
- typ: ast::PointerType::Scalar(ast::ScalarType::B64),
+ state_space: ast::StateSpace::Param,
+ typ: ast::Type::Scalar(ast::ScalarType::B64),
..
},
arg,
@@ -4611,19 +4689,16 @@ fn convert_to_stateful_memory_access<'a>(
let mut remapped_ids = HashMap::new();
let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len());
for reg in regs_ptr_seen {
- let new_id = id_defs.new_variable(ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ));
+ let new_id = id_defs.register_variable(
+ ast::Type::Pointer(ast::ScalarType::U8),
+ ast::StateSpace::Global,
+ );
result.push(Statement::Variable(ast::Variable {
align: None,
name: new_id,
array_init: Vec::new(),
- v_type: ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- ),
- state_space: ast::StateSpace::Reg,
+ v_type: ast::Type::Pointer(ast::ScalarType::U8),
+ state_space: ast::StateSpace::Global,
}));
remapped_ids.insert(reg, new_id);
}
@@ -4658,8 +4733,8 @@ fn convert_to_stateful_memory_access<'a>(
};
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
- state_space: ast::LdStateSpace::Global,
+ underlying_type: ast::ScalarType::U8,
+ state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: offset,
@@ -4686,7 +4761,7 @@ fn convert_to_stateful_memory_access<'a>(
_ => return Err(error_unreachable()),
};
let offset_neg =
- id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64)));
+ id_defs.register_intermediate(Some(ast::Type::Scalar(ast::ScalarType::S64)));
result.push(Statement::Instruction(ast::Instruction::Neg(
ast::NegDetails {
typ: ast::ScalarType::S64,
@@ -4699,8 +4774,8 @@ fn convert_to_stateful_memory_access<'a>(
)));
let dst = arg.dst.upcast().unwrap_reg()?;
result.push(Statement::PtrAccess(PtrAccess {
- underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8),
- state_space: ast::LdStateSpace::Global,
+ underlying_type: ast::ScalarType::U8,
+ state_space: ast::StateSpace::Global,
dst: *remapped_ids.get(&dst).unwrap(),
ptr_src: *ptr,
offset_src: TypedOperand::Reg(offset_neg),
@@ -4768,10 +4843,8 @@ fn convert_to_stateful_memory_access<'a>(
}
for arg in func_args.input.iter_mut() {
if func_args_ptr.contains(&arg.name) {
- arg.v_type = ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
- );
+ arg.v_type = ast::Type::Pointer(ast::ScalarType::U8);
+ arg.state_space = ast::StateSpace::Global;
}
}
Ok(result)
@@ -4790,21 +4863,21 @@ fn convert_to_stateful_memory_access_postprocess(
Some(new_id) => {
// We skip conversion here to trigger PtrAcces in a later pass
let old_type = match expected_type {
- Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
+ Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id),
_ => id_defs.get_typed(arg_desc.op)?.0,
};
let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type_clone));
+ let converting_id = id_defs.register_intermediate(Some(old_type_clone));
if arg_desc.is_dst {
post_statements.push(Statement::Conversion(ImplicitConversion {
src: converting_id,
dst: *new_id,
- from: old_type,
- to: ast::Type::Pointer(
+ from_type: old_type,
+ to_type: ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
+ ast::StateSpace::Global,
),
- kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global),
+ kind: ConversionKind::BitToPtr(ast::StateSpace::Global),
src_sema: ArgumentSemantics::Default,
dst_sema: arg_desc.sema,
}));
@@ -4813,11 +4886,11 @@ fn convert_to_stateful_memory_access_postprocess(
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
- from: ast::Type::Pointer(
+ from_type: ast::Type::Pointer(
ast::PointerType::Scalar(ast::ScalarType::U8),
- ast::LdStateSpace::Global,
+ ast::StateSpace::Global,
),
- to: old_type,
+ to_type: old_type,
kind: ConversionKind::PtrToBit(ast::ScalarType::U64),
src_sema: arg_desc.sema,
dst_sema: ArgumentSemantics::Default,
@@ -4832,19 +4905,19 @@ fn convert_to_stateful_memory_access_postprocess(
}
// We skip conversion here to trigger PtrAcces in a later pass
let old_type = match expected_type {
- Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id),
+ Some(ast::Type::Pointer(_, ast::StateSpace::Global)) => return Ok(*new_id),
_ => id_defs.get_typed(arg_desc.op)?.0,
};
let old_type_clone = old_type.clone();
- let converting_id = id_defs.new_non_variable(Some(old_type));
+ let converting_id = id_defs.register_intermediate(Some(old_type));
result.push(Statement::Conversion(ImplicitConversion {
src: *new_id,
dst: converting_id,
- from: ast::Type::Pointer(
- ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global),
- ast::LdStateSpace::Param,
+ from_type: ast::Type::Pointer(
+ ast::PointerType::Pointer(ast::ScalarType::U8, ast::StateSpace::Global),
+ ast::StateSpace::Param,
),
- to: old_type_clone,
+ to_type: old_type_clone,
kind: ConversionKind::PtrToPtr { spirv_ptr: false },
src_sema: arg_desc.sema,
dst_sema: ArgumentSemantics::Default,
@@ -4855,6 +4928,7 @@ fn convert_to_stateful_memory_access_postprocess(
},
})
}
+*/
fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgParams>) -> bool {
match arg.dst {
@@ -4876,9 +4950,9 @@ fn is_add_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg3<TypedArgP
fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
match id_defs.get_typed(id) {
- Ok((ast::Type::Scalar(ast::ScalarType::U64), _))
- | Ok((ast::Type::Scalar(ast::ScalarType::S64), _))
- | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true,
+ Ok((ast::Type::Scalar(ast::ScalarType::U64), _, _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::S64), _, _))
+ | Ok((ast::Type::Scalar(ast::ScalarType::B64), _, _)) => true,
_ => false,
}
}
@@ -5007,7 +5081,7 @@ impl SpecialRegistersMap {
struct GlobalStringIdResolver<'input> {
current_id: spirv::Word,
variables: HashMap<Cow<'input, str>, spirv::Word>,
- variables_type_check: HashMap<u32, Option<(ast::Type, bool)>>,
+ variables_type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: SpecialRegistersMap,
fns: HashMap<spirv::Word, FnDecl>,
}
@@ -5036,12 +5110,17 @@ impl<'a> GlobalStringIdResolver<'a> {
&mut self,
id: &'a str,
typ: ast::Type,
+ state_space: ast::StateSpace,
is_variable: bool,
) -> spirv::Word {
- self.get_or_add_impl(id, Some((typ, is_variable)))
+ self.get_or_add_impl(id, Some((typ, state_space, is_variable)))
}
- fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word {
+ fn get_or_add_impl(
+ &mut self,
+ id: &'a str,
+ typ: Option<(ast::Type, ast::StateSpace, bool)>,
+ ) -> spirv::Word {
let id = match self.variables.entry(Cow::Borrowed(id)) {
hash_map::Entry::Occupied(e) => *(e.get()),
hash_map::Entry::Vacant(e) => {
@@ -5143,10 +5222,10 @@ impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
struct FnStringIdResolver<'input, 'b> {
current_id: &'b mut spirv::Word,
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
- global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: &'b mut SpecialRegistersMap,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
- type_check: HashMap<u32, Option<(ast::Type, bool)>>,
+ type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
@@ -5184,14 +5263,21 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
}
}
- fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>, is_variable: bool) -> spirv::Word {
+ fn add_def(
+ &mut self,
+ id: &'a str,
+ typ: Option<(ast::Type, ast::StateSpace)>,
+ is_variable: bool,
+ ) -> spirv::Word {
let numeric_id = *self.current_id;
self.variables
.last_mut()
.unwrap()
.insert(Cow::Borrowed(id), numeric_id);
- self.type_check
- .insert(numeric_id, typ.map(|t| (t, is_variable)));
+ self.type_check.insert(
+ numeric_id,
+ typ.map(|(typ, space)| (typ, space, is_variable)),
+ );
*self.current_id += 1;
numeric_id
}
@@ -5202,6 +5288,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
base_id: &'a str,
count: u32,
typ: ast::Type,
+ state_space: ast::StateSpace,
is_variable: bool,
) -> impl Iterator<Item = spirv::Word> {
let numeric_id = *self.current_id;
@@ -5210,8 +5297,10 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
.last_mut()
.unwrap()
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
- self.type_check
- .insert(numeric_id + i, Some((typ.clone(), is_variable)));
+ self.type_check.insert(
+ numeric_id + i,
+ Some((typ.clone(), state_space, is_variable)),
+ );
}
*self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
@@ -5220,8 +5309,8 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
- global_type_check: &'b HashMap<u32, Option<(ast::Type, bool)>>,
- type_check: HashMap<u32, Option<(ast::Type, bool)>>,
+ global_type_check: &'b HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
+ type_check: HashMap<u32, Option<(ast::Type, ast::StateSpace, bool)>>,
special_registers: &'b mut SpecialRegistersMap,
}
@@ -5230,12 +5319,15 @@ impl<'b> NumericIdResolver<'b> {
MutableNumericIdResolver { base: self }
}
- fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, bool), TranslateError> {
+ fn get_typed(
+ &self,
+ id: spirv::Word,
+ ) -> Result<(ast::Type, ast::StateSpace, bool), TranslateError> {
match self.type_check.get(&id) {
Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(id) {
- Some(x) => Ok((x.get_type(), true)),
+ Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
None => match self.global_type_check.get(&id) {
Some(Some(result)) => Ok(result.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol),
@@ -5246,16 +5338,18 @@ impl<'b> NumericIdResolver<'b> {
// This is for identifiers which will be emitted later as OpVariable
// They are candidates for insertion of LoadVar/StoreVar
- fn new_variable(&mut self, typ: ast::Type) -> spirv::Word {
+ fn register_variable(&mut self, typ: ast::Type, state_space: ast::StateSpace) -> spirv::Word {
let new_id = *self.current_id;
- self.type_check.insert(new_id, Some((typ, true)));
+ self.type_check
+ .insert(new_id, Some((typ, state_space, true)));
*self.current_id += 1;
new_id
}
- fn new_non_variable(&mut self, typ: Option<ast::Type>) -> spirv::Word {
+ fn register_intermediate(&mut self, typ: Option<(ast::Type, ast::StateSpace)>) -> spirv::Word {
let new_id = *self.current_id;
- self.type_check.insert(new_id, typ.map(|t| (t, false)));
+ self.type_check
+ .insert(new_id, typ.map(|(t, space)| (t, space, false)));
*self.current_id += 1;
new_id
}
@@ -5270,12 +5364,16 @@ impl<'b> MutableNumericIdResolver<'b> {
self.base
}
- fn get_typed(&self, id: spirv::Word) -> Result<ast::Type, TranslateError> {
- self.base.get_typed(id).map(|(t, _)| t)
+ fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, ast::StateSpace), TranslateError> {
+ self.base.get_typed(id).map(|(t, space, _)| (t, space))
}
- fn new_non_variable(&mut self, typ: ast::Type) -> spirv::Word {
- self.base.new_non_variable(Some(typ))
+ fn register_intermediate(
+ &mut self,
+ typ: ast::Type,
+ state_space: ast::StateSpace,
+ ) -> spirv::Word {
+ self.base.register_intermediate(Some((typ, state_space)))
}
}
@@ -5304,7 +5402,8 @@ impl ExpandedStatement {
Statement::Variable(var)
}
Statement::Instruction(inst) => inst
- .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| {
+ .visit(&mut |arg: ArgumentDescriptor<_>,
+ _: Option<(&ast::Type, ast::StateSpace)>| {
Ok(f(arg.op, arg.is_dst))
})
.unwrap(),
@@ -5391,6 +5490,7 @@ impl ExpandedStatement {
struct LoadVarDetails {
arg: ast::Arg2<ExpandedArgParams>,
typ: ast::Type,
+ state_space: ast::StateSpace,
// (index, vector_width)
// HACK ALERT
// For some reason IGC explodes when you try to load from builtin vectors
@@ -5402,6 +5502,7 @@ struct LoadVarDetails {
struct StoreVarDetails {
arg: ast::Arg2St<ExpandedArgParams>,
typ: ast::Type,
+ state_space: ast::StateSpace,
member_index: Option<u8>,
}
@@ -5428,7 +5529,10 @@ impl RepackVectorDetails {
is_dst: !self.is_extract,
sema: ArgumentSemantics::Default,
},
- Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)),
+ Some((
+ &ast::Type::Vector(self.typ, self.unpacked.len() as u8),
+ ast::StateSpace::Reg,
+ )),
)?;
let scalar_type = self.typ;
let is_extract = self.is_extract;
@@ -5443,7 +5547,7 @@ impl RepackVectorDetails {
is_dst: is_extract,
sema: vector_sema,
},
- Some(&ast::Type::Scalar(scalar_type)),
+ Some((&ast::Type::Scalar(scalar_type), ast::StateSpace::Reg)),
)
})
.collect::<Result<_, _>>()?;
@@ -5501,7 +5605,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
is_dst: space != ast::StateSpace::Param,
sema: space.semantics(),
},
- Some(&typ),
+ Some((&typ, space)),
)?;
Ok((new_id, typ, space))
})
@@ -5525,6 +5629,7 @@ impl<From: ArgParamsEx<Id = spirv::Word>> ResolvedCall<From> {
sema: space.semantics(),
},
&typ,
+ space,
)?;
Ok((new_id, typ, space))
})
@@ -5555,22 +5660,22 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
visitor: &mut V,
) -> Result<PtrAccess<To>, TranslateError> {
let sema = match self.state_space {
- ast::LdStateSpace::Const
- | ast::LdStateSpace::Global
- | ast::LdStateSpace::Shared
- | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer,
- ast::LdStateSpace::Local | ast::LdStateSpace::Param => {
- ArgumentSemantics::RegisterPointer
- }
+ ast::StateSpace::Const
+ | ast::StateSpace::Global
+ | ast::StateSpace::Shared
+ | ast::StateSpace::Generic => ArgumentSemantics::PhysicalPointer,
+ ast::StateSpace::Local | ast::StateSpace::Param => ArgumentSemantics::RegisterPointer,
+ ast::StateSpace::Reg => new_todo!(),
+ ast::StateSpace::Sreg => new_todo!(),
};
- let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space);
+ let ptr_type = ast::Type::Pointer(self.underlying_type.clone());
let new_dst = visitor.id(
ArgumentDescriptor {
op: self.dst,
is_dst: true,
sema,
},
- Some(&ptr_type),
+ Some((&ptr_type, self.state_space)),
)?;
let new_ptr_src = visitor.id(
ArgumentDescriptor {
@@ -5578,7 +5683,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
is_dst: false,
sema,
},
- Some(&ptr_type),
+ Some((&ptr_type, self.state_space)),
)?;
let new_constant_src = visitor.operand(
ArgumentDescriptor {
@@ -5587,6 +5692,7 @@ impl<P: ArgParamsEx<Id = spirv::Word>> PtrAccess<P> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(ast::ScalarType::S64),
+ self.state_space,
)?;
Ok(PtrAccess {
underlying_type: self.underlying_type,
@@ -5723,12 +5829,13 @@ pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
fn id(
&mut self,
desc: ArgumentDescriptor<T::Id>,
- typ: Option<&ast::Type>,
+ typ: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<U::Id, TranslateError>;
fn operand(
&mut self,
desc: ArgumentDescriptor<T::Operand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<U::Operand, TranslateError>;
}
@@ -5736,13 +5843,13 @@ impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
+ Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -5751,8 +5858,9 @@ where
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<spirv::Word, TranslateError> {
- self(desc, Some(typ))
+ self(desc, Some((typ, state_space)))
}
}
@@ -5763,7 +5871,7 @@ where
fn id(
&mut self,
desc: ArgumentDescriptor<&str>,
- _: Option<&ast::Type>,
+ _: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self(desc.op)
}
@@ -5772,6 +5880,7 @@ where
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<ast::Operand<spirv::Word>, TranslateError> {
Ok(match desc.op {
ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?),
@@ -5780,7 +5889,7 @@ where
ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member),
ast::Operand::VecPack(ref ids) => ast::Operand::VecPack(
ids.into_iter()
- .map(|id| self.id(desc.new_op(id), Some(typ)))
+ .map(|id| self.id(desc.new_op(id), Some((typ, state_space))))
.collect::<Result<Vec<_>, _>>()?,
),
})
@@ -5794,8 +5903,8 @@ pub struct ArgumentDescriptor<Op> {
}
pub struct PtrAccess<P: ast::ArgParams> {
- underlying_type: ast::PointerType,
- state_space: ast::LdStateSpace,
+ underlying_type: ast::ScalarType,
+ state_space: ast::StateSpace,
dst: spirv::Word,
ptr_src: spirv::Word,
offset_src: P::Operand,
@@ -6061,7 +6170,7 @@ impl ImplicitConversion {
is_dst: true,
sema: self.dst_sema,
},
- Some(&self.to),
+ Some((&self.to_type, self.to_space)),
)?;
let new_src = visitor.id(
ArgumentDescriptor {
@@ -6069,7 +6178,7 @@ impl ImplicitConversion {
is_dst: false,
sema: self.src_sema,
},
- Some(&self.from),
+ Some((&self.from_type, self.from_space)),
)?;
Ok(Statement::Conversion({
ImplicitConversion {
@@ -6096,13 +6205,13 @@ impl<T> ArgumentMapVisitor<TypedArgParams, TypedArgParams> for T
where
T: FnMut(
ArgumentDescriptor<spirv::Word>,
- Option<&ast::Type>,
+ Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError>,
{
fn id(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -6111,12 +6220,15 @@ where
&mut self,
desc: ArgumentDescriptor<TypedOperand>,
typ: &ast::Type,
+ state_space: ast::StateSpace,
) -> Result<TypedOperand, TranslateError> {
Ok(match desc.op {
- TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?),
+ TypedOperand::Reg(id) => {
+ TypedOperand::Reg(self(desc.new_op(id), Some((typ, state_space)))?)
+ }
TypedOperand::Imm(imm) => TypedOperand::Imm(imm),
TypedOperand::RegOffset(id, imm) => {
- TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm)
+ TypedOperand::RegOffset(self(desc.new_op(id), Some((typ, state_space)))?, imm)
}
TypedOperand::VecMember(reg, index) => {
let scalar_type = match typ {
@@ -6124,7 +6236,10 @@ where
_ => return Err(error_unreachable()),
};
let vec_type = ast::Type::Vector(scalar_type, index + 1);
- TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index)
+ TypedOperand::VecMember(
+ self(desc.new_op(reg), Some((&vec_type, state_space)))?,
+ index,
+ )
}
})
}
@@ -6159,54 +6274,25 @@ impl ast::Type {
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
- state_space: ast::LdStateSpace::Global,
},
ast::Type::Vector(scalar, components) => TypeParts {
kind: TypeKind::Vector,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: vec![*components as u32],
- state_space: ast::LdStateSpace::Global,
},
ast::Type::Array(scalar, components) => TypeParts {
kind: TypeKind::Array,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: components.clone(),
- state_space: ast::LdStateSpace::Global,
},
- ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts {
+ ast::Type::Pointer(scalar) => TypeParts {
kind: TypeKind::PointerScalar,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
- state_space: *state_space,
- },
- ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts {
- kind: TypeKind::PointerVector,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: vec![*len as u32],
- state_space: *state_space,
},
- ast::Type::Pointer(ast::PointerType::Array(scalar, components), state_space) => {
- TypeParts {
- kind: TypeKind::PointerArray,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: components.clone(),
- state_space: *state_space,
- }
- }
- ast::Type::Pointer(ast::PointerType::Pointer(scalar, inner_space), state_space) => {
- TypeParts {
- kind: TypeKind::PointerPointer,
- scalar_kind: scalar.kind(),
- width: scalar.size_of(),
- components: vec![*inner_space as u32],
- state_space: *state_space,
- }
- }
}
}
@@ -6223,31 +6309,9 @@ impl ast::Type {
ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components,
),
- TypeKind::PointerScalar => ast::Type::Pointer(
- ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)),
- t.state_space,
- ),
- TypeKind::PointerVector => ast::Type::Pointer(
- ast::PointerType::Vector(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- t.components[0] as u8,
- ),
- t.state_space,
- ),
- TypeKind::PointerArray => ast::Type::Pointer(
- ast::PointerType::Array(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- t.components,
- ),
- t.state_space,
- ),
- TypeKind::PointerPointer => ast::Type::Pointer(
- ast::PointerType::Pointer(
- ast::ScalarType::from_parts(t.width, t.scalar_kind),
- unsafe { mem::transmute::<_, ast::LdStateSpace>(t.components[0] as u8) },
- ),
- t.state_space,
- ),
+ TypeKind::PointerScalar => {
+ ast::Type::Pointer(ast::ScalarType::from_parts(t.width, t.scalar_kind))
+ }
}
}
@@ -6258,7 +6322,7 @@ impl ast::Type {
ast::Type::Array(typ, len) => len
.iter()
.fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)),
- ast::Type::Pointer(_, _) => mem::size_of::<usize>(),
+ ast::Type::Pointer(..) => mem::size_of::<usize>(),
}
}
}
@@ -6269,7 +6333,6 @@ struct TypeParts {
scalar_kind: ast::ScalarKind,
width: u8,
components: Vec<u32>,
- state_space: ast::LdStateSpace,
}
#[derive(Eq, PartialEq, Copy, Clone)]
@@ -6278,9 +6341,6 @@ enum TypeKind {
Vector,
Array,
PointerScalar,
- PointerVector,
- PointerArray,
- PointerPointer,
}
impl ast::Instruction<ExpandedArgParams> {
@@ -6408,8 +6468,10 @@ struct BrachCondition {
struct ImplicitConversion {
src: spirv::Word,
dst: spirv::Word,
- from: ast::Type,
- to: ast::Type,
+ from_type: ast::Type,
+ to_type: ast::Type,
+ from_space: ast::StateSpace,
+ to_space: ast::StateSpace,
kind: ConversionKind,
src_sema: ArgumentSemantics,
dst_sema: ArgumentSemantics,
@@ -6420,7 +6482,7 @@ enum ConversionKind {
Default,
// zero-extend/chop/bitcast depending on types
SignExtend,
- BitToPtr(ast::LdStateSpace),
+ BitToPtr,
PtrToBit(ast::ScalarType),
PtrToPtr { spirv_ptr: bool },
}
@@ -6470,7 +6532,7 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- t: Option<&ast::Type>,
+ t: Option<(&ast::Type, ast::StateSpace)>,
) -> Result<ast::Arg1<U>, TranslateError> {
let new_src = visitor.id(
ArgumentDescriptor {
@@ -6496,6 +6558,7 @@ impl<T: ArgParamsEx> ast::Arg1Bar<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg1Bar { src: new_src })
}
@@ -6514,6 +6577,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
sema: ArgumentSemantics::Default,
},
t,
+ ast::StateSpace::Reg,
)?;
let new_src = visitor.operand(
ArgumentDescriptor {
@@ -6522,6 +6586,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
sema: ArgumentSemantics::Default,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2 {
dst: new_dst,
@@ -6542,6 +6607,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
sema: ArgumentSemantics::Default,
},
dst_t,
+ ast::StateSpace::Reg,
)?;
let src = visitor.operand(
ArgumentDescriptor {
@@ -6550,6 +6616,7 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
sema: ArgumentSemantics::Default,
},
src_t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2 { dst, src })
}
@@ -6568,9 +6635,10 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
sema: ArgumentSemantics::DefaultRelaxed,
},
&ast::Type::from(details.typ.clone()),
+ ast::StateSpace::Reg,
)?;
- let is_logical_ptr = details.state_space == ast::LdStateSpace::Param
- || details.state_space == ast::LdStateSpace::Local;
+ let is_logical_ptr = details.state_space == ast::StateSpace::Param
+ || details.state_space == ast::StateSpace::Local;
let src = visitor.operand(
ArgumentDescriptor {
op: self.src,
@@ -6581,10 +6649,8 @@ impl<T: ArgParamsEx> ast::Arg2Ld<T> {
ArgumentSemantics::PhysicalPointer
},
},
- &ast::Type::Pointer(
- ast::PointerType::from(details.typ.clone()),
- details.state_space,
- ),
+ &details.typ,
+ details.state_space,
)?;
Ok(ast::Arg2Ld { dst, src })
}
@@ -6596,8 +6662,8 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
visitor: &mut V,
details: &ast::StData,
) -> Result<ast::Arg2St<U>, TranslateError> {
- let is_logical_ptr = details.state_space == ast::StStateSpace::Param
- || details.state_space == ast::StStateSpace::Local;
+ let is_logical_ptr = details.state_space == ast::StateSpace::Param
+ || details.state_space == ast::StateSpace::Local;
let src1 = visitor.operand(
ArgumentDescriptor {
op: self.src1,
@@ -6608,10 +6674,8 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
ArgumentSemantics::PhysicalPointer
},
},
- &ast::Type::Pointer(
- ast::PointerType::from(details.typ.clone()),
- details.state_space.to_ld_ss(),
- ),
+ &details.typ,
+ details.state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@@ -6620,6 +6684,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
sema: ArgumentSemantics::DefaultRelaxed,
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2St { src1, src2 })
}
@@ -6638,6 +6703,7 @@ impl<T: ArgParamsEx> ast::Arg2Mov<T> {
sema: ArgumentSemantics::Default,
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
let src = visitor.operand(
ArgumentDescriptor {
@@ -6650,6 +6716,7 @@ impl<T: ArgParamsEx> ast::Arg2Mov<T> {
},
},
&details.typ.clone().into(),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg2Mov { dst, src })
}
@@ -6674,6 +6741,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
sema: ArgumentSemantics::Default,
},
wide_type.as_ref().unwrap_or(typ),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6682,6 +6750,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
sema: ArgumentSemantics::Default,
},
typ,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@@ -6690,6 +6759,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
sema: ArgumentSemantics::Default,
},
typ,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6706,6 +6776,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
sema: ArgumentSemantics::Default,
},
t,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6714,6 +6785,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
sema: ArgumentSemantics::Default,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@@ -6722,6 +6794,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6730,7 +6803,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
self,
visitor: &mut V,
t: ast::ScalarType,
- state_space: ast::AtomSpace,
+ state_space: ast::StateSpace,
) -> Result<ast::Arg3<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
let dst = visitor.operand(
@@ -6740,6 +6813,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6747,10 +6821,8 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
is_dst: false,
sema: ArgumentSemantics::PhysicalPointer,
},
- &ast::Type::Pointer(
- ast::PointerType::Scalar(scalar_type),
- state_space.to_ld_ss(),
- ),
+ &ast::Type::Scalar(scalar_type),
+ state_space,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@@ -6759,6 +6831,7 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg3 { dst, src1, src2 })
}
@@ -6783,6 +6856,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
wide_type.as_ref().unwrap_or(t),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6791,6 +6865,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@@ -6799,6 +6874,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
t,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
@@ -6807,6 +6883,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6828,6 +6905,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6836,6 +6914,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@@ -6844,6 +6923,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(t.into()),
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
@@ -6852,6 +6932,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6865,7 +6946,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
self,
visitor: &mut V,
t: ast::ScalarType,
- state_space: ast::AtomSpace,
+ state_space: ast::StateSpace,
) -> Result<ast::Arg4<U>, TranslateError> {
let scalar_type = ast::ScalarType::from(t);
let dst = visitor.operand(
@@ -6875,6 +6956,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6882,10 +6964,8 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
is_dst: false,
sema: ArgumentSemantics::PhysicalPointer,
},
- &ast::Type::Pointer(
- ast::PointerType::Scalar(scalar_type),
- state_space.to_ld_ss(),
- ),
+ &ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@@ -6894,6 +6974,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
@@ -6902,6 +6983,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(scalar_type),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6923,6 +7005,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
typ,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -6931,6 +7014,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
typ,
+ ast::StateSpace::Reg,
)?;
let u32_type = ast::Type::Scalar(ast::ScalarType::U32);
let src2 = visitor.operand(
@@ -6940,6 +7024,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
&u32_type,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
@@ -6948,6 +7033,7 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
sema: ArgumentSemantics::Default,
},
&u32_type,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4 {
dst,
@@ -6970,7 +7056,10 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)?;
let dst2 = self
.dst2
@@ -6981,7 +7070,10 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)
})
.transpose()?;
@@ -6992,6 +7084,7 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
sema: ArgumentSemantics::Default,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@@ -7000,6 +7093,7 @@ impl<T: ArgParamsEx> ast::Arg4Setp<T> {
sema: ArgumentSemantics::Default,
},
t,
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg4Setp {
dst1,
@@ -7023,6 +7117,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
sema: ArgumentSemantics::Default,
},
base_type,
+ ast::StateSpace::Reg,
)?;
let src1 = visitor.operand(
ArgumentDescriptor {
@@ -7031,6 +7126,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
sema: ArgumentSemantics::Default,
},
base_type,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@@ -7039,6 +7135,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
sema: ArgumentSemantics::Default,
},
base_type,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
@@ -7047,6 +7144,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
let src4 = visitor.operand(
ArgumentDescriptor {
@@ -7055,6 +7153,7 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(ast::ScalarType::U32),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg5 {
dst,
@@ -7078,7 +7177,10 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)?;
let dst2 = self
.dst2
@@ -7089,7 +7191,10 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
is_dst: true,
sema: ArgumentSemantics::Default,
},
- Some(&ast::Type::Scalar(ast::ScalarType::Pred)),
+ Some((
+ &ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
+ )),
)
})
.transpose()?;
@@ -7100,6 +7205,7 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
sema: ArgumentSemantics::Default,
},
t,
+ ast::StateSpace::Reg,
)?;
let src2 = visitor.operand(
ArgumentDescriptor {
@@ -7108,6 +7214,7 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
sema: ArgumentSemantics::Default,
},
t,
+ ast::StateSpace::Reg,
)?;
let src3 = visitor.operand(
ArgumentDescriptor {
@@ -7116,6 +7223,7 @@ impl<T: ArgParamsEx> ast::Arg5Setp<T> {
sema: ArgumentSemantics::Default,
},
&ast::Type::Scalar(ast::ScalarType::Pred),
+ ast::StateSpace::Reg,
)?;
Ok(ast::Arg5Setp {
dst1,
@@ -7153,18 +7261,6 @@ impl ast::Operand<spirv::Word> {
}
}
-impl ast::StStateSpace {
- fn to_ld_ss(self) -> ast::LdStateSpace {
- match self {
- ast::StStateSpace::Generic => ast::LdStateSpace::Generic,
- ast::StStateSpace::Global => ast::LdStateSpace::Global,
- ast::StStateSpace::Local => ast::LdStateSpace::Local,
- ast::StStateSpace::Param => ast::LdStateSpace::Param,
- ast::StStateSpace::Shared => ast::LdStateSpace::Shared,
- }
- }
-}
-
impl ast::ScalarType {
fn from_parts(width: u8, kind: ast::ScalarKind) -> Self {
match kind {
@@ -7255,15 +7351,17 @@ impl ast::AtomInnerDetails {
}
}
-impl ast::LdStateSpace {
+impl ast::StateSpace {
fn to_spirv(self) -> spirv::StorageClass {
match self {
- ast::LdStateSpace::Const => spirv::StorageClass::UniformConstant,
- ast::LdStateSpace::Generic => spirv::StorageClass::Generic,
- ast::LdStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
- ast::LdStateSpace::Local => spirv::StorageClass::Function,
- ast::LdStateSpace::Shared => spirv::StorageClass::Workgroup,
- ast::LdStateSpace::Param => spirv::StorageClass::Function,
+ ast::StateSpace::Const => spirv::StorageClass::UniformConstant,
+ ast::StateSpace::Generic => spirv::StorageClass::Generic,
+ ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup,
+ ast::StateSpace::Local => spirv::StorageClass::Function,
+ ast::StateSpace::Shared => spirv::StorageClass::Workgroup,
+ ast::StateSpace::Param => spirv::StorageClass::Function,
+ ast::StateSpace::Reg => spirv::StorageClass::Function,
+ ast::StateSpace::Sreg => spirv::StorageClass::Input,
}
}
}
@@ -7289,16 +7387,6 @@ impl ast::MulDetails {
}
}
-impl ast::AtomSpace {
- fn to_ld_ss(self) -> ast::LdStateSpace {
- match self {
- ast::AtomSpace::Generic => ast::LdStateSpace::Generic,
- ast::AtomSpace::Global => ast::LdStateSpace::Global,
- ast::AtomSpace::Shared => ast::LdStateSpace::Shared,
- }
- }
-}
-
impl ast::MemScope {
fn to_spirv(self) -> spirv::Scope {
match self {
@@ -7333,89 +7421,44 @@ impl ast::StateSpace {
fn bitcast_register_pointer(
operand_type: &ast::Type,
+ operand_space: ast::StateSpace,
instr_type: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+ instruction_space: ast::StateSpace,
) -> Result<Option<ConversionKind>, TranslateError> {
- bitcast_physical_pointer(operand_type, instr_type, ss)
+ bitcast_physical_pointer(operand_type, operand_space, instr_type, instruction_space)
}
fn bitcast_physical_pointer(
operand_type: &ast::Type,
- instr_type: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+ operand_space: ast::StateSpace,
+ instruction_type: &ast::Type,
+ instruction_space: ast::StateSpace,
) -> Result<Option<ConversionKind>, TranslateError> {
- match operand_type {
- // array decays to a pointer
- ast::Type::Array(op_scalar_t, _) => {
- if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
- if ss == Some(*instr_space) {
- if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) {
- Ok(None)
- } else {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- }
- } else {
- if ss == Some(ast::LdStateSpace::Generic)
- || *instr_space == ast::LdStateSpace::Generic
- {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- } else {
- Err(TranslateError::MismatchedType)
- }
- }
- } else {
- Err(TranslateError::MismatchedType)
- }
- }
- ast::Type::Scalar(ast::ScalarType::B64)
- | ast::Type::Scalar(ast::ScalarType::U64)
- | ast::Type::Scalar(ast::ScalarType::S64) => {
- if let Some(space) = ss {
- Ok(Some(ConversionKind::BitToPtr(space)))
- } else {
- Err(error_unreachable())
- }
+ if operand_space == instruction_space {
+ if operand_type != instruction_type {
+ Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
+ } else {
+ Ok(None)
}
- ast::Type::Scalar(ast::ScalarType::B32)
- | ast::Type::Scalar(ast::ScalarType::U32)
- | ast::Type::Scalar(ast::ScalarType::S32) => match ss {
- Some(ast::LdStateSpace::Shared)
- | Some(ast::LdStateSpace::Generic)
- | Some(ast::LdStateSpace::Param)
- | Some(ast::LdStateSpace::Local) => {
- Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared)))
- }
+ } else {
+ match operand_space {
+ ast::StateSpace::Reg | ast::StateSpace::Sreg => match instruction_space {
+ ast::StateSpace::Generic
+ | ast::StateSpace::Global
+ | ast::StateSpace::Shared
+ | ast::StateSpace::Local => Ok(Some(ConversionKind::BitToPtr)),
+ _ => Err(TranslateError::MismatchedType),
+ },
_ => Err(TranslateError::MismatchedType),
- },
- ast::Type::Pointer(op_scalar_t, op_space) => {
- if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type {
- if op_space == instr_space {
- if op_scalar_t == instr_scalar_t {
- Ok(None)
- } else {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- }
- } else {
- if *op_space == ast::LdStateSpace::Generic
- || *instr_space == ast::LdStateSpace::Generic
- {
- Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false }))
- } else {
- Err(TranslateError::MismatchedType)
- }
- }
- } else {
- Err(TranslateError::MismatchedType)
- }
}
- _ => Err(TranslateError::MismatchedType),
}
}
fn force_bitcast_ptr_to_bit(
_: &ast::Type,
+ _: ast::StateSpace,
instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+ _: ast::StateSpace,
) -> Result<Option<ConversionKind>, TranslateError> {
// TODO: verify this on f32, u16 and the like
if let ast::Type::Scalar(scalar_t) = instr_type {
@@ -7457,11 +7500,12 @@ fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool {
fn should_bitcast_packed(
operand: &ast::Type,
- instr: &ast::Type,
- ss: Option<ast::LdStateSpace>,
+ operand_space: ast::StateSpace,
+ instruction: &ast::Type,
+ instruction_space: ast::StateSpace,
) -> Result<Option<ConversionKind>, TranslateError> {
if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) =
- (operand, instr)
+ (operand, instruction)
{
if scalar.kind() == ast::ScalarKind::Bit
&& scalar.size_of() == (vec_underlying_type.size_of() * vec_len)
@@ -7469,13 +7513,14 @@ fn should_bitcast_packed(
return Ok(Some(ConversionKind::Default));
}
}
- should_bitcast_wrapper(operand, instr, ss)
+ should_bitcast_wrapper(operand, operand_space, instruction, instruction_space)
}
fn should_bitcast_wrapper(
operand: &ast::Type,
+ _: ast::StateSpace,
instr: &ast::Type,
- _: Option<ast::LdStateSpace>,
+ _: ast::StateSpace,
) -> Result<Option<ConversionKind>, TranslateError> {
if instr == operand {
return Ok(None);
@@ -7489,8 +7534,9 @@ fn should_bitcast_wrapper(
fn should_convert_relaxed_src_wrapper(
src_type: &ast::Type,
+ _: ast::StateSpace,
instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+ _: ast::StateSpace,
) -> Result<Option<ConversionKind>, TranslateError> {
if src_type == instr_type {
return Ok(None);
@@ -7552,8 +7598,9 @@ fn should_convert_relaxed_src(
fn should_convert_relaxed_dst_wrapper(
dst_type: &ast::Type,
+ _: ast::StateSpace,
instr_type: &ast::Type,
- _: Option<ast::LdStateSpace>,
+ _: ast::StateSpace,
) -> Result<Option<ConversionKind>, TranslateError> {
if dst_type == instr_type {
return Ok(None);