aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-10-18 14:46:05 +0200
committerAndrzej Janik <[email protected]>2020-10-18 14:46:05 +0200
commit2b3ecc99e3b2a1c0a1989733da17b359e974951c (patch)
tree31ae774ab35c37e61063f8ea30488966c148feb3 /ptx
parent27d25865af2bf51ca55b223e634208234d1a141a (diff)
downloadZLUDA-2b3ecc99e3b2a1c0a1989733da17b359e974951c.tar.gz
ZLUDA-2b3ecc99e3b2a1c0a1989733da17b359e974951c.zip
Implement pass to handle .extern .shared and add parsing code for it
Diffstat (limited to 'ptx')
-rw-r--r--ptx/Cargo.toml1
-rw-r--r--ptx/src/ast.rs84
-rw-r--r--ptx/src/lib.rs3
-rw-r--r--ptx/src/ptx.lalrpop270
-rw-r--r--ptx/src/test/spirv_build/global_extern_array.ptx5
-rw-r--r--ptx/src/test/spirv_build/param_func_array_0.ptx10
-rw-r--r--ptx/src/test/spirv_fail/const_ptr.ptx5
-rw-r--r--ptx/src/test/spirv_fail/global_ptr.ptx5
-rw-r--r--ptx/src/test/spirv_fail/local_ptr.txt12
-rw-r--r--ptx/src/test/spirv_fail/param_entry_array_0.ptx10
-rw-r--r--ptx/src/test/spirv_fail/param_vector.ptx10
-rw-r--r--ptx/src/test/spirv_fail/shared_ptr.ptx5
-rw-r--r--ptx/src/test/spirv_fail/shared_ptr2.ptx13
-rw-r--r--ptx/src/test/spirv_run/extern_shared.ptx24
-rw-r--r--ptx/src/test/spirv_run/extern_shared.spvtxt53
-rw-r--r--ptx/src/test/spirv_run/extern_shared_call.ptx45
-rw-r--r--ptx/src/test/spirv_run/extern_shared_call.spvtxt53
-rw-r--r--ptx/src/test/spirv_run/mod.rs1
-rw-r--r--ptx/src/translate.rs391
19 files changed, 877 insertions, 123 deletions
diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml
index 96ab9d0..409cd1f 100644
--- a/ptx/Cargo.toml
+++ b/ptx/Cargo.toml
@@ -14,6 +14,7 @@ spirv_headers = "~1.4.2"
quick-error = "1.2"
bit-vec = "0.6"
half ="1.6"
+bitflags = "1.2"
[build-dependencies.lalrpop]
version = "0.19"
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index c6510da..1e90eba 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -1,4 +1,5 @@
-use std::{convert::From, mem, num::ParseFloatError, str::FromStr};
+use std::convert::TryInto;
+use std::{convert::From, convert::TryFrom, mem, num::ParseFloatError, str::FromStr};
use std::{marker::PhantomData, num::ParseIntError};
use half::f16;
@@ -22,6 +23,8 @@ quick_error! {
WrongVectorElement {}
MultiArrayVariable {}
ZeroDimensionArray {}
+ ArrayInitalizer {}
+ NonExternPointer {}
}
}
@@ -78,6 +81,21 @@ macro_rules! sub_type {
}
}
}
+
+ impl std::convert::TryFrom<Type> for $type_name {
+ type Error = ();
+
+ #[allow(non_snake_case)]
+ #[allow(unreachable_patterns)]
+ fn try_from(t: Type) -> Result<Self, Self::Error> {
+ match t {
+ $(
+ Type::$variant ( $($field_type),+ ) => Ok($type_name::$variant ( $($field_type.try_into().map_err(|_| ())? ),+ )),
+ )+
+ _ => Err(()),
+ }
+ }
+ }
};
}
@@ -98,14 +116,39 @@ sub_type! {
}
}
+impl TryFrom<VariableGlobalType> for VariableLocalType {
+ type Error = PtxError;
+
+ fn try_from(value: VariableGlobalType) -> Result<Self, Self::Error> {
+ match value {
+ VariableGlobalType::Scalar(t) => Ok(VariableLocalType::Scalar(t)),
+ VariableGlobalType::Vector(t, len) => Ok(VariableLocalType::Vector(t, len)),
+ VariableGlobalType::Array(t, len) => Ok(VariableLocalType::Array(t, len)),
+ VariableGlobalType::Pointer(_, _) => Err(PtxError::ZeroDimensionArray),
+ }
+ }
+}
+
+sub_type! {
+ VariableGlobalType {
+ Scalar(SizedScalarType),
+ Vector(SizedScalarType, u8),
+ Array(SizedScalarType, VecU32),
+ Pointer(SizedScalarType, PointerStateSpace),
+ }
+}
+
// For some weird reson this is illegal:
// .param .f16x2 foobar;
// but this is legal:
// .param .f16x2 foobar[1];
+// even more interestingly this is legal, but only in .func (not in .entry):
+// .param .b32 foobar[]
sub_type! {
VariableParamType {
Scalar(ParamScalarType),
Array(SizedScalarType, VecU32),
+ Pointer(SizedScalarType, PointerStateSpace),
}
}
@@ -193,7 +236,7 @@ pub enum MethodDecl<'a, ID> {
}
pub type FnArgument<ID> = Variable<FnArgumentType, ID>;
-pub type KernelArgument<ID> = Variable<VariableParamType, ID>;
+pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>;
pub struct Function<'a, ID, S> {
pub func_directive: MethodDecl<'a, ID>,
@@ -206,6 +249,12 @@ pub type ParsedFunction<'a> = Function<'a, &'a str, Statement<ParsedArgParams<'a
pub enum FnArgumentType {
Reg(VariableRegType),
Param(VariableParamType),
+ Shared,
+}
+#[derive(PartialEq, Eq, Clone)]
+pub enum KernelArgumentType {
+ Normal(VariableParamType),
+ Shared,
}
impl From<FnArgumentType> for Type {
@@ -213,15 +262,25 @@ impl From<FnArgumentType> for Type {
match t {
FnArgumentType::Reg(x) => x.into(),
FnArgumentType::Param(x) => x.into(),
+ FnArgumentType::Shared => Type::Scalar(ScalarType::B64),
}
}
}
-#[derive(PartialEq, Eq, Hash, Clone)]
+#[derive(PartialEq, Eq, Clone, Copy)]
+pub enum PointerStateSpace {
+ Global,
+ Const,
+ Shared,
+ Param,
+}
+
+#[derive(PartialEq, Eq, Clone)]
pub enum Type {
Scalar(ScalarType),
Vector(ScalarType, u8),
Array(ScalarType, Vec<u32>),
+ Pointer(ScalarType, PointerStateSpace),
}
#[derive(PartialEq, Eq, Hash, Clone, Copy)]
@@ -343,7 +402,8 @@ pub enum VariableType {
Reg(VariableRegType),
Local(VariableLocalType),
Param(VariableParamType),
- Global(VariableLocalType),
+ Global(VariableGlobalType),
+ Shared(VariableGlobalType),
}
impl VariableType {
@@ -353,6 +413,7 @@ impl VariableType {
VariableType::Local(t) => (StateSpace::Local, t.clone().into()),
VariableType::Param(t) => (StateSpace::Param, t.clone().into()),
VariableType::Global(t) => (StateSpace::Global, t.clone().into()),
+ VariableType::Shared(t) => (StateSpace::Shared, t.clone().into()),
}
}
}
@@ -364,6 +425,7 @@ impl From<VariableType> for Type {
VariableType::Local(t) => t.into(),
VariableType::Param(t) => t.into(),
VariableType::Global(t) => t.into(),
+ VariableType::Shared(t) => t.into(),
}
}
}
@@ -1039,6 +1101,20 @@ impl<'a> NumsOrArrays<'a> {
}
}
+pub enum ArrayOrPointer {
+ Array { dimensions: Vec<u32>, init: Vec<u8> },
+ Pointer,
+}
+
+bitflags! {
+ pub struct LinkingDirective: u8 {
+ const NONE = 0b000;
+ const EXTERN = 0b001;
+ const VISIBLE = 0b10;
+ const WEAK = 0b100;
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs
index 8ae1c6d..1aac8ab 100644
--- a/ptx/src/lib.rs
+++ b/ptx/src/lib.rs
@@ -17,6 +17,9 @@ extern crate spirv_headers as spirv;
#[cfg(test)]
extern crate spirv_tools_sys as spirv_tools;
+#[macro_use]
+extern crate bitflags;
+
lalrpop_mod!(
#[allow(warnings)]
ptx
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 0b6fa0f..4624580 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -3,6 +3,7 @@ use crate::ast::UnwrapWithVec;
use crate::{without_none, vector_index};
use lalrpop_util::ParseError;
+use std::convert::TryInto;
grammar<'a>(errors: &mut Vec<ast::PtxError>);
@@ -210,7 +211,7 @@ Directive: Option<ast::Directive<'input, ast::ParsedArgParams<'input>>> = {
<f:Function> => Some(ast::Directive::Method(f)),
File => None,
Section => None,
- <v:GlobalVariable> ";" => Some(ast::Directive::Variable(v)),
+ <v:ModuleVariable> ";" => Some(ast::Directive::Variable(v)),
};
AddressSize = {
@@ -218,17 +219,23 @@ AddressSize = {
};
Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = {
- LinkingDirective*
+ LinkingDirectives
<func_directive:MethodDecl>
<body:FunctionBody> => ast::Function{<>}
};
-LinkingDirective = {
- ".extern",
- ".visible",
- ".weak"
+LinkingDirective: ast::LinkingDirective = {
+ ".extern" => ast::LinkingDirective::EXTERN,
+ ".visible" => ast::LinkingDirective::VISIBLE,
+ ".weak" => ast::LinkingDirective::WEAK,
};
+LinkingDirectives: ast::LinkingDirective = {
+ <ldirs:LinkingDirective*> => {
+ ldirs.into_iter().fold(ast::LinkingDirective::NONE, |x, y| x | y)
+ }
+}
+
MethodDecl: ast::MethodDecl<'input, &'input str> = {
".entry" <name:ExtendedID> <params:KernelArguments> => ast::MethodDecl::Kernel(name, params),
".func" <ret_vals:FnArguments?> <name:ExtendedID> <params:FnArguments> => {
@@ -244,10 +251,15 @@ FnArguments: Vec<ast::FnArgument<&'input str>> = {
"(" <args:Comma<FnInput>> ")" => args
};
-KernelInput: ast::Variable<ast::VariableParamType, &'input str> = {
+KernelInput: ast::Variable<ast::KernelArgumentType, &'input str> = {
<v:ParamDeclaration> => {
let (align, v_type, name) = v;
- ast::Variable{ align, v_type, name, array_init: Vec::new() }
+ ast::Variable {
+ align,
+ v_type: ast::KernelArgumentType::Normal(v_type),
+ name,
+ array_init: Vec::new()
+ }
}
}
@@ -357,69 +369,120 @@ Variable: ast::Variable<ast::VariableType, &'input str> = {
};
RegVariable: (Option<u32>, ast::VariableRegType, &'input str) = {
- ".reg" <align:Align?> <t:ScalarType> <name:ExtendedID> => {
+ ".reg" <var:VariableScalar<ScalarType>> => {
+ let (align, t, name) = var;
let v_type = ast::VariableRegType::Scalar(t);
(align, v_type, name)
},
- ".reg" <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
+ ".reg" <var:VariableVector<SizedScalarType>> => {
+ let (align, v_len, t, name) = var;
let v_type = ast::VariableRegType::Vector(t, v_len);
(align, v_type, name)
}
}
LocalVariable: ast::Variable<ast::VariableType, &'input str> = {
- ".local" <def:LocalVariableDefinition> => {
- let (align, array_init, v_type, name) = def;
- ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init }
+ ".local" <var:VariableScalar<SizedScalarType>> => {
+ let (align, t, name) = var;
+ let v_type = ast::VariableType::Local(ast::VariableLocalType::Scalar(t));
+ ast::Variable { align, v_type, name, array_init: Vec::new() }
+ },
+ ".local" <var:VariableVector<SizedScalarType>> => {
+ let (align, v_len, t, name) = var;
+ let v_type = ast::VariableType::Local(ast::VariableLocalType::Vector(t, v_len));
+ ast::Variable { align, v_type, name, array_init: Vec::new() }
+ },
+ ".local" <var:VariableArrayOrPointer<SizedScalarType>> =>? {
+ let (align, t, name, arr_or_ptr) = var;
+ let (v_type, array_init) = match arr_or_ptr {
+ ast::ArrayOrPointer::Array { dimensions, init } => {
+ (ast::VariableLocalType::Array(t, dimensions), init)
+ }
+ ast::ArrayOrPointer::Pointer => {
+ return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray });
+ }
+ };
+ Ok(ast::Variable { align, v_type: ast::VariableType::Local(v_type), name, array_init })
}
}
-GlobalVariable: ast::Variable<ast::VariableType, &'input str> = {
- ".global" <def:LocalVariableDefinition> => {
- let (align, array_init, v_type, name) = def;
+ModuleVariable: ast::Variable<ast::VariableType, &'input str> = {
+ LinkingDirectives ".global" <def:GlobalVariableDefinitionNoArray> => {
+ let (align, v_type, name, array_init) = def;
ast::Variable { align, v_type: ast::VariableType::Global(v_type), name, array_init }
+ },
+ LinkingDirectives ".shared" <def:GlobalVariableDefinitionNoArray> => {
+ let (align, v_type, name, array_init) = def;
+ ast::Variable { align, v_type: ast::VariableType::Shared(v_type), name, array_init: Vec::new() }
+ },
+ <ldirs:LinkingDirectives> <space:Or<".global", ".shared">> <var:VariableArrayOrPointer<SizedScalarType>> =>? {
+ let (align, t, name, arr_or_ptr) = var;
+ let (v_type, array_init) = match arr_or_ptr {
+ ast::ArrayOrPointer::Array { dimensions, init } => {
+ if space == ".global" {
+ (ast::VariableType::Global(ast::VariableGlobalType::Array(t, dimensions)), init)
+ } else {
+ (ast::VariableType::Shared(ast::VariableGlobalType::Array(t, dimensions)), init)
+ }
+ }
+ ast::ArrayOrPointer::Pointer => {
+ if !ldirs.contains(ast::LinkingDirective::EXTERN) {
+ return Err(ParseError::User { error: ast::PtxError::NonExternPointer });
+ }
+ if space == ".global" {
+ (ast::VariableType::Global(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Global)), Vec::new())
+ } else {
+ (ast::VariableType::Shared(ast::VariableGlobalType::Pointer(t, ast::PointerStateSpace::Shared)), Vec::new())
+ }
+ }
+ };
+ Ok(ast::Variable{ align, array_init, v_type, name })
}
}
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parameter-state-space
ParamVariable: (Option<u32>, Vec<u8>, ast::VariableParamType, &'input str) = {
- ".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
+ ".param" <var:VariableScalar<ParamScalarType>> => {
+ let (align, t, name) = var;
let v_type = ast::VariableParamType::Scalar(t);
(align, Vec::new(), v_type, name)
},
- ".param" <align:Align?> <arr:ArrayDefinition> => {
- let (array_init, name, (t, dimensions)) = arr;
- let v_type = ast::VariableParamType::Array(t, dimensions);
+ ".param" <var:VariableArrayOrPointer<SizedScalarType>> => {
+ let (align, t, name, arr_or_ptr) = var;
+ let (v_type, array_init) = match arr_or_ptr {
+ ast::ArrayOrPointer::Array { dimensions, init } => {
+ (ast::VariableParamType::Array(t, dimensions), init)
+ }
+ ast::ArrayOrPointer::Pointer => {
+ (ast::VariableParamType::Pointer(t, ast::PointerStateSpace::Param), Vec::new())
+ }
+ };
(align, array_init, v_type, name)
}
}
ParamDeclaration: (Option<u32>, ast::VariableParamType, &'input str) = {
- ".param" <align:Align?> <t:ParamScalarType> <name:ExtendedID> => {
- let v_type = ast::VariableParamType::Scalar(t);
- (align, v_type, name)
- },
- ".param" <align:Align?> <arr:ArrayDeclaration> => {
- let (name, (t, dimensions)) = arr;
- let v_type = ast::VariableParamType::Array(t, dimensions);
- (align, v_type, name)
+ <var:ParamVariable> =>? {
+ let (align, array_init, v_type, name) = var;
+ if array_init.len() > 0 {
+ Err(ParseError::User { error: ast::PtxError::ArrayInitalizer })
+ } else {
+ Ok((align, v_type, name))
+ }
}
}
-LocalVariableDefinition: (Option<u32>, Vec<u8>, ast::VariableLocalType, &'input str) = {
- <align:Align?> <t:SizedScalarType> <name:ExtendedID> => {
- let v_type = ast::VariableLocalType::Scalar(t);
- (align, Vec::new(), v_type, name)
+GlobalVariableDefinitionNoArray: (Option<u32>, ast::VariableGlobalType, &'input str, Vec<u8>) = {
+ <scalar:VariableScalar<SizedScalarType>> => {
+ let (align, t, name) = scalar;
+ let v_type = ast::VariableGlobalType::Scalar(t);
+ (align, v_type, name, Vec::new())
},
- <align:Align?> <v_len:VectorPrefix> <t:SizedScalarType> <name:ExtendedID> => {
- let v_type = ast::VariableLocalType::Vector(t, v_len);
- (align, Vec::new(), v_type, name)
+ <var:VariableVector<SizedScalarType>> => {
+ let (align, v_len, t, name) = var;
+ let v_type = ast::VariableGlobalType::Vector(t, v_len);
+ (align, v_type, name, Vec::new())
},
- <align:Align?> <arr:ArrayDefinition> => {
- let (array_init, name, (t, dimensions)) = arr;
- let v_type = ast::VariableLocalType::Array(t, dimensions);
- (align, array_init, v_type, name)
- }
}
#[inline]
@@ -461,60 +524,6 @@ ParamScalarType: ast::ParamScalarType = {
".f64" => ast::ParamScalarType::F64,
}
-ArrayDefinition: (Vec<u8>, &'input str, (ast::SizedScalarType, Vec<u32>)) = {
- <typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimensions> <init:ArrayInitializer?> =>? {
- let mut dims = dims;
- let array_init = init.unwrap_or(ast::NumsOrArrays::Nums(Vec::new())).to_vec(typ, &mut dims)?;
- Ok((
- array_init,
- name,
- (typ, dims)
- ))
- }
-}
-
-ArrayDeclaration: (&'input str, (ast::SizedScalarType, Vec<u32>)) = {
- <typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimension+> =>? {
- let dims = dims.into_iter().map(|x| if x > 0 { Ok(x) } else { Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray }) }).collect::<Result<_,_>>()?;
- Ok((name, (typ, dims)))
- }
-}
-
-// [0] and [] are treated the same
-ArrayDimensions: Vec<u32> = {
- ArrayEmptyDimension => vec![0u32],
- ArrayEmptyDimension <dims:ArrayDimension+> => {
- let mut dims = dims;
- let mut result = vec![0u32];
- result.append(&mut dims);
- result
- },
- <dims:ArrayDimension+> => dims
-}
-
-ArrayEmptyDimension = {
- "[" "]"
-}
-
-ArrayDimension: u32 = {
- "[" <n:Num> "]" =>? {
- str::parse::<u32>(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) })
- }
-}
-
-ArrayInitializer: ast::NumsOrArrays<'input> = {
- "=" <nums:NumsOrArraysBracket> => nums
-}
-
-NumsOrArraysBracket: ast::NumsOrArrays<'input> = {
- "{" <nums:NumsOrArrays> "}" => nums
-}
-
-NumsOrArrays: ast::NumsOrArrays<'input> = {
- <n:Comma<NumsOrArraysBracket>> => ast::NumsOrArrays::Arrays(n),
- <n:CommaNonEmpty<Num>> => ast::NumsOrArrays::Nums(n),
-}
-
Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
InstLd,
InstMov,
@@ -1311,6 +1320,73 @@ BitType = {
".b8", ".b16", ".b32", ".b64"
};
+VariableScalar<T>: (Option<u32>, T, &'input str) = {
+ <align:Align?> <v_type:T> <name:ExtendedID> => {
+ (align, v_type, name)
+ }
+}
+
+VariableVector<T>: (Option<u32>, u8, T, &'input str) = {
+ <align:Align?> <v_len:VectorPrefix> <v_type:T> <name:ExtendedID> => {
+ (align, v_len, v_type, name)
+ }
+}
+
+// empty dimensions [0] means it's a pointer
+VariableArrayOrPointer<T>: (Option<u32>, T, &'input str, ast::ArrayOrPointer) = {
+ <align:Align?> <typ:SizedScalarType> <name:ExtendedID> <dims:ArrayDimensions> <init:ArrayInitializer?> =>? {
+ let mut dims = dims;
+ let array_init = match init {
+ Some(init) => {
+ let init_vec = init.to_vec(typ, &mut dims)?;
+ ast::ArrayOrPointer::Array { dimensions: dims, init: init_vec }
+ }
+ None => {
+ if dims.len() > 1 && dims.contains(&0) {
+ return Err(ParseError::User { error: ast::PtxError::ZeroDimensionArray })
+ }
+ ast::ArrayOrPointer::Pointer
+ }
+ };
+ Ok((align, typ, name, array_init))
+ }
+}
+
+// [0] and [] are treated the same
+ArrayDimensions: Vec<u32> = {
+ ArrayEmptyDimension => vec![0u32],
+ ArrayEmptyDimension <dims:ArrayDimension+> => {
+ let mut dims = dims;
+ let mut result = vec![0u32];
+ result.append(&mut dims);
+ result
+ },
+ <dims:ArrayDimension+> => dims
+}
+
+ArrayEmptyDimension = {
+ "[" "]"
+}
+
+ArrayDimension: u32 = {
+ "[" <n:Num> "]" =>? {
+ str::parse::<u32>(n).map_err(|e| ParseError::User { error: ast::PtxError::ParseInt(e) })
+ }
+}
+
+ArrayInitializer: ast::NumsOrArrays<'input> = {
+ "=" <nums:NumsOrArraysBracket> => nums
+}
+
+NumsOrArraysBracket: ast::NumsOrArrays<'input> = {
+ "{" <nums:NumsOrArrays> "}" => nums
+}
+
+NumsOrArrays: ast::NumsOrArrays<'input> = {
+ <n:Comma<NumsOrArraysBracket>> => ast::NumsOrArrays::Arrays(n),
+ <n:CommaNonEmpty<Num>> => ast::NumsOrArrays::Nums(n),
+}
+
Comma<T>: Vec<T> = {
<v:(<T> ",")*> <e:T?> => match e {
None => v,
@@ -1329,3 +1405,9 @@ CommaNonEmpty<T>: Vec<T> = {
v
}
};
+
+#[inline]
+Or<T1, T2>: T1 = {
+ T1,
+ T2
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_build/global_extern_array.ptx b/ptx/src/test/spirv_build/global_extern_array.ptx
new file mode 100644
index 0000000..fe0f19f
--- /dev/null
+++ b/ptx/src/test/spirv_build/global_extern_array.ptx
@@ -0,0 +1,5 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.extern .global .b32 foobar [1]; \ No newline at end of file
diff --git a/ptx/src/test/spirv_build/param_func_array_0.ptx b/ptx/src/test/spirv_build/param_func_array_0.ptx
new file mode 100644
index 0000000..005af52
--- /dev/null
+++ b/ptx/src/test/spirv_build/param_func_array_0.ptx
@@ -0,0 +1,10 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .func foobar(
+ .param .b32 foobar[]
+)
+{
+ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_fail/const_ptr.ptx b/ptx/src/test/spirv_fail/const_ptr.ptx
new file mode 100644
index 0000000..0efd729
--- /dev/null
+++ b/ptx/src/test/spirv_fail/const_ptr.ptx
@@ -0,0 +1,5 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.const .b32 foobar []; \ No newline at end of file
diff --git a/ptx/src/test/spirv_fail/global_ptr.ptx b/ptx/src/test/spirv_fail/global_ptr.ptx
new file mode 100644
index 0000000..7ce4c83
--- /dev/null
+++ b/ptx/src/test/spirv_fail/global_ptr.ptx
@@ -0,0 +1,5 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.global .b32 foobar []; \ No newline at end of file
diff --git a/ptx/src/test/spirv_fail/local_ptr.txt b/ptx/src/test/spirv_fail/local_ptr.txt
new file mode 100644
index 0000000..9375011
--- /dev/null
+++ b/ptx/src/test/spirv_fail/local_ptr.txt
@@ -0,0 +1,12 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+
+.visible .entry func()
+{
+
+ .local .b32 foobar [1];
+
+ ret;
+}
diff --git a/ptx/src/test/spirv_fail/param_entry_array_0.ptx b/ptx/src/test/spirv_fail/param_entry_array_0.ptx
new file mode 100644
index 0000000..86dd5eb
--- /dev/null
+++ b/ptx/src/test/spirv_fail/param_entry_array_0.ptx
@@ -0,0 +1,10 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry foobar(
+ .param .b32 foobar[]
+)
+{
+ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_fail/param_vector.ptx b/ptx/src/test/spirv_fail/param_vector.ptx
new file mode 100644
index 0000000..28895e2
--- /dev/null
+++ b/ptx/src/test/spirv_fail/param_vector.ptx
@@ -0,0 +1,10 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .func foobar(
+ .param .b32 .v2 foobar
+)
+{
+ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_fail/shared_ptr.ptx b/ptx/src/test/spirv_fail/shared_ptr.ptx
new file mode 100644
index 0000000..b1b815a
--- /dev/null
+++ b/ptx/src/test/spirv_fail/shared_ptr.ptx
@@ -0,0 +1,5 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+extern .shared .b32 foobar []; \ No newline at end of file
diff --git a/ptx/src/test/spirv_fail/shared_ptr2.ptx b/ptx/src/test/spirv_fail/shared_ptr2.ptx
new file mode 100644
index 0000000..fb2472a
--- /dev/null
+++ b/ptx/src/test/spirv_fail/shared_ptr2.ptx
@@ -0,0 +1,13 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.extern .shared .b32 foobar1 [];
+
+.visible .func _Z4dupaPf(
+ .param .b64 _Z4dupaPf_param_0
+)
+{
+.shared .b32 foobar2 [];
+ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/extern_shared.ptx b/ptx/src/test/spirv_run/extern_shared.ptx
new file mode 100644
index 0000000..ac5c256
--- /dev/null
+++ b/ptx/src/test/spirv_run/extern_shared.ptx
@@ -0,0 +1,24 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.extern .shared .b32 shared_mem [];
+
+.visible .entry extern_shared(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.global.u64 temp, [in_addr];
+ st.shared.u64 [shared_mem], temp;
+ ld.shared.u64 temp, [shared_mem];
+ st.global.u64 [out_addr], temp;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/extern_shared.spvtxt b/ptx/src/test/spirv_run/extern_shared.spvtxt
new file mode 100644
index 0000000..84e7eac
--- /dev/null
+++ b/ptx/src/test/spirv_run/extern_shared.spvtxt
@@ -0,0 +1,53 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %29 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "cvta"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %32 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
+ %1 = OpFunction %void None %32
+ %7 = OpFunctionParameter %ulong
+ %8 = OpFunctionParameter %ulong
+ %27 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_float Function
+ OpStore %2 %7
+ OpStore %3 %8
+ %10 = OpLoad %ulong %2
+ %9 = OpCopyObject %ulong %10
+ OpStore %4 %9
+ %12 = OpLoad %ulong %3
+ %11 = OpCopyObject %ulong %12
+ OpStore %5 %11
+ %14 = OpLoad %ulong %4
+ %22 = OpCopyObject %ulong %14
+ %21 = OpCopyObject %ulong %22
+ %13 = OpCopyObject %ulong %21
+ OpStore %4 %13
+ %16 = OpLoad %ulong %5
+ %24 = OpCopyObject %ulong %16
+ %23 = OpCopyObject %ulong %24
+ %15 = OpCopyObject %ulong %23
+ OpStore %5 %15
+ %18 = OpLoad %ulong %4
+ %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18
+ %17 = OpLoad %float %25
+ OpStore %6 %17
+ %19 = OpLoad %ulong %5
+ %20 = OpLoad %float %6
+ %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19
+ OpStore %26 %20
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/extern_shared_call.ptx b/ptx/src/test/spirv_run/extern_shared_call.ptx
new file mode 100644
index 0000000..6626783
--- /dev/null
+++ b/ptx/src/test/spirv_run/extern_shared_call.ptx
@@ -0,0 +1,45 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.extern .shared .align 4 .b32 shared_mem[];
+
+.func (.param .u64 output) incr_shared_2_param(
+ .param .u64 .ptr .shared shared_mem_addr
+)
+{
+ .reg .u64 temp;
+ ld.shared.u64 temp, [shared_mem_addr];
+ add.u64 temp, temp, 2;
+ st.param.u64 [output], temp;
+ ret;
+}
+
+.func (.param .u64 output) incr_shared_2_global()
+{
+ .reg .u64 temp;
+ ld.shared.u64 temp, [shared_mem];
+ add.u64 temp, temp, 2;
+ st.param.u64 [output], temp;
+ ret;
+}
+
+
+.visible .entry extern_shared(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.global.u64 temp, [in_addr];
+ st.shared.u64 [shared_mem], temp;
+ ld.shared.u64 temp, [shared_mem];
+ st.global.u64 [out_addr], temp;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/extern_shared_call.spvtxt b/ptx/src/test/spirv_run/extern_shared_call.spvtxt
new file mode 100644
index 0000000..84e7eac
--- /dev/null
+++ b/ptx/src/test/spirv_run/extern_shared_call.spvtxt
@@ -0,0 +1,53 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %29 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "cvta"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %32 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+ %float = OpTypeFloat 32
+%_ptr_Function_float = OpTypePointer Function %float
+%_ptr_CrossWorkgroup_float = OpTypePointer CrossWorkgroup %float
+ %1 = OpFunction %void None %32
+ %7 = OpFunctionParameter %ulong
+ %8 = OpFunctionParameter %ulong
+ %27 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_float Function
+ OpStore %2 %7
+ OpStore %3 %8
+ %10 = OpLoad %ulong %2
+ %9 = OpCopyObject %ulong %10
+ OpStore %4 %9
+ %12 = OpLoad %ulong %3
+ %11 = OpCopyObject %ulong %12
+ OpStore %5 %11
+ %14 = OpLoad %ulong %4
+ %22 = OpCopyObject %ulong %14
+ %21 = OpCopyObject %ulong %22
+ %13 = OpCopyObject %ulong %21
+ OpStore %4 %13
+ %16 = OpLoad %ulong %5
+ %24 = OpCopyObject %ulong %16
+ %23 = OpCopyObject %ulong %24
+ %15 = OpCopyObject %ulong %23
+ OpStore %5 %15
+ %18 = OpLoad %ulong %4
+ %25 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %18
+ %17 = OpLoad %float %25
+ OpStore %6 %17
+ %19 = OpLoad %ulong %5
+ %20 = OpLoad %float %6
+ %26 = OpConvertUToPtr %_ptr_CrossWorkgroup_float %19
+ OpStore %26 %20
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index 0c881d9..14c3bc9 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -78,6 +78,7 @@ test_ptx!(sub, [2u64], [1u64]);
test_ptx!(min, [555i32, 444i32], [444i32]);
test_ptx!(max, [555i32, 444i32], [555i32]);
test_ptx!(global_array, [0xDEADu32], [1u32]);
+test_ptx!(extern_shared, [127u64], [127u64]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index a86ab3c..09dd0bb 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -34,11 +34,7 @@ enum SpirvType {
impl SpirvType {
fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self {
- let key = match t {
- ast::Type::Scalar(typ) => SpirvType::Base(SpirvScalarKey::from(typ)),
- ast::Type::Vector(typ, len) => SpirvType::Vector(SpirvScalarKey::from(typ), len),
- ast::Type::Array(typ, len) => SpirvType::Array(SpirvScalarKey::from(typ), len),
- };
+ let key = t.into();
SpirvType::Pointer(Box::new(key), sc)
}
}
@@ -49,6 +45,20 @@ impl From<ast::Type> for SpirvType {
ast::Type::Scalar(t) => SpirvType::Base(t.into()),
ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len),
ast::Type::Array(t, len) => SpirvType::Array(t.into(), len),
+ ast::Type::Pointer(typ, state_space) => {
+ SpirvType::Pointer(Box::new(SpirvType::Base(typ.into())), state_space.into())
+ }
+ }
+ }
+}
+
+impl Into<spirv::StorageClass> for ast::PointerStateSpace {
+ fn into(self) -> spirv::StorageClass {
+ match self {
+ ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant,
+ ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup,
+ ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup,
+ ast::PointerStateSpace::Param => spirv::StorageClass::Function,
}
}
}
@@ -354,6 +364,14 @@ impl TypeWordMap {
b.constant_composite(result_type, None, &components)
}
},
+ ast::Type::Pointer(typ, state_space) => {
+ let base = self.get_or_add_constant(b, &ast::Type::Scalar(*typ), &[])?;
+ let result_type = self.get_or_add(
+ b,
+ SpirvType::Pointer(Box::new(SpirvType::from(*typ)), (*state_space).into()),
+ );
+ b.variable(result_type, None, (*state_space).into(), Some(base))
+ }
})
}
@@ -415,13 +433,7 @@ pub fn to_spirv_module<'a>(
None => continue,
};
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f.globals)?;
- emit_function_header(
- &mut builder,
- &mut map,
- &id_defs,
- f.func_directive,
- &mut args_len,
- )?;
+ emit_function_header(&mut builder, &mut map, &id_defs, f.func_decl, &mut args_len)?;
emit_function_body_ops(&mut builder, &mut map, opencl_id, &f_body)?;
builder.end_function()?;
}
@@ -430,6 +442,202 @@ pub fn to_spirv_module<'a>(
Ok((builder.module(), args_len))
}
+type MultiHashMap<K, V> = HashMap<K, Vec<V>>;
+
+fn multi_hash_map_append<K: Eq + std::hash::Hash, V>(m: &mut MultiHashMap<K, V>, key: K, value: V) {
+ match m.entry(key) {
+ hash_map::Entry::Occupied(mut entry) => {
+ entry.get_mut().push(value);
+ }
+ hash_map::Entry::Vacant(entry) => {
+ entry.insert(vec![value]);
+ }
+ }
+}
+
+// PTX represents dynamically allocated shared local memory as
+// .extern .shared .align 4 .b8 shared_mem[];
+// In SPIRV/OpenCL world this is expressed as an additional argument
+// This pass looks for all uses of .extern .shared and converts them to
+// an additional method argument
+fn convert_dynamic_shared_memory_usage<'input>(
+ new_id: &mut impl FnMut() -> spirv::Word,
+ id_defs: &mut GlobalStringIdResolver<'input>,
+ module: Vec<Directive<'input>>,
+) -> Vec<Directive<'input>> {
+ let mut extern_shared_decls = HashSet::new();
+ for dir in module.iter() {
+ match dir {
+ Directive::Variable(var) => {
+ if let ast::VariableType::Shared(_) = var.v_type {
+ extern_shared_decls.insert(var.name);
+ }
+ }
+ _ => {}
+ }
+ }
+ if extern_shared_decls.len() == 0 {
+ return module;
+ }
+ let mut methods_using_extern_shared = HashSet::new();
+ let mut directly_called_by = MultiHashMap::new();
+ let module = module
+ .into_iter()
+ .map(|directive| match directive {
+ Directive::Method(Function {
+ func_decl,
+ globals,
+ body: Some(statements),
+ }) => {
+ let call_key = match func_decl {
+ ast::MethodDecl::Kernel(name, _) => CallgraphKey::Kernel(name),
+ ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
+ };
+ let statements = statements
+ .into_iter()
+ .map(|statement| match statement {
+ Statement::Call(call) => {
+ multi_hash_map_append(&mut directly_called_by, call.func, call_key);
+ Statement::Call(call)
+ }
+ statement => statement.map_id(&mut |id| {
+ if extern_shared_decls.contains(&id) {
+ methods_using_extern_shared.insert(call_key);
+ }
+ id
+ }),
+ })
+ .collect();
+ Directive::Method(Function {
+ func_decl,
+ globals,
+ body: Some(statements),
+ })
+ }
+ directive => directive,
+ })
+ .collect::<Vec<_>>();
+ // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared,
+ // make sure it gets propagated to `fn1` and `kernel`
+ get_callers_of_extern_shared(&mut methods_using_extern_shared, &directly_called_by);
+ // now visit every method declaration and inject those additional arguments
+ module
+ .into_iter()
+ .map(|directive| match directive {
+ Directive::Method(Function {
+ mut func_decl,
+ globals,
+ body: Some(statements),
+ }) => {
+ let call_key = match func_decl {
+ ast::MethodDecl::Kernel(name, _) => CallgraphKey::Kernel(name),
+ ast::MethodDecl::Func(_, id, _) => CallgraphKey::Func(id),
+ };
+ if !methods_using_extern_shared.contains(&call_key) {
+ return Directive::Method(Function {
+ func_decl,
+ globals,
+ body: Some(statements),
+ });
+ }
+ let shared_id_param = new_id();
+ match &mut func_decl {
+ ast::MethodDecl::Func(_, _, input_args) => {
+ input_args.push(ast::Variable {
+ align: None,
+ v_type: ast::FnArgumentType::Shared,
+ array_init: Vec::new(),
+ name: shared_id_param,
+ });
+ }
+ ast::MethodDecl::Kernel(_, input_args) => {
+ input_args.push(ast::Variable {
+ align: None,
+ v_type: ast::KernelArgumentType::Shared,
+ array_init: Vec::new(),
+ name: shared_id_param,
+ });
+ }
+ }
+ let statements = statements
+ .into_iter()
+ .map(|statement| match statement {
+ Statement::Call(mut call) => {
+ // We can safely skip checking call arguments,
+ // because there's simply no way to pass shared ptr
+ // without converting it to .b64 first
+ if methods_using_extern_shared.contains(&CallgraphKey::Func(call.func))
+ {
+ call.param_list
+ .push((shared_id_param, ast::FnArgumentType::Shared));
+ }
+ Statement::Call(call)
+ }
+ statement => statement.map_id(&mut |id| {
+ if extern_shared_decls.contains(&id) {
+ shared_id_param
+ } else {
+ id
+ }
+ }),
+ })
+ .collect();
+ Directive::Method(Function {
+ func_decl,
+ globals,
+ body: Some(statements),
+ })
+ }
+ directive => directive,
+ })
+ .collect::<Vec<_>>()
+}
+
+fn get_callers_of_extern_shared<'a>(
+ methods_using_extern_shared: &mut HashSet<CallgraphKey<'a>>,
+ directly_called_by: &MultiHashMap<spirv::Word, CallgraphKey<'a>>,
+) {
+ let direct_uses_of_extern_shared = methods_using_extern_shared
+ .iter()
+ .filter_map(|method| {
+ if let CallgraphKey::Func(f_id) = method {
+ Some(*f_id)
+ } else {
+ None
+ }
+ })
+ .collect::<Vec<_>>();
+ for fn_id in direct_uses_of_extern_shared {
+ get_callers_of_extern_shared_single(methods_using_extern_shared, directly_called_by, fn_id);
+ }
+}
+
+fn get_callers_of_extern_shared_single<'a>(
+ methods_using_extern_shared: &mut HashSet<CallgraphKey<'a>>,
+ directly_called_by: &MultiHashMap<spirv::Word, CallgraphKey<'a>>,
+ fn_id: spirv::Word,
+) {
+ if let Some(callers) = directly_called_by.get(&fn_id) {
+ for caller in callers {
+ if methods_using_extern_shared.insert(*caller) {
+ if let CallgraphKey::Func(caller_fn) = caller {
+ get_callers_of_extern_shared_single(
+ methods_using_extern_shared,
+ directly_called_by,
+ *caller_fn,
+ );
+ }
+ }
+ }
+ }
+}
+
+#[derive(Hash, PartialEq, Eq, Copy, Clone)]
+enum CallgraphKey<'input> {
+ Kernel(&'input str),
+ Func(spirv::Word),
+}
+
fn emit_builtins(
builder: &mut dr::Builder,
map: &mut TypeWordMap,
@@ -594,6 +802,7 @@ fn expand_fn_params<'a, 'b>(
let ss = match a.v_type {
ast::FnArgumentType::Reg(_) => StateSpace::Reg,
ast::FnArgumentType::Param(_) => StateSpace::Param,
+ ast::FnArgumentType::Shared => StateSpace::Shared,
};
ast::FnArgument {
name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type.clone())))),
@@ -615,7 +824,7 @@ fn to_ssa<'input, 'b>(
Some(vec) => vec,
None => {
return Ok(Function {
- func_directive: f_args,
+ func_decl: f_args,
body: None,
globals: Vec::new(),
})
@@ -637,7 +846,7 @@ fn to_ssa<'input, 'b>(
let sorted_statements = normalize_variable_decls(labeled_statements);
let (f_body, globals) = extract_globals(sorted_statements);
Ok(Function {
- func_directive: f_args,
+ func_decl: f_args,
globals: globals,
body: Some(f_body),
})
@@ -935,7 +1144,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
let new_id = id_def.new_id(typ.clone());
result.push(Statement::Variable(ast::Variable {
align: p.align,
- v_type: ast::VariableType::Param(p.v_type.clone()),
+ v_type: ast::VariableType::Param(p.v_type.clone().to_param()),
name: p.name,
array_init: p.array_init.clone(),
}));
@@ -1878,26 +2087,33 @@ fn emit_variable(
map: &mut TypeWordMap,
var: &ast::Variable<ast::VariableType, spirv::Word>,
) -> Result<(), TranslateError> {
- let (should_init, st_class) = match var.v_type {
+ let (must_init, st_class) = match var.v_type {
ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => {
(false, spirv::StorageClass::Function)
}
ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup),
+ ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup),
};
- let type_id = map.get_or_add(
- builder,
- SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class),
- );
- let initalizer = if should_init {
+ let initalizer = if var.array_init.len() > 0 {
Some(map.get_or_add_constant(
builder,
&ast::Type::from(var.v_type.clone()),
&*var.array_init,
)?)
+ } else if must_init {
+ let type_id = map.get_or_add(
+ builder,
+ SpirvType::from(ast::Type::from(var.v_type.clone())),
+ );
+ Some(builder.constant_null(type_id, None))
} else {
None
};
- builder.variable(type_id, Some(var.name), st_class, initalizer);
+ let ptr_type_id = map.get_or_add(
+ builder,
+ SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class),
+ );
+ builder.variable(ptr_type_id, Some(var.name), st_class, initalizer);
if let Some(align) = var.align {
builder.decorate(
var.name,
@@ -2537,7 +2753,8 @@ fn expand_map_variables<'a, 'b>(
ast::VariableType::Reg(_) => StateSpace::Reg,
ast::VariableType::Local(_) => StateSpace::Local,
ast::VariableType::Param(_) => StateSpace::ParamReg,
- ast::VariableType::Global(_) => todo!(),
+ ast::VariableType::Global(_) => StateSpace::Global,
+ ast::VariableType::Shared(_) => StateSpace::Shared,
};
match var.count {
Some(count) => {
@@ -2888,6 +3105,69 @@ enum Statement<I, P: ast::ArgParams> {
Undef(ast::Type, spirv::Word),
}
+impl ExpandedStatement {
+ fn map_id(self, f: &mut impl FnMut(spirv::Word) -> spirv::Word) -> ExpandedStatement {
+ match self {
+ Statement::Label(id) => Statement::Label(f(id)),
+ Statement::Variable(mut var) => {
+ var.name = f(var.name);
+ Statement::Variable(var)
+ }
+ Statement::Instruction(inst) => inst
+ .visit_variable_extended(&mut |arg: ArgumentDescriptor<_>, _| Ok(f(arg.op)))
+ .unwrap(),
+ Statement::LoadVar(mut arg, typ) => {
+ arg.dst = f(arg.dst);
+ arg.src = f(arg.src);
+ Statement::LoadVar(arg, typ)
+ }
+ Statement::StoreVar(mut arg, typ) => {
+ arg.src1 = f(arg.src1);
+ arg.src2 = f(arg.src2);
+ Statement::StoreVar(arg, typ)
+ }
+ Statement::Call(mut call) => {
+ for (id, _) in call.ret_params.iter_mut() {
+ *id = f(*id);
+ }
+ call.func = f(call.func);
+ for (id, _) in call.param_list.iter_mut() {
+ *id = f(*id);
+ }
+ Statement::Call(call)
+ }
+ Statement::Composite(mut composite) => {
+ composite.dst = f(composite.dst);
+ composite.src_composite = f(composite.src_composite);
+ Statement::Composite(composite)
+ }
+ Statement::Conditional(mut conditional) => {
+ conditional.predicate = f(conditional.predicate);
+ conditional.if_true = f(conditional.if_true);
+ conditional.if_false = f(conditional.if_false);
+ Statement::Conditional(conditional)
+ }
+ Statement::Conversion(mut conv) => {
+ conv.dst = f(conv.dst);
+ conv.src = f(conv.src);
+ Statement::Conversion(conv)
+ }
+ Statement::Constant(mut constant) => {
+ constant.dst = f(constant.dst);
+ Statement::Constant(constant)
+ }
+ Statement::RetValue(data, id) => {
+ let id = f(id);
+ Statement::RetValue(data, id)
+ }
+ Statement::Undef(typ, id) => {
+ let id = f(id);
+ Statement::Undef(typ, id)
+ }
+ }
+ }
+}
+
struct ResolvedCall<P: ast::ArgParams> {
pub uniform: bool,
pub ret_params: Vec<(spirv::Word, ast::FnArgumentType)>,
@@ -3106,7 +3386,7 @@ enum Directive<'input> {
}
struct Function<'input> {
- pub func_directive: ast::MethodDecl<'input, spirv::Word>,
+ pub func_decl: ast::MethodDecl<'input, spirv::Word>,
pub globals: Vec<ExpandedStatement>,
pub body: Option<Vec<ExpandedStatement>>,
}
@@ -3546,18 +3826,28 @@ impl ast::Type {
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: Vec::new(),
+ state_space: ast::PointerStateSpace::Global,
},
ast::Type::Vector(scalar, components) => TypeParts {
kind: TypeKind::Vector,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: vec![*components as u32],
+ state_space: ast::PointerStateSpace::Global,
},
ast::Type::Array(scalar, components) => TypeParts {
kind: TypeKind::Array,
scalar_kind: scalar.kind(),
width: scalar.size_of(),
components: components.clone(),
+ state_space: ast::PointerStateSpace::Global,
+ },
+ ast::Type::Pointer(scalar, state_space) => TypeParts {
+ kind: TypeKind::Pointer,
+ scalar_kind: scalar.kind(),
+ width: scalar.size_of(),
+ components: Vec::new(),
+ state_space: *state_space,
},
}
}
@@ -3575,6 +3865,10 @@ impl ast::Type {
ast::ScalarType::from_parts(t.width, t.scalar_kind),
t.components,
),
+ TypeKind::Pointer => ast::Type::Pointer(
+ ast::ScalarType::from_parts(t.width, t.scalar_kind),
+ t.state_space,
+ ),
}
}
}
@@ -3585,6 +3879,7 @@ struct TypeParts {
scalar_kind: ScalarKind,
width: u8,
components: Vec<u32>,
+ state_space: ast::PointerStateSpace,
}
#[derive(Eq, PartialEq, Copy, Clone)]
@@ -3592,6 +3887,7 @@ enum TypeKind {
Scalar,
Vector,
Array,
+ Pointer,
}
impl ast::Instruction<ExpandedArgParams> {
@@ -3762,6 +4058,36 @@ impl ast::VariableParamType {
(ast::ScalarType::from(*t).size_of() as usize)
* (len.iter().fold(1, |x, y| x * (*y)) as usize)
}
+ ast::VariableParamType::Pointer(_, _) => mem::size_of::<usize>()
+ }
+ }
+}
+
+impl ast::KernelArgumentType {
+ fn width(&self) -> usize {
+ match self {
+ ast::KernelArgumentType::Normal(t) => t.width(),
+ ast::KernelArgumentType::Shared => mem::size_of::<usize>(),
+ }
+ }
+}
+
+impl From<ast::KernelArgumentType> for ast::Type {
+ fn from(this: ast::KernelArgumentType) -> Self {
+ match this {
+ ast::KernelArgumentType::Normal(typ) => typ.into(),
+ ast::KernelArgumentType::Shared => ast::Type::Scalar(ast::ScalarType::B64),
+ }
+ }
+}
+
+impl ast::KernelArgumentType {
+ fn to_param(self) -> ast::VariableParamType {
+ match self {
+ ast::KernelArgumentType::Normal(p) => p,
+ ast::KernelArgumentType::Shared => {
+ ast::VariableParamType::Scalar(ast::ParamScalarType::B64)
+ }
}
}
}
@@ -4598,6 +4924,7 @@ impl From<ast::FnArgumentType> for ast::VariableType {
match t {
ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t),
ast::FnArgumentType::Param(t) => ast::VariableType::Param(t),
+ ast::FnArgumentType::Shared => todo!(),
}
}
}
@@ -4648,6 +4975,17 @@ fn bitcast_physical_pointer(
ss: Option<ast::LdStateSpace>,
) -> Result<Option<ConversionKind>, TranslateError> {
match operand_type {
+ // array decays to a pointer
+ ast::Type::Array(_, vec) => {
+ if vec.len() != 0 {
+ return Err(TranslateError::MismatchedType);
+ }
+ if let Some(space) = ss {
+ Ok(Some(ConversionKind::BitToPtr(space)))
+ } else {
+ Err(TranslateError::Unreachable)
+ }
+ }
ast::Type::Scalar(ast::ScalarType::B64)
| ast::Type::Scalar(ast::ScalarType::U64)
| ast::Type::Scalar(ast::ScalarType::S64) => {
@@ -4882,7 +5220,10 @@ impl<'a> ast::MethodDecl<'a, spirv::Word> {
f(&ast::FnArgument {
align: arg.align,
name: arg.name,
- v_type: ast::FnArgumentType::Param(arg.v_type.clone()),
+ v_type: match arg.v_type.clone() {
+ ast::KernelArgumentType::Normal(typ) => ast::FnArgumentType::Param(typ),
+ ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared,
+ },
array_init: arg.array_init.clone(),
})
}),