diff options
author | Andrzej Janik <[email protected]> | 2021-03-03 00:59:47 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2021-03-03 00:59:47 +0100 |
commit | cdac38d572e2cd86036cbd85f753214b8e1a5172 (patch) | |
tree | e7cd1536e7a5e90b548f4159a7b196fb8e9326ce /ptx | |
parent | 648035a01a84cd87b7f917b277e4e2faad7bb731 (diff) | |
download | ZLUDA-cdac38d572e2cd86036cbd85f753214b8e1a5172.tar.gz ZLUDA-cdac38d572e2cd86036cbd85f753214b8e1a5172.zip |
Support kernel tuning directives
Diffstat (limited to 'ptx')
-rw-r--r-- | ptx/src/ast.rs | 9 | ||||
-rw-r--r-- | ptx/src/ptx.lalrpop | 18 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/add_tuning.ptx | 24 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/add_tuning.spvtxt | 48 | ||||
-rw-r--r-- | ptx/src/test/spirv_run/mod.rs | 1 | ||||
-rw-r--r-- | ptx/src/translate.rs | 65 |
6 files changed, 148 insertions, 17 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 8bbd1d7..22d378e 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -283,6 +283,7 @@ pub type KernelArgument<ID> = Variable<KernelArgumentType, ID>; pub struct Function<'a, ID, S> { pub func_directive: MethodDecl<'a, ID>, + pub tuning: Vec<TuningDirective>, pub body: Option<Vec<S>>, } @@ -1369,6 +1370,14 @@ bitflags! { } } +#[derive(Copy, Clone, PartialEq, Eq)] +pub enum TuningDirective { + MaxNReg(u32), + MaxNtid(u32, u32, u32), + ReqNtid(u32, u32, u32), + MinNCtaPerSm(u32), +} + #[cfg(test)] mod tests { use super::*; diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop index ce3e387..631d5ad 100644 --- a/ptx/src/ptx.lalrpop +++ b/ptx/src/ptx.lalrpop @@ -87,6 +87,9 @@ match { ".ltu", ".lu", ".max", + ".maxnreg", + ".maxntid", + ".minnctapersm", ".min", ".nan", ".NaN", @@ -100,6 +103,7 @@ match { ".reg", ".relaxed", ".release", + ".reqntid", ".rm", ".rmi", ".rn", @@ -356,15 +360,27 @@ AddressSize = { Function: ast::Function<'input, &'input str, ast::Statement<ast::ParsedArgParams<'input>>> = { LinkingDirectives <func_directive:MethodDecl> + <tuning:TuningDirective*> <body:FunctionBody> => ast::Function{<>} }; - + LinkingDirective: ast::LinkingDirective = { ".extern" => ast::LinkingDirective::EXTERN, ".visible" => ast::LinkingDirective::VISIBLE, ".weak" => ast::LinkingDirective::WEAK, }; +TuningDirective: ast::TuningDirective = { + ".maxnreg" <ncta:U32Num> => ast::TuningDirective::MaxNReg(ncta), + ".maxntid" <nx:U32Num> => ast::TuningDirective::MaxNtid(nx, 1, 1), + ".maxntid" <nx:U32Num> "," <ny:U32Num> => ast::TuningDirective::MaxNtid(nx, ny, 1), + ".maxntid" <nx:U32Num> "," <ny:U32Num> "," <nz:U32Num> => ast::TuningDirective::MaxNtid(nx, ny, nz), + ".reqntid" <nx:U32Num> => ast::TuningDirective::ReqNtid(nx, 1, 1), + ".reqntid" <nx:U32Num> "," <ny:U32Num> => ast::TuningDirective::ReqNtid(nx, ny, 1), + ".reqntid" <nx:U32Num> "," <ny:U32Num> "," <nz:U32Num> => ast::TuningDirective::ReqNtid(nx, ny, nz), + ".minnctapersm" <ncta:U32Num> => ast::TuningDirective::MinNCtaPerSm(ncta), +}; + LinkingDirectives: ast::LinkingDirective = { <ldirs:LinkingDirective*> => { ldirs.into_iter().fold(ast::LinkingDirective::NONE, |x, y| x | y) diff --git a/ptx/src/test/spirv_run/add_tuning.ptx b/ptx/src/test/spirv_run/add_tuning.ptx new file mode 100644 index 0000000..2a5dcf8 --- /dev/null +++ b/ptx/src/test/spirv_run/add_tuning.ptx @@ -0,0 +1,24 @@ +.version 6.5
+.target sm_30
+.address_size 64
+
+.visible .entry add_tuning(
+ .param .u64 input,
+ .param .u64 output
+)
+.maxntid 256, 1, 1
+.minnctapersm 4
+{
+ .reg .u64 in_addr;
+ .reg .u64 out_addr;
+ .reg .u64 temp;
+ .reg .u64 temp2;
+
+ ld.param.u64 in_addr, [input];
+ ld.param.u64 out_addr, [output];
+
+ ld.u64 temp, [in_addr];
+ add.u64 temp2, temp, 1;
+ st.u64 [out_addr], temp2;
+ ret;
+}
diff --git a/ptx/src/test/spirv_run/add_tuning.spvtxt b/ptx/src/test/spirv_run/add_tuning.spvtxt new file mode 100644 index 0000000..173e0d4 --- /dev/null +++ b/ptx/src/test/spirv_run/add_tuning.spvtxt @@ -0,0 +1,48 @@ + OpCapability GenericPointer + OpCapability Linkage + OpCapability Addresses + OpCapability Kernel + OpCapability Int8 + OpCapability Int16 + OpCapability Int64 + OpCapability Float16 + OpCapability Float64 + %23 = OpExtInstImport "OpenCL.std" + OpMemoryModel Physical64 OpenCL + OpEntryPoint Kernel %1 "add_tuning" + OpExecutionMode %1 MaxWorkgroupSizeINTEL 256 1 1 + %void = OpTypeVoid + %ulong = OpTypeInt 64 0 + %26 = OpTypeFunction %void %ulong %ulong +%_ptr_Function_ulong = OpTypePointer Function %ulong +%_ptr_Generic_ulong = OpTypePointer Generic %ulong + %ulong_1 = OpConstant %ulong 1 + %1 = OpFunction %void None %26 + %8 = OpFunctionParameter %ulong + %9 = OpFunctionParameter %ulong + %21 = OpLabel + %2 = OpVariable %_ptr_Function_ulong Function + %3 = OpVariable %_ptr_Function_ulong Function + %4 = OpVariable %_ptr_Function_ulong Function + %5 = OpVariable %_ptr_Function_ulong Function + %6 = OpVariable %_ptr_Function_ulong Function + %7 = OpVariable %_ptr_Function_ulong Function + OpStore %2 %8 + OpStore %3 %9 + %10 = OpLoad %ulong %2 Aligned 8 + OpStore %4 %10 + %11 = OpLoad %ulong %3 Aligned 8 + OpStore %5 %11 + %13 = OpLoad %ulong %4 + %19 = OpConvertUToPtr %_ptr_Generic_ulong %13 + %12 = OpLoad %ulong %19 Aligned 8 + OpStore %6 %12 + %15 = OpLoad %ulong %6 + %14 = OpIAdd %ulong %15 %ulong_1 + OpStore %7 %14 + %16 = OpLoad %ulong %5 + %17 = OpLoad %ulong %7 + %20 = OpConvertUToPtr %_ptr_Generic_ulong %16 + OpStore %20 %17 Aligned 8 + OpReturn + OpFunctionEnd diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 91e6113..4178e2f 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -152,6 +152,7 @@ test_ptx!(shared_ptr_take_address, [97815231u64], [97815231u64]); // For now, we just make sure that it builds and links
test_ptx!(assertfail, [716523871u64], [716523872u64]);
test_ptx!(cvt_s64_s32, [-1i32], [-1i64]);
+test_ptx!(add_tuning, [2u64], [3u64]);
struct DisplayError<T: Debug> {
err: T,
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index 7efcaf6..da0cc07 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -589,7 +589,7 @@ fn emit_directives<'input>( for var in f.globals.iter() {
emit_variable(builder, map, var)?;
}
- emit_function_header(
+ let fn_id = emit_function_header(
builder,
map,
&id_defs,
@@ -600,6 +600,27 @@ fn emit_directives<'input>( &directives,
kernel_info,
)?;
+ for t in f.tuning.iter() {
+ match *t {
+ ast::TuningDirective::MaxNtid(nx, ny, nz) => {
+ builder.execution_mode(
+ fn_id,
+ spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL,
+ [nx, ny, nz],
+ );
+ }
+ ast::TuningDirective::ReqNtid(nx, ny, nz) => {
+ builder.execution_mode(
+ fn_id,
+ spirv_headers::ExecutionMode::LocalSize,
+ [nx, ny, nz],
+ );
+ }
+ // Too architecture specific
+ ast::TuningDirective::MaxNReg(..)
+ | ast::TuningDirective::MinNCtaPerSm(..) => {}
+ }
+ }
emit_function_body_ops(builder, map, opencl_id, &f_body)?;
builder.end_function()?;
if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) =
@@ -729,6 +750,7 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements),
import_as,
spirv_decl,
+ tuning,
}) => {
let call_key = MethodName::new(&func_decl);
let statements = statements
@@ -752,6 +774,7 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements),
import_as,
spirv_decl,
+ tuning,
})
}
directive => directive,
@@ -770,6 +793,7 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements),
import_as,
mut spirv_decl,
+ tuning,
}) => {
if !methods_using_extern_shared.contains(&spirv_decl.name) {
return Directive::Method(Function {
@@ -778,6 +802,7 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(statements),
import_as,
spirv_decl,
+ tuning,
});
}
let shared_id_param = new_id();
@@ -827,6 +852,7 @@ fn convert_dynamic_shared_memory_usage<'input>( body: Some(new_statements),
import_as,
spirv_decl,
+ tuning,
})
}
directive => directive,
@@ -1044,9 +1070,7 @@ fn emit_builtins( builder.decorate(
id,
spirv::Decoration::BuiltIn,
- [dr::Operand::BuiltIn(reg.get_builtin())]
- .iter()
- .cloned(),
+ [dr::Operand::BuiltIn(reg.get_builtin())].iter().cloned(),
);
}
}
@@ -1061,7 +1085,7 @@ fn emit_function_header<'a>( call_map: &HashMap<&'a str, HashSet<spirv::Word>>,
direcitves: &[Directive],
kernel_info: &mut HashMap<String, KernelInfo>,
-) -> Result<(), TranslateError> {
+) -> Result<spirv::Word, TranslateError> {
if let MethodName::Kernel(name) = func_decl.name {
let input_args = if !func_decl.uses_shared_mem {
func_decl.input.as_slice()
@@ -1143,7 +1167,7 @@ fn emit_function_header<'a>( let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone()));
builder.function_parameter(Some(input.name), result_type)?;
}
- Ok(())
+ Ok(fn_id)
}
fn emit_capabilities(builder: &mut dr::Builder) {
@@ -1235,7 +1259,14 @@ fn translate_function<'a>( _ => None,
};
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?;
- let mut func = to_ssa(ptx_impl_imports, str_resolver, fn_resolver, fn_decl, f.body)?;
+ let mut func = to_ssa(
+ ptx_impl_imports,
+ str_resolver,
+ fn_resolver,
+ fn_decl,
+ f.body,
+ f.tuning,
+ )?;
func.import_as = import_as;
if func.import_as.is_some() {
ptx_impl_imports.insert(
@@ -1293,6 +1324,7 @@ fn to_ssa<'input, 'b>( fn_defs: GlobalFnDeclResolver<'input, 'b>,
f_args: ast::MethodDecl<'input, spirv::Word>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
+ tuning: Vec<ast::TuningDirective>,
) -> Result<Function<'input>, TranslateError> {
let mut spirv_decl = SpirvMethodDecl::new(&f_args);
let f_body = match f_body {
@@ -1304,6 +1336,7 @@ fn to_ssa<'input, 'b>( globals: Vec::new(),
import_as: None,
spirv_decl,
+ tuning,
})
}
};
@@ -1335,6 +1368,7 @@ fn to_ssa<'input, 'b>( body: Some(f_body),
import_as: None,
spirv_decl,
+ tuning,
})
}
@@ -1716,6 +1750,7 @@ fn to_ptx_impl_atomic_call( body: None,
import_as: Some(entry.key().clone()),
spirv_decl,
+ tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
@@ -1809,6 +1844,7 @@ fn to_ptx_impl_bfe_call( body: None,
import_as: Some(entry.key().clone()),
spirv_decl,
+ tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
@@ -1907,6 +1943,7 @@ fn to_ptx_impl_bfi_call( body: None,
import_as: Some(entry.key().clone()),
spirv_decl,
+ tuning: Vec::new(),
};
entry.insert(Directive::Method(func));
fn_id
@@ -4112,16 +4149,11 @@ fn struct2_bitcast_to_wide( dst_type_id: spirv::Word,
src: spirv::Word,
) -> Result<(), dr::Error> {
- let low_bits =
- builder.composite_extract(instruction_type, None, src, [0].iter().copied())?;
- let high_bits =
- builder.composite_extract(instruction_type, None, src, [1].iter().copied())?;
+ let low_bits = builder.composite_extract(instruction_type, None, src, [0].iter().copied())?;
+ let high_bits = builder.composite_extract(instruction_type, None, src, [1].iter().copied())?;
let vector_type = map.get_or_add(builder, SpirvType::Vector(base_type_key, 2));
- let vector = builder.composite_construct(
- vector_type,
- None,
- [low_bits, high_bits].iter().copied(),
- )?;
+ let vector =
+ builder.composite_construct(vector_type, None, [low_bits, high_bits].iter().copied())?;
builder.bitcast(dst_type_id, Some(dst), vector)?;
Ok(())
}
@@ -5668,6 +5700,7 @@ struct Function<'input> { pub globals: Vec<ast::Variable<ast::VariableType, spirv::Word>>,
pub body: Option<Vec<ExpandedStatement>>,
import_as: Option<String>,
+ tuning: Vec<ast::TuningDirective>,
}
pub trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
|