aboutsummaryrefslogtreecommitdiffhomepage
path: root/ptx/src
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-09-18 02:25:20 +0200
committerAndrzej Janik <[email protected]>2020-09-18 02:25:20 +0200
commit952ed5d5049462c60abf4149ee0ddbcb9cdb8cdc (patch)
tree138e409c3ec519602bed55521efdacdfd3d86963 /ptx/src
parent42bad8fcc22d3fd66bcdbfea7ce9a41268772e50 (diff)
downloadZLUDA-952ed5d5049462c60abf4149ee0ddbcb9cdb8cdc.tar.gz
ZLUDA-952ed5d5049462c60abf4149ee0ddbcb9cdb8cdc.zip
[BROKEN] Start implementing better support for addressable arguments
Diffstat (limited to 'ptx/src')
-rw-r--r--ptx/src/ast.rs13
-rw-r--r--ptx/src/ptx.lalrpop13
-rw-r--r--ptx/src/test/spirv_run/mod.rs2
-rw-r--r--ptx/src/test/spirv_run/ntid.ptx23
-rw-r--r--ptx/src/test/spirv_run/ntid.spvtxt56
-rw-r--r--ptx/src/test/spirv_run/reg_slm.ptx26
-rw-r--r--ptx/src/test/spirv_run/reg_slm.spvtxt46
-rw-r--r--ptx/src/translate.rs250
8 files changed, 350 insertions, 79 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 9214944..7ac9d18 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -354,6 +354,7 @@ pub struct CallInst<P: ArgParams> {
pub trait ArgParams {
type ID;
type Operand;
+ type MemoryOperand;
type CallOperand;
type VecOperand;
}
@@ -365,6 +366,7 @@ pub struct ParsedArgParams<'a> {
impl<'a> ArgParams for ParsedArgParams<'a> {
type ID = &'a str;
type Operand = Operand<&'a str>;
+ type MemoryOperand = Operand<&'a str>;
type CallOperand = CallOperand<&'a str>;
type VecOperand = (&'a str, u8);
}
@@ -378,8 +380,13 @@ pub struct Arg2<P: ArgParams> {
pub src: P::Operand,
}
+pub struct Arg2Ld<P: ArgParams> {
+ pub dst: P::ID,
+ pub src: P::MemoryOperand,
+}
+
pub struct Arg2St<P: ArgParams> {
- pub src1: P::Operand,
+ pub src1: P::MemoryOperand,
pub src2: P::Operand,
}
@@ -416,13 +423,13 @@ pub struct Arg5<P: ArgParams> {
pub enum Operand<ID> {
Reg(ID),
RegOffset(ID, i32),
- Imm(i128),
+ Imm(u32),
}
#[derive(Copy, Clone)]
pub enum CallOperand<ID> {
Reg(ID),
- Imm(i128),
+ Imm(u32),
}
pub enum VectorPrefix {
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 1ffbca2..44f29a5 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -446,7 +446,7 @@ Instruction: ast::Instruction<ast::ParsedArgParams<'input>> = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-ld
InstLd: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:ExtendedID> "," "[" <src:Operand> "]" => {
+ "ld" <q:LdStQualifier?> <ss:LdStateSpace?> <cop:LdCacheOperator?> <t:LdStType> <dst:ExtendedID> "," <src:MemoryOperand> => {
ast::Instruction::Ld(
ast::LdData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
@@ -899,7 +899,7 @@ ShlType: ast::ShlType = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-st
// Warning: NVIDIA documentation is incorrect, you can specify scope only once
InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
- "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> "[" <src1:Operand> "]" "," <src2:Operand> => {
+ "st" <q:LdStQualifier?> <ss:StStateSpace?> <cop:StCacheOperator?> <t:LdStType> <src1:MemoryOperand> "," <src2:Operand> => {
ast::Instruction::St(
ast::StData {
qualifier: q.unwrap_or(ast::LdStQualifier::Weak),
@@ -912,6 +912,11 @@ InstSt: ast::Instruction<ast::ParsedArgParams<'input>> = {
}
};
+// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#using-addresses-arrays-and-vectors
+MemoryOperand: ast::Operand<&'input str> = {
+ "[" <o:Operand> "]" => o
+}
+
StStateSpace: ast::StStateSpace = {
".global" => ast::StStateSpace::Global,
".local" => ast::StStateSpace::Local,
@@ -1006,7 +1011,7 @@ Operand: ast::Operand<&'input str> = {
// TODO: start parsing whole constants sub-language:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#constants
<o:Num> => {
- let offset = o.parse::<i128>();
+ let offset = o.parse::<u32>();
let offset = offset.unwrap_with(errors);
ast::Operand::Imm(offset)
}
@@ -1015,7 +1020,7 @@ Operand: ast::Operand<&'input str> = {
CallOperand: ast::CallOperand<&'input str> = {
<r:ExtendedID> => ast::CallOperand::Reg(r),
<o:Num> => {
- let offset = o.parse::<i128>();
+ let offset = o.parse::<u32>();
let offset = offset.unwrap_with(errors);
ast::CallOperand::Imm(offset)
}
diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs
index f1c3194..d251f77 100644
--- a/ptx/src/test/spirv_run/mod.rs
+++ b/ptx/src/test/spirv_run/mod.rs
@@ -59,6 +59,8 @@ test_ptx!(local_align, [1u64], [1u64]);
test_ptx!(call, [1u64], [2u64]);
test_ptx!(vector, [1u32, 2u32], [3u32, 3u32]);
test_ptx!(ld_st_offset, [1u32, 2u32], [2u32, 1u32]);
+test_ptx!(ntid, [3u32], [4u32]);
+test_ptx!(reg_slm, [12u64], [12u64]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/test/spirv_run/ntid.ptx b/ptx/src/test/spirv_run/ntid.ptx
new file mode 100644
index 0000000..2961197
--- /dev/null
+++ b/ptx/src/test/spirv_run/ntid.ptx
@@ -0,0 +1,23 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry ntid(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u32 in_val;
+ .reg .u32 global_count;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u32 in_val, [in_addr];
+ mov.u32 global_count, %ntid.x;
+ add.u32 in_val, in_val, global_count;
+ st.u32 [out_addr], in_val;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/ntid.spvtxt b/ptx/src/test/spirv_run/ntid.spvtxt
new file mode 100644
index 0000000..ef308f0
--- /dev/null
+++ b/ptx/src/test/spirv_run/ntid.spvtxt
@@ -0,0 +1,56 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %29 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "add" %GlobalSize
+ OpDecorate %GlobalSize BuiltIn GlobalSize
+ %void = OpTypeVoid
+ %uint = OpTypeInt 32 0
+ %v3uint = OpTypeVector %uint 3
+%_ptr_UniformConstant_v3uint = OpTypePointer UniformConstant %v3uint
+ %GlobalSize = OpVariable %_ptr_UniformConstant_v3uint UniformConstant
+ %ulong = OpTypeInt 64 0
+ %35 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Function_uint = OpTypePointer Function %uint
+%_ptr_Generic_uint = OpTypePointer Generic %uint
+ %1 = OpFunction %void None %35
+ %9 = OpFunctionParameter %ulong
+ %10 = OpFunctionParameter %ulong
+ %27 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_uint Function
+ %7 = OpVariable %_ptr_Function_uint Function
+ OpStore %2 %9
+ OpStore %3 %10
+ %12 = OpLoad %ulong %2
+ %11 = OpCopyObject %ulong %12
+ OpStore %4 %11
+ %14 = OpLoad %ulong %3
+ %13 = OpCopyObject %ulong %14
+ OpStore %5 %13
+ %16 = OpLoad %ulong %4
+ %25 = OpConvertUToPtr %_ptr_Generic_uint %16
+ %15 = OpLoad %uint %25
+ OpStore %6 %15
+ %18 = OpLoad %v3uint %GlobalSize
+ %24 = OpCompositeExtract %uint %18 0
+ %17 = OpCopyObject %uint %24
+ OpStore %7 %17
+ %20 = OpLoad %uint %6
+ %21 = OpLoad %uint %7
+ %19 = OpIAdd %uint %20 %21
+ OpStore %6 %19
+ %22 = OpLoad %ulong %5
+ %23 = OpLoad %uint %6
+ %26 = OpConvertUToPtr %_ptr_Generic_uint %22
+ OpStore %26 %23
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/test/spirv_run/reg_slm.ptx b/ptx/src/test/spirv_run/reg_slm.ptx
new file mode 100644
index 0000000..929d116
--- /dev/null
+++ b/ptx/src/test/spirv_run/reg_slm.ptx
@@ -0,0 +1,26 @@
+.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry reg_slm(
+ .param .u64 input,
+ .param .u64 output
+)
+{
+ .local .align 8 .b8 slm[8];
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .b64 temp;
+ .reg .s64 unused;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ mov.s64 unused, slm;
+
+ ld.global.u64 temp, [in_addr];
+ st.u64 [slm], temp;
+ ld.u64 temp, [slm];
+ st.global.u64 [out_addr], temp;
+ ret;
+} \ No newline at end of file
diff --git a/ptx/src/test/spirv_run/reg_slm.spvtxt b/ptx/src/test/spirv_run/reg_slm.spvtxt
new file mode 100644
index 0000000..6810fec
--- /dev/null
+++ b/ptx/src/test/spirv_run/reg_slm.spvtxt
@@ -0,0 +1,46 @@
+ OpCapability GenericPointer
+ OpCapability Linkage
+ OpCapability Addresses
+ OpCapability Kernel
+ OpCapability Int64
+ OpCapability Int8
+ %25 = OpExtInstImport "OpenCL.std"
+ OpMemoryModel Physical64 OpenCL
+ OpEntryPoint Kernel %1 "add"
+ %void = OpTypeVoid
+ %ulong = OpTypeInt 64 0
+ %28 = OpTypeFunction %void %ulong %ulong
+%_ptr_Function_ulong = OpTypePointer Function %ulong
+%_ptr_Generic_ulong = OpTypePointer Generic %ulong
+ %ulong_1 = OpConstant %ulong 1
+ %1 = OpFunction %void None %28
+ %8 = OpFunctionParameter %ulong
+ %9 = OpFunctionParameter %ulong
+ %23 = OpLabel
+ %2 = OpVariable %_ptr_Function_ulong Function
+ %3 = OpVariable %_ptr_Function_ulong Function
+ %4 = OpVariable %_ptr_Function_ulong Function
+ %5 = OpVariable %_ptr_Function_ulong Function
+ %6 = OpVariable %_ptr_Function_ulong Function
+ %7 = OpVariable %_ptr_Function_ulong Function
+ OpStore %2 %8
+ OpStore %3 %9
+ %11 = OpLoad %ulong %2
+ %10 = OpCopyObject %ulong %11
+ OpStore %4 %10
+ %13 = OpLoad %ulong %3
+ %12 = OpCopyObject %ulong %13
+ OpStore %5 %12
+ %15 = OpLoad %ulong %4
+ %21 = OpConvertUToPtr %_ptr_Generic_ulong %15
+ %14 = OpLoad %ulong %21
+ OpStore %6 %14
+ %17 = OpLoad %ulong %6
+ %16 = OpIAdd %ulong %17 %ulong_1
+ OpStore %7 %16
+ %18 = OpLoad %ulong %5
+ %19 = OpLoad %ulong %7
+ %22 = OpConvertUToPtr %_ptr_Generic_ulong %18
+ OpStore %22 %19
+ OpReturn
+ OpFunctionEnd
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index f5c0ecb..45372f1 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -286,7 +286,7 @@ fn expand_kernel_params<'a, 'b>(
args: impl Iterator<Item = &'b ast::KernelArgument<ast::ParsedArgParams<'a>>>,
) -> Vec<ast::KernelArgument<ExpandedArgParams>> {
args.map(|a| ast::KernelArgument {
- name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
+ name: fn_resolver.add_def(a.name, Some((StateSpace::Param, ast::Type::from(a.v_type)))),
v_type: a.v_type,
align: a.align,
})
@@ -297,10 +297,16 @@ fn expand_fn_params<'a, 'b>(
fn_resolver: &mut FnStringIdResolver<'a, 'b>,
args: impl Iterator<Item = &'b ast::FnArgument<ast::ParsedArgParams<'a>>>,
) -> Vec<ast::FnArgument<ExpandedArgParams>> {
- args.map(|a| ast::FnArgument {
- name: fn_resolver.add_def(a.name, Some(ast::Type::from(a.v_type))),
- v_type: a.v_type,
- align: a.align,
+ args.map(|a| {
+ let ss = match a.v_type {
+ ast::FnArgumentType::Reg(_) => StateSpace::Reg,
+ ast::FnArgumentType::Param(_) => StateSpace::Param,
+ };
+ ast::FnArgument {
+ name: fn_resolver.add_def(a.name, Some((ss, ast::Type::from(a.v_type)))),
+ v_type: a.v_type,
+ align: a.align,
+ }
})
.collect()
}
@@ -325,6 +331,8 @@ fn to_ssa<'input, 'b>(
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
let unadorned_statements =
add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs);
+ todo!()
+ /*
let (f_args, ssa_statements) =
insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args);
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs);
@@ -336,6 +344,7 @@ fn to_ssa<'input, 'b>(
func_directive: f_args,
body: Some(sorted_statements),
}
+ */
}
fn normalize_variable_decls(mut func: Vec<ExpandedStatement>) -> Vec<ExpandedStatement> {
@@ -350,7 +359,7 @@ fn add_types_to_statements(
func: Vec<UnadornedStatement>,
fn_defs: &GlobalFnDeclResolver,
id_defs: &NumericIdResolver,
-) -> Vec<UnadornedStatement> {
+) -> Vec<TypedStatement> {
func.into_iter()
.map(|s| {
match s {
@@ -359,7 +368,7 @@ fn add_types_to_statements(
let fn_def = fn_defs.get_fn_decl(call.func);
let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
let param_list = to_resolved_fn_args(call.param_list, &*fn_def.params);
- let resolved_call = ResolvedCall {
+ let resolved_call: ResolvedCall<TypedArgParams> = ResolvedCall {
uniform: call.uniform,
ret_params,
func: call.func,
@@ -367,18 +376,13 @@ fn add_types_to_statements(
};
Statement::Call(resolved_call)
}
+ Statement::Instruction(ast::Instruction::Ld(d, arg)) => {
+ todo!()
+ }
Statement::Instruction(ast::Instruction::MovVector(dets, args)) => {
- // TODO fail on type mismatch
- let new_dets = match id_defs.get_type(*args.dst()) {
- Some(ast::Type::Vector(_, len)) => ast::MovVectorDetails {
- length: len,
- ..dets
- },
- _ => dets,
- };
- Statement::Instruction(ast::Instruction::MovVector(new_dets, args))
+ todo!()
}
- s => s,
+ s => todo!(),
}
})
.collect()
@@ -485,7 +489,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
ast::MethodDecl::Kernel(_, in_params) => {
for p in in_params.iter_mut() {
let typ = ast::Type::from(p.v_type);
- let new_id = id_def.new_id(Some(typ));
+ let new_id = id_def.new_id(Some((StateSpace::Param, typ)));
result.push(Statement::Variable(ast::Variable {
align: p.align,
v_type: ast::VariableType::Param(p.v_type),
@@ -504,8 +508,12 @@ fn insert_mem_ssa_statements<'a, 'b>(
}
ast::MethodDecl::Func(out_params, _, in_params) => {
for p in in_params.iter_mut() {
+ let ss = match p.v_type {
+ ast::FnArgumentType::Reg(_) => StateSpace::Reg,
+ ast::FnArgumentType::Param(_) => StateSpace::Param,
+ };
let typ = ast::Type::from(p.v_type);
- let new_id = id_def.new_id(Some(typ));
+ let new_id = id_def.new_id(Some((ss, typ)));
let var_typ = ast::VariableType::from(p.v_type);
result.push(Statement::Variable(ast::Variable {
align: p.align,
@@ -548,7 +556,7 @@ fn insert_mem_ssa_statements<'a, 'b>(
dst: new_id,
src: out_param,
},
- typ.unwrap(),
+ typ.unwrap().1,
));
result.push(Statement::RetValue(d, new_id));
} else {
@@ -558,7 +566,10 @@ fn insert_mem_ssa_statements<'a, 'b>(
inst => insert_mem_ssa_statement_default(id_def, &mut result, inst),
},
Statement::Conditional(mut bra) => {
- let generated_id = id_def.new_id(Some(ast::Type::Scalar(ast::ScalarType::Pred)));
+ let generated_id = id_def.new_id(Some((
+ StateSpace::Reg,
+ ast::Type::Scalar(ast::ScalarType::Pred),
+ )));
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
@@ -607,11 +618,12 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
let mut post_statements = Vec::new();
let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, _| {
let id_type = match (id_def.get_type(desc.op), desc.sema) {
- (Some(t), ArgumentSemantics::ParamPtr) | (Some(t), ArgumentSemantics::Default) => t,
- (Some(t), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64),
+ (Some((_, t)), ArgumentSemantics::ParamPtr)
+ | (Some((_, t)), ArgumentSemantics::Default) => t,
+ (Some((_, t)), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64),
(None, _) => return desc.op,
};
- let generated_id = id_def.new_id(Some(id_type));
+ let generated_id = id_def.new_id(Some((StateSpace::Reg, id_type)));
if !desc.is_dst {
result.push(Statement::LoadVar(
Arg2 {
@@ -716,11 +728,13 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
} else {
todo!()
};
- let id = self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
+ let id = self
+ .id_def
+ .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id,
typ: scalar_t,
- value: x,
+ value: x as i64,
}));
id
}
@@ -732,13 +746,14 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
} else {
todo!()
};
- let id_constant_stmt =
- self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
- let result_id = self.id_def.new_id(Some(typ));
+ let id_constant_stmt = self
+ .id_def
+ .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
+ let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ)));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: scalar_t,
- value: offset as i128,
+ value: offset as i64,
}));
let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
self.func.push(Statement::Instruction(
@@ -758,13 +773,14 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
}
ArgumentSemantics::Ptr => {
let scalar_t = ast::ScalarType::U64;
- let id_constant_stmt =
- self.id_def.new_id(Some(ast::Type::Scalar(scalar_t)));
- let result_id = self.id_def.new_id(Some(typ));
+ let id_constant_stmt = self
+ .id_def
+ .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
+ let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ)));
self.func.push(Statement::Constant(ConstantDefinition {
dst: id_constant_stmt,
typ: scalar_t,
- value: offset as i128,
+ value: offset as i64,
}));
let int_type = ast::IntType::U64;
self.func.push(Statement::Instruction(
@@ -810,9 +826,10 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
desc: ArgumentDescriptor<(spirv::Word, u8)>,
(scalar_type, vec_len): (ast::MovVectorType, u8),
) -> spirv::Word {
- let new_id = self
- .id_def
- .new_id(Some(ast::Type::Vector(scalar_type.into(), vec_len)));
+ let new_id = self.id_def.new_id(Some((
+ StateSpace::Reg,
+ ast::Type::Vector(scalar_type.into(), vec_len),
+ )));
self.func.push(Statement::Composite(CompositeRead {
typ: scalar_type,
dst: new_id,
@@ -821,6 +838,14 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
}));
new_id
}
+
+ fn mov_operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ typ: ast::Type,
+ ) -> spirv::Word {
+ self.operand(desc, typ)
+ }
}
/*
@@ -911,7 +936,7 @@ fn insert_implicit_conversions(
let mut did_vector_implicit = false;
let mut post_conv = None;
if inst_typ_is_bit {
- let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!());
+ let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!()).1;
if let ast::Type::Vector(_, _) = src_type {
arg.src = insert_conversion_src(
&mut result,
@@ -923,7 +948,7 @@ fn insert_implicit_conversions(
);
did_vector_implicit = true;
}
- let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!());
+ let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!()).1;
if let ast::Type::Vector(_, _) = src_type {
post_conv = Some(get_conversion_dst(
id_def,
@@ -1615,25 +1640,32 @@ fn expand_map_variables<'a, 'b>(
p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))),
i.map_variable(&mut |id| id_defs.get_id(id)),
))),
- ast::Statement::Variable(var) => match var.count {
- Some(count) => {
- for new_id in id_defs.add_defs(var.var.name, count, var.var.v_type.into()) {
+ ast::Statement::Variable(var) => {
+ let ss = match var.var.v_type {
+ ast::VariableType::Reg(_) => StateSpace::Reg,
+ ast::VariableType::Local(_) => StateSpace::Local,
+ ast::VariableType::Param(_) => StateSpace::ParamReg,
+ };
+ match var.count {
+ Some(count) => {
+ for new_id in id_defs.add_defs(var.var.name, count, ss, var.var.v_type.into()) {
+ result.push(Statement::Variable(ast::Variable {
+ align: var.var.align,
+ v_type: var.var.v_type,
+ name: new_id,
+ }))
+ }
+ }
+ None => {
+ let new_id = id_defs.add_def(var.var.name, Some((ss, var.var.v_type.into())));
result.push(Statement::Variable(ast::Variable {
align: var.var.align,
v_type: var.var.v_type,
name: new_id,
- }))
+ }));
}
}
- None => {
- let new_id = id_defs.add_def(var.var.name, Some(var.var.v_type.into()));
- result.push(Statement::Variable(ast::Variable {
- align: var.var.align,
- v_type: var.var.v_type,
- name: new_id,
- }));
- }
- },
+ }
}
}
@@ -1766,7 +1798,7 @@ struct FnStringIdResolver<'input, 'b> {
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
- type_check: HashMap<u32, ast::Type>,
+ type_check: HashMap<u32, (StateSpace, ast::Type)>,
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
@@ -1809,7 +1841,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
}
}
- fn add_def(&mut self, id: &'a str, typ: Option<ast::Type>) -> spirv::Word {
+ fn add_def(&mut self, id: &'a str, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
let numeric_id = *self.current_id;
self.variables
.last_mut()
@@ -1827,6 +1859,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
&mut self,
base_id: &'a str,
count: u32,
+ ss: StateSpace,
typ: ast::Type,
) -> impl Iterator<Item = spirv::Word> {
let numeric_id = *self.current_id;
@@ -1835,7 +1868,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
.last_mut()
.unwrap()
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
- self.type_check.insert(numeric_id + i, typ);
+ self.type_check.insert(numeric_id + i, (ss, typ));
}
*self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
@@ -1844,15 +1877,15 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
- type_check: HashMap<u32, ast::Type>,
+ type_check: HashMap<u32, (StateSpace, ast::Type)>,
}
impl<'b> NumericIdResolver<'b> {
- fn get_type(&self, id: spirv::Word) -> Option<ast::Type> {
+ fn get_type(&self, id: spirv::Word) -> Option<(StateSpace, ast::Type)> {
self.type_check.get(&id).map(|x| *x)
}
- fn new_id(&mut self, typ: Option<ast::Type>) -> spirv::Word {
+ fn new_id(&mut self, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
let new_id = *self.current_id;
if let Some(typ) = typ {
self.type_check.insert(new_id, typ);
@@ -1982,9 +2015,20 @@ type UnadornedStatement = Statement<ast::Instruction<NormalizedArgParams>, Norma
impl ast::ArgParams for NormalizedArgParams {
type ID = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
+ type MemoryOperand = ast::Operand<spirv::Word>;
+ type CallOperand = ast::CallOperand<spirv::Word>;
+ type VecOperand = (spirv::Word, u8);
+}
+
+enum TypedArgParams {}
+impl ast::ArgParams for TypedArgParams {
+ type ID = spirv::Word;
+ type Operand = ast::Operand<spirv::Word>;
+ type MemoryOperand = MemoryOperand;
type CallOperand = ast::CallOperand<spirv::Word>;
type VecOperand = (spirv::Word, u8);
}
+type TypedStatement = Statement<ast::Instruction<TypedArgParams>, TypedArgParams>;
impl ArgParamsEx for NormalizedArgParams {
fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl {
@@ -1992,6 +2036,27 @@ impl ArgParamsEx for NormalizedArgParams {
}
}
+#[derive(Copy, Clone)]
+pub enum StateSpace {
+ Reg,
+ Sreg,
+ Const,
+ Global,
+ Local,
+ Shared,
+ Param,
+ ParamReg,
+}
+
+#[derive(Copy, Clone)]
+pub enum MemoryOperand {
+ Reg(spirv::Word),
+ Address(spirv::Word),
+ RegOffset(spirv::Word, i32),
+ AddressOffset(spirv::Word, i32),
+ Imm(u32),
+}
+
enum ExpandedArgParams {}
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>;
@@ -1999,6 +2064,7 @@ type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStateme
impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word;
type Operand = spirv::Word;
+ type MemoryOperand = spirv::Word;
type CallOperand = spirv::Word;
type VecOperand = spirv::Word;
}
@@ -2012,6 +2078,11 @@ impl ArgParamsEx for ExpandedArgParams {
trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
fn variable(&mut self, desc: ArgumentDescriptor<T::ID>, typ: Option<ast::Type>) -> U::ID;
fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>, typ: ast::Type) -> U::Operand;
+ fn mov_operand(
+ &mut self,
+ desc: ArgumentDescriptor<T::MemoryOperand>,
+ typ: ast::Type,
+ ) -> U::MemoryOperand;
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<T::CallOperand>,
@@ -2035,9 +2106,15 @@ where
) -> spirv::Word {
self(desc, t)
}
+
fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
self(desc, Some(t))
}
+
+ fn mov_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
+ self(desc, Some(t))
+ }
+
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
@@ -2045,6 +2122,7 @@ where
) -> spirv::Word {
self(desc, Some(t))
}
+
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
@@ -2095,6 +2173,14 @@ where
) -> (spirv::Word, u8) {
(self(desc.op.0), desc.op.1)
}
+
+ fn mov_operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::Operand<&str>>,
+ typ: ast::Type,
+ ) -> ast::Operand<spirv::Word> {
+ self.operand(desc, typ)
+ }
}
struct ArgumentDescriptor<Op> {
@@ -2260,6 +2346,16 @@ where
desc.op.1,
)
}
+
+ fn mov_operand(
+ &mut self,
+ desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ typ: ast::Type,
+ ) -> ast::Operand<spirv::Word> {
+ <Self as ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams>>::operand(
+ self, desc, typ,
+ )
+ }
}
impl ast::Type {
@@ -2365,7 +2461,7 @@ struct CompositeRead {
struct ConstantDefinition {
pub dst: spirv::Word,
pub typ: ast::ScalarType,
- pub value: i128,
+ pub value: i64,
}
struct BrachCondition {
@@ -2534,7 +2630,7 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
is_param: bool,
) -> ast::Arg2St<U> {
ast::Arg2St {
- src1: visitor.operand(
+ src1: visitor.mov_operand(
ArgumentDescriptor {
op: self.src1,
is_dst: is_param,
@@ -3012,6 +3108,16 @@ impl From<ast::FnArgumentType> for ast::VariableType {
}
}
+impl<T> ast::Operand<T> {
+ fn underlying(&self) -> Option<&T> {
+ match self {
+ ast::Operand::Reg(r) => Some(r),
+ ast::Operand::RegOffset(r, _) => Some(r),
+ ast::Operand::Imm(_) => None,
+ }
+ }
+}
+
fn should_bitcast(instr: ast::Type, operand: ast::Type) -> bool {
match (instr, operand) {
(ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => {
@@ -3053,7 +3159,7 @@ fn insert_with_conversions<T, ToInstruction: FnOnce(T) -> ast::Instruction<Expan
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_src, &mut src);
insert_with_conversions_pre_conv(func, id_def, &mut instr, pre_conv_dst, &mut dst);
if post_conv.len() > 0 {
- let new_id = id_def.new_id(Some(post_conv[0].from));
+ let new_id = id_def.new_id(Some((StateSpace::Reg, post_conv[0].from)));
post_conv[0].src = new_id;
post_conv.last_mut().unwrap().dst = *dst(&mut instr);
*dst(&mut instr) = new_id;
@@ -3078,7 +3184,7 @@ fn insert_with_conversions_pre_conv<T>(
conv.src = *original_src;
}
if i == pre_conv_len - 1 {
- let new_id = id_def.new_id(Some(conv.to));
+ let new_id = id_def.new_id(Some((StateSpace::Reg, conv.to)));
conv.dst = new_id;
*original_src = new_id;
}
@@ -3095,7 +3201,7 @@ fn get_implicit_conversions_ld_dst<
should_convert: ShouldConvert,
in_reverse: bool,
) -> Option<ImplicitConversion> {
- let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!());
+ let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!()).1;
if let Some(conv) = should_convert(dst_type, instr_type) {
Some(ImplicitConversion {
src: u32::max_value(),
@@ -3115,7 +3221,7 @@ fn get_implicit_conversions_ld_src(
state_space: ast::LdStateSpace,
src: spirv::Word,
) -> Vec<ImplicitConversion> {
- let src_type = id_def.get_type(src).unwrap_or_else(|| todo!());
+ let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1;
match state_space {
ast::LdStateSpace::Param => {
if src_type != instr_type {
@@ -3162,7 +3268,7 @@ fn get_implicit_conversions_ld_src(
kind: ConversionKind::Ptr(state_space),
});
if result.len() == 2 {
- let new_id = id_def.new_id(Some(new_src_type));
+ let new_id = id_def.new_id(Some((StateSpace::Reg, new_src_type)));
result[0].dst = new_id;
result[1].src = new_id;
result[1].from = new_src_type;
@@ -3221,9 +3327,9 @@ fn insert_implicit_conversions_ld_src_impl<
src: spirv::Word,
should_convert: ShouldConvert,
) -> spirv::Word {
- let src_type = id_def.get_type(src);
- if let Some(conv) = should_convert(src_type.unwrap(), instr_type) {
- insert_conversion_src(func, id_def, src, src_type.unwrap(), instr_type, conv)
+ let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1;
+ if let Some(conv) = should_convert(src_type, instr_type) {
+ insert_conversion_src(func, id_def, src, src_type, instr_type, conv)
} else {
src
}
@@ -3263,7 +3369,7 @@ fn insert_conversion_src(
instr_type: ast::Type,
conv: ConversionKind,
) -> spirv::Word {
- let temp_src = id_def.new_id(Some(instr_type));
+ let temp_src = id_def.new_id(Some((StateSpace::Reg, instr_type)));
func.push(Statement::Conversion(ImplicitConversion {
src: src,
dst: temp_src,
@@ -3309,7 +3415,7 @@ fn get_conversion_dst(
kind: ConversionKind,
) -> ExpandedStatement {
let original_dst = *dst;
- let temp_dst = id_def.new_id(Some(instr_type));
+ let temp_dst = id_def.new_id(Some((StateSpace::Reg, instr_type)));
*dst = temp_dst;
Statement::Conversion(ImplicitConversion {
src: temp_dst,
@@ -3428,8 +3534,8 @@ fn insert_implicit_bitcasts(
Some(t) => t,
None => return desc.op,
};
- let id_actual_type = id_def.get_type(desc.op).unwrap();
- if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap()) {
+ let id_actual_type = id_def.get_type(desc.op).unwrap().1;
+ if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap().1) {
if desc.is_dst {
dst_coercion = Some(get_conversion_dst(
id_def,