diff options
author | Andrzej Janik <[email protected]> | 2020-11-18 01:50:29 +0100 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2020-11-18 01:50:29 +0100 |
commit | ac45c2bde42cef74d68b572920b0749f8eea6837 (patch) | |
tree | 1a425dcce73e5caafcb891e68e843ea3072282f2 | |
parent | 106ed74cb79b09a24964967ed7a9b4b13aff1b98 (diff) | |
download | ZLUDA-ac45c2bde42cef74d68b572920b0749f8eea6837.tar.gz ZLUDA-ac45c2bde42cef74d68b572920b0749f8eea6837.zip |
Optimize operations involving the statefuls
-rw-r--r-- | ptx/src/ast.rs | 6 | ||||
-rw-r--r-- | ptx/src/translate.rs | 129 |
2 files changed, 135 insertions, 0 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs index 5a5f6be..367f060 100644 --- a/ptx/src/ast.rs +++ b/ptx/src/ast.rs @@ -766,6 +766,8 @@ sub_type! { LdStType { Scalar(LdStScalarType), Vector(LdStScalarType, u8), + // Used in generated code + Pointer(PointerType, LdStateSpace), } } @@ -774,6 +776,10 @@ impl From<LdStType> for PointerType { match t { LdStType::Scalar(t) => PointerType::Scalar(t.into()), LdStType::Vector(t, len) => PointerType::Vector(t.into(), len), + LdStType::Pointer(PointerType::Scalar(scalar_type), space) => { + PointerType::Pointer(scalar_type, space) + } + LdStType::Pointer(..) => unreachable!(), } } } diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs index f644a27..29fa93e 100644 --- a/ptx/src/translate.rs +++ b/ptx/src/translate.rs @@ -4397,6 +4397,105 @@ fn convert_to_stateful_memory_access<'a>( result.push(Statement::Variable(var));
}
}
+ Statement::Instruction(ast::Instruction::Cvta(
+ ast::CvtaDetails {
+ to: ast::CvtaStateSpace::Global,
+ size: ast::CvtaSize::U64,
+ from: ast::CvtaStateSpace::Generic,
+ },
+ arg,
+ )) if is_cvta_ptr_direct(&remapped_ids, &arg) => {
+ let new_dst = *remapped_ids.get(&arg.dst).unwrap();
+ let new_src = *remapped_ids.get(&arg.src.underlying().unwrap()).unwrap();
+ result.push(Statement::Instruction(ast::Instruction::Mov(
+ ast::MovDetails {
+ typ: ast::Type::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ src_is_address: false,
+ dst_width: 0,
+ src_width: 0,
+ relaxed_src2_conv: false,
+ },
+ ast::Arg2Mov::Normal(ast::Arg2MovNormal {
+ dst: ast::IdOrVector::Reg(new_dst),
+ src: ast::OperandOrVector::Reg(new_src),
+ }),
+ )));
+ }
+ Statement::Instruction(ast::Instruction::Ld(
+ details
+ @
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Param,
+ ..
+ },
+ arg,
+ )) if is_param_ld_ptr_direct(&remapped_ids, &func_args_ptr, &arg) => {
+ let new_dst = if let ast::IdOrVector::Reg(dst) = arg.dst {
+ *remapped_ids.get(&dst).unwrap()
+ } else {
+ return Err(TranslateError::Unreachable);
+ };
+ result.push(Statement::Instruction(ast::Instruction::Ld(
+ ast::LdDetails {
+ typ: ast::LdStType::Pointer(
+ ast::PointerType::Scalar(ast::ScalarType::U8),
+ ast::LdStateSpace::Global,
+ ),
+ ..details
+ },
+ ast::Arg2Ld {
+ src: arg.src,
+ dst: ast::IdOrVector::Reg(new_dst),
+ },
+ )));
+ }
+ Statement::Instruction(ast::Instruction::Ld(
+ details
+ @
+ ast::LdDetails {
+ state_space: ast::LdStateSpace::Global,
+ ..
+ },
+ arg,
+ )) if is_ldst_global_ptr_direct(&remapped_ids, &arg.src) => {
+ let new_src = if let ast::Operand::Reg(src) = arg.src {
+ *remapped_ids.get(&src).unwrap()
+ } else {
+ return Err(TranslateError::Unreachable);
+ };
+ result.push(Statement::Instruction(ast::Instruction::Ld(
+ details,
+ ast::Arg2Ld {
+ src: ast::Operand::Reg(new_src),
+ ..arg
+ },
+ )));
+ }
+ Statement::Instruction(ast::Instruction::St(
+ details
+ @
+ ast::StData {
+ state_space: ast::StStateSpace::Global,
+ ..
+ },
+ arg,
+ )) if is_ldst_global_ptr_direct(&remapped_ids, &arg.src1) => {
+ let new_src1 = if let ast::Operand::Reg(src1) = arg.src1 {
+ *remapped_ids.get(&src1).unwrap()
+ } else {
+ return Err(TranslateError::Unreachable);
+ };
+ result.push(Statement::Instruction(ast::Instruction::St(
+ details,
+ ast::Arg2St {
+ src1: ast::Operand::Reg(new_src1),
+ ..arg
+ },
+ )));
+ }
Statement::Instruction(inst) => {
let mut post_statements = Vec::new();
let new_statement =
@@ -4489,6 +4588,36 @@ fn convert_to_stateful_memory_access<'a>( Ok(result)
}
+fn is_param_ld_ptr_direct(
+ remapped_ids: &HashMap<u32, u32>,
+ func_args_ptr: &HashSet<u32>,
+ arg: &ast::Arg2Ld<TypedArgParams>,
+) -> bool {
+ match (arg.src.underlying(), &arg.dst) {
+ (Some(src), ast::IdOrVector::Reg(dst)) => {
+ func_args_ptr.contains(src) && remapped_ids.contains_key(dst)
+ }
+ _ => false,
+ }
+}
+
+fn is_cvta_ptr_direct(remapped_ids: &HashMap<u32, u32>, arg: &ast::Arg2<TypedArgParams>) -> bool {
+ match arg.src.underlying() {
+ Some(src) => remapped_ids.contains_key(src) && remapped_ids.contains_key(&arg.dst),
+ None => false,
+ }
+}
+
+fn is_ldst_global_ptr_direct(
+ remapped_ids: &HashMap<u32, u32>,
+ src: &ast::Operand<spirv::Word>,
+) -> bool {
+ match src.underlying() {
+ Some(src) => remapped_ids.contains_key(src),
+ None => false,
+ }
+}
+
fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool {
match id_defs.get_typed(id) {
Ok((ast::Type::Scalar(ast::ScalarType::U64), _))
|