From c92abba2bb884a4dba8ca5e3df4d46a30878f27e Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Mon, 23 Sep 2024 16:33:46 +0200 Subject: Refactor compilation passes (#270) The overarching goal is to refactor all passes so they are module-scoped and not function-scoped. Additionally, make improvements to the most egregiously buggy/unfit passes (so the code is ready for the next major features: linking, ftz handling) and continue adding more code to the LLVM backend --- .github/workflows/rust.yml | 60 --- ptx/Cargo.toml | 3 + ptx/src/pass/convert_to_stateful_memory_access.rs | 2 +- ptx/src/pass/deparamize_functions.rs | 141 +++++++ ptx/src/pass/emit_llvm.rs | 224 ++++++++---- ptx/src/pass/emit_spirv.rs | 3 +- ptx/src/pass/expand_arguments.rs | 4 +- ptx/src/pass/expand_operands.rs | 289 +++++++++++++++ ptx/src/pass/extract_globals.rs | 1 - ptx/src/pass/fix_special_registers2.rs | 209 +++++++++++ ptx/src/pass/hoist_globals.rs | 45 +++ ptx/src/pass/insert_explicit_load_store.rs | 338 +++++++++++++++++ ptx/src/pass/insert_implicit_conversions.rs | 26 +- ptx/src/pass/insert_implicit_conversions2.rs | 426 ++++++++++++++++++++++ ptx/src/pass/insert_mem_ssa_statements.rs | 2 +- ptx/src/pass/mod.rs | 365 +++++++++++++++++- ptx/src/pass/normalize_identifiers2.rs | 199 ++++++++++ ptx/src/pass/normalize_labels.rs | 1 + ptx/src/pass/normalize_predicates2.rs | 84 +++++ ptx/src/pass/resolve_function_pointers.rs | 82 +++++ ptx/src/test/spirv_run/mod.rs | 2 +- ptx_parser/src/ast.rs | 30 +- ptx_parser/src/lib.rs | 1 - 23 files changed, 2365 insertions(+), 172 deletions(-) delete mode 100644 .github/workflows/rust.yml create mode 100644 ptx/src/pass/deparamize_functions.rs create mode 100644 ptx/src/pass/expand_operands.rs create mode 100644 ptx/src/pass/fix_special_registers2.rs create mode 100644 ptx/src/pass/hoist_globals.rs create mode 100644 ptx/src/pass/insert_explicit_load_store.rs create mode 100644 ptx/src/pass/insert_implicit_conversions2.rs create mode 100644 ptx/src/pass/normalize_identifiers2.rs create mode 100644 ptx/src/pass/normalize_predicates2.rs create mode 100644 ptx/src/pass/resolve_function_pointers.rs diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml deleted file mode 100644 index 26ede14..0000000 --- a/.github/workflows/rust.yml +++ /dev/null @@ -1,60 +0,0 @@ -name: Rust - -on: - push: - branches: [ master ] - pull_request: - branches: [ master ] - -env: - CARGO_TERM_COLOR: always - -jobs: - build_lin: - name: Build and publish (Linux) - runs-on: ubuntu-20.04 - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - name: Install GPU drivers - run: | - sudo apt-get install -y gpg-agent wget - wget -qO - https://repositories.intel.com/graphics/intel-graphics.key | sudo apt-key add - - sudo apt-add-repository 'deb [arch=amd64] https://repositories.intel.com/graphics/ubuntu focal main' - sudo apt-get update - sudo apt-get install intel-opencl-icd intel-level-zero-gpu level-zero intel-media-va-driver-non-free libmfx1 libigc-dev intel-igc-cm libigdfcl-dev libigfxcmrt-dev level-zero-dev ocl-icd-opencl-dev - - name: Build - run: cargo build --workspace --verbose --release - - name: Rename to libcuda.so - run: | - mv target/release/libnvcuda.so target/release/libcuda.so - ln -s libcuda.so target/release/libcuda.so.1 - - uses: actions/upload-artifact@v2 - with: - name: Linux - path: | - target/release/libcuda.so - target/release/libcuda.so.1 - target/release/libnvml.so - build_win: - name: Build and publish (Windows) - runs-on: windows-latest - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - name: Build - run: cargo build --workspace --verbose --release - - uses: actions/upload-artifact@v2 - with: - name: Windows - path: | - target/release/nvcuda.dll - target/release/nvml.dll - target/release/zluda_redirect.dll - target/release/zluda_with.exe - target/release/zluda_dump.dll - # TODO(take-cheeze): Support testing - # - name: Run tests - # run: cargo test --verbose diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index 2e2995f..e2c4ff8 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -17,6 +17,9 @@ thiserror = "1.0" bit-vec = "0.6" half ="1.6" bitflags = "1.2" +rustc-hash = "2.0.0" +strum = "0.26" +strum_macros = "0.26" [dependencies.lalrpop-util] version = "0.19.12" diff --git a/ptx/src/pass/convert_to_stateful_memory_access.rs b/ptx/src/pass/convert_to_stateful_memory_access.rs index 455a8c2..3b8fa93 100644 --- a/ptx/src/pass/convert_to_stateful_memory_access.rs +++ b/ptx/src/pass/convert_to_stateful_memory_access.rs @@ -489,7 +489,7 @@ fn convert_to_stateful_memory_access_postprocess( let (old_operand_type, old_operand_space, _) = id_defs.get_typed(operand)?; let converting_id = id_defs .register_intermediate(Some((old_operand_type.clone(), old_operand_space))); - let kind = if space_is_compatible(new_operand_space, ast::StateSpace::Reg) { + let kind = if new_operand_space == ast::StateSpace::Reg { ConversionKind::Default } else { ConversionKind::PtrToPtr diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs new file mode 100644 index 0000000..04c8831 --- /dev/null +++ b/ptx/src/pass/deparamize_functions.rs @@ -0,0 +1,141 @@ +use std::collections::BTreeMap; + +use super::*; + +pub(super) fn run<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2, + directive: Directive2<'input, ast::Instruction, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), + }) +} + +fn run_method<'input>( + resolver: &mut GlobalStringIdentResolver2, + mut method: Function2<'input, ast::Instruction, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + if method.func_decl.name.is_kernel() { + return Ok(method); + } + let is_declaration = method.body.is_none(); + let mut body = Vec::new(); + let mut remap_returns = Vec::new(); + for arg in method.func_decl.return_arguments.iter_mut() { + match arg.state_space { + ptx_parser::StateSpace::Param => { + arg.state_space = ptx_parser::StateSpace::Reg; + let old_name = arg.name; + arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); + if is_declaration { + continue; + } + remap_returns.push((old_name, arg.name, arg.v_type.clone())); + body.push(Statement::Variable(ast::Variable { + align: None, + name: old_name, + v_type: arg.v_type.clone(), + state_space: ptx_parser::StateSpace::Param, + array_init: Vec::new(), + })); + } + ptx_parser::StateSpace::Reg => {} + _ => return Err(error_unreachable()), + } + } + for arg in method.func_decl.input_arguments.iter_mut() { + match arg.state_space { + ptx_parser::StateSpace::Param => { + arg.state_space = ptx_parser::StateSpace::Reg; + let old_name = arg.name; + arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space))); + if is_declaration { + continue; + } + body.push(Statement::Variable(ast::Variable { + align: None, + name: old_name, + v_type: arg.v_type.clone(), + state_space: ptx_parser::StateSpace::Param, + array_init: Vec::new(), + })); + body.push(Statement::Instruction(ast::Instruction::St { + data: ast::StData { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Param, + caching: ast::StCacheOperator::Writethrough, + typ: arg.v_type.clone(), + }, + arguments: ast::StArgs { + src1: old_name, + src2: arg.name, + }, + })); + } + ptx_parser::StateSpace::Reg => {} + _ => return Err(error_unreachable()), + } + } + if remap_returns.is_empty() { + return Ok(method); + } + let body = method + .body + .map(|statements| { + for statement in statements { + run_statement(&remap_returns, &mut body, statement)?; + } + Ok::<_, TranslateError>(body) + }) + .transpose()?; + Ok(Function2 { + func_decl: method.func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'input>( + remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>, + result: &mut Vec, SpirvWord>>, + statement: Statement, SpirvWord>, +) -> Result<(), TranslateError> { + match statement { + Statement::Instruction(ast::Instruction::Ret { .. }) => { + for (old_name, new_name, type_) in remap_returns.iter().cloned() { + result.push(Statement::Instruction(ast::Instruction::Ld { + data: ast::LdDetails { + qualifier: ast::LdStQualifier::Weak, + state_space: ast::StateSpace::Reg, + caching: ast::LdCacheOperator::Cached, + typ: type_, + non_coherent: false, + }, + arguments: ast::LdArgs { + dst: new_name, + src: old_name, + }, + })); + } + result.push(statement); + } + statement => { + result.push(statement); + } + } + Ok(()) +} diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index 44debba..235ad7d 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -164,17 +164,16 @@ impl Deref for MemoryBuffer { } pub(super) fn run<'input>( - id_defs: &GlobalStringIdResolver<'input>, - call_map: MethodsCallMap<'input>, - directives: Vec>, + id_defs: GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, ) -> Result { let context = Context::new(); let module = Module::new(&context, LLVM_UNNAMED); - let mut emit_ctx = ModuleEmitContext::new(&context, &module, id_defs); + let mut emit_ctx = ModuleEmitContext::new(&context, &module, &id_defs); for directive in directives { match directive { - Directive::Variable(..) => todo!(), - Directive::Method(method) => emit_ctx.emit_method(method)?, + Directive2::Variable(..) => todo!(), + Directive2::Method(method) => emit_ctx.emit_method(method)?, } } module.write_to_stderr(); @@ -188,7 +187,7 @@ struct ModuleEmitContext<'a, 'input> { context: LLVMContextRef, module: LLVMModuleRef, builder: Builder, - id_defs: &'a GlobalStringIdResolver<'input>, + id_defs: &'a GlobalStringIdentResolver2<'input>, resolver: ResolveIdent, } @@ -196,7 +195,7 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { fn new( context: &Context, module: &Module, - id_defs: &'a GlobalStringIdResolver<'input>, + id_defs: &'a GlobalStringIdentResolver2<'input>, ) -> Self { ModuleEmitContext { context: context.get(), @@ -215,26 +214,50 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { LLVMCallConv::LLVMCCallConv as u32 } - fn emit_method(&mut self, method: Function<'input>) -> Result<(), TranslateError> { - let func_decl = method.func_decl.borrow(); + fn emit_method( + &mut self, + method: Function2<'input, ast::Instruction, SpirvWord>, + ) -> Result<(), TranslateError> { + let func_decl = method.func_decl; let name = method .import_as .as_deref() - .unwrap_or_else(|| match func_decl.name { - ast::MethodName::Kernel(name) => name, - ast::MethodName::Func(id) => self.id_defs.reverse_variables[&id], - }); + .or_else(|| match func_decl.name { + ast::MethodName::Kernel(name) => Some(name), + ast::MethodName::Func(id) => self.id_defs.ident_map[&id].name.as_deref(), + }) + .ok_or_else(|| error_unreachable())?; let name = CString::new(name).map_err(|_| error_unreachable())?; - let fn_type = self.function_type( + let fn_type = get_function_type( + self.context, func_decl.return_arguments.iter().map(|v| &v.v_type), - func_decl.input_arguments.iter().map(|v| &v.v_type), - ); + func_decl + .input_arguments + .iter() + .map(|v| get_input_argument_type(self.context, &v.v_type, v.state_space)), + )?; let fn_ = unsafe { LLVMAddFunction(self.module, name.as_ptr(), fn_type) }; + if let ast::MethodName::Func(name) = func_decl.name { + self.resolver.register(name, fn_); + } for (i, param) in func_decl.input_arguments.iter().enumerate() { let value = unsafe { LLVMGetParam(fn_, i as u32) }; let name = self.resolver.get_or_add(param.name); unsafe { LLVMSetValueName2(value, name.as_ptr().cast(), name.len()) }; self.resolver.register(param.name, value); + if func_decl.name.is_kernel() { + let attr_kind = unsafe { + LLVMGetEnumAttributeKindForName(b"byref".as_ptr().cast(), b"byref".len()) + }; + let attr = unsafe { + LLVMCreateTypeAttribute( + self.context, + attr_kind, + get_type(self.context, ¶m.v_type)?, + ) + }; + unsafe { LLVMAddAttributeAtIndex(fn_, i as u32 + 1, attr) }; + } } let call_conv = if func_decl.name.is_kernel() { Self::kernel_call_convention() @@ -258,66 +281,19 @@ impl<'a, 'input> ModuleEmitContext<'a, 'input> { } Ok(()) } +} - fn function_type( - &self, - return_args: impl ExactSizeIterator, - input_args: impl ExactSizeIterator, - ) -> LLVMTypeRef { - if return_args.len() == 0 { - let mut input_args = input_args - .map(|type_| match type_ { - ast::Type::Scalar(scalar) => match scalar { - ast::ScalarType::Pred => { - unsafe { LLVMInt1TypeInContext(self.context) } - } - ast::ScalarType::S8 | ast::ScalarType::B8 | ast::ScalarType::U8 => { - unsafe { LLVMInt8TypeInContext(self.context) } - } - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { - unsafe { LLVMInt16TypeInContext(self.context) } - } - ast::ScalarType::S32 | ast::ScalarType::B32 | ast::ScalarType::U32 => { - unsafe { LLVMInt32TypeInContext(self.context) } - } - ast::ScalarType::U64 | ast::ScalarType::S64 | ast::ScalarType::B64 => { - unsafe { LLVMInt64TypeInContext(self.context) } - } - ast::ScalarType::B128 => { - unsafe { LLVMInt128TypeInContext(self.context) } - } - ast::ScalarType::F16 => { - unsafe { LLVMHalfTypeInContext(self.context) } - } - ast::ScalarType::F32 => { - unsafe { LLVMFloatTypeInContext(self.context) } - } - ast::ScalarType::F64 => { - unsafe { LLVMDoubleTypeInContext(self.context) } - } - ast::ScalarType::BF16 => { - unsafe { LLVMBFloatTypeInContext(self.context) } - } - ast::ScalarType::U16x2 => todo!(), - ast::ScalarType::S16x2 => todo!(), - ast::ScalarType::F16x2 => todo!(), - ast::ScalarType::BF16x2 => todo!(), - }, - ast::Type::Vector(_, _) => todo!(), - ast::Type::Array(_, _, _) => todo!(), - ast::Type::Pointer(_, _) => todo!(), - }) - .collect::>(); - return unsafe { - LLVMFunctionType( - LLVMVoidTypeInContext(self.context), - input_args.as_mut_ptr(), - input_args.len() as u32, - 0, - ) - }; +fn get_input_argument_type( + context: LLVMContextRef, + v_type: &ptx_parser::Type, + state_space: ptx_parser::StateSpace, +) -> Result { + match state_space { + ptx_parser::StateSpace::ParamEntry => { + Ok(unsafe { LLVMPointerTypeInContext(context, get_state_space(state_space)?) }) } - todo!() + ptx_parser::StateSpace::Reg => get_type(context, v_type), + _ => return Err(error_unreachable()), } } @@ -326,7 +302,7 @@ struct MethodEmitContext<'a, 'input> { module: LLVMModuleRef, method: LLVMValueRef, builder: LLVMBuilderRef, - id_defs: &'a GlobalStringIdResolver<'input>, + id_defs: &'a GlobalStringIdentResolver2<'input>, variables_builder: Builder, resolver: &'a mut ResolveIdent, } @@ -365,6 +341,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Statement::PtrAccess(_) => todo!(), Statement::RepackVector(_) => todo!(), Statement::FunctionPointer(_) => todo!(), + Statement::VectorAccess(_) => todo!(), }) } @@ -414,7 +391,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { inst: ast::Instruction, ) -> Result<(), TranslateError> { match inst { - ast::Instruction::Mov { data, arguments } => todo!(), + ast::Instruction::Mov { data, arguments } => self.emit_mov(data, arguments), ast::Instruction::Ld { data, arguments } => self.emit_ld(data, arguments), ast::Instruction::Add { data, arguments } => self.emit_add(data, arguments), ast::Instruction::St { data, arguments } => self.emit_st(data, arguments), @@ -425,7 +402,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { ast::Instruction::Or { data, arguments } => todo!(), ast::Instruction::And { data, arguments } => todo!(), ast::Instruction::Bra { arguments } => todo!(), - ast::Instruction::Call { data, arguments } => todo!(), + ast::Instruction::Call { data, arguments } => self.emit_call(data, arguments), ast::Instruction::Cvt { data, arguments } => todo!(), ast::Instruction::Shr { data, arguments } => todo!(), ast::Instruction::Shl { data, arguments } => todo!(), @@ -563,6 +540,70 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { fn emit_ret(&self, _data: ptx_parser::RetData) { unsafe { LLVMBuildRetVoid(self.builder) }; } + + fn emit_call( + &mut self, + data: ptx_parser::CallDetails, + arguments: ptx_parser::CallArgs, + ) -> Result<(), TranslateError> { + if cfg!(debug_assertions) { + for (_, space) in data.return_arguments.iter() { + if *space != ast::StateSpace::Reg { + panic!() + } + } + for (_, space) in data.input_arguments.iter() { + if *space != ast::StateSpace::Reg { + panic!() + } + } + } + let name = match (&*data.return_arguments, &*arguments.return_arguments) { + ([], []) => LLVM_UNNAMED.as_ptr(), + ([(type_, _)], [dst]) => self.resolver.get_or_add_raw(*dst), + _ => todo!(), + }; + let type_ = get_function_type( + self.context, + data.return_arguments.iter().map(|(type_, space)| type_), + data.input_arguments + .iter() + .map(|(type_, space)| get_input_argument_type(self.context, &type_, *space)), + )?; + let mut input_arguments = arguments + .input_arguments + .iter() + .map(|arg| self.resolver.value(*arg)) + .collect::, _>>()?; + let llvm_fn = unsafe { + LLVMBuildCall2( + self.builder, + type_, + self.resolver.value(arguments.func)?, + input_arguments.as_mut_ptr(), + input_arguments.len() as u32, + name, + ) + }; + match &*arguments.return_arguments { + [] => {} + [name] => { + self.resolver.register(*name, llvm_fn); + } + _ => todo!(), + } + Ok(()) + } + + fn emit_mov( + &mut self, + _data: ptx_parser::MovDetails, + arguments: ptx_parser::MovArgs, + ) -> Result<(), TranslateError> { + self.resolver + .register(arguments.dst, self.resolver.value(arguments.src)?); + Ok(()) + } } fn get_pointer_type<'ctx>( @@ -624,13 +665,34 @@ fn get_scalar_type(context: LLVMContextRef, type_: ast::ScalarType) -> LLVMTypeR } } +fn get_function_type<'a>( + context: LLVMContextRef, + mut return_args: impl ExactSizeIterator, + input_args: impl ExactSizeIterator>, +) -> Result { + let mut input_args: Vec<*mut llvm_zluda::LLVMType> = + input_args.collect::, _>>()?; + let return_type = match return_args.len() { + 0 => unsafe { LLVMVoidTypeInContext(context) }, + 1 => get_type(context, return_args.next().unwrap())?, + _ => todo!(), + }; + Ok(unsafe { + LLVMFunctionType( + return_type, + input_args.as_mut_ptr(), + input_args.len() as u32, + 0, + ) + }) +} + fn get_state_space(space: ast::StateSpace) -> Result { match space { ast::StateSpace::Reg => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Generic => Ok(GENERIC_ADDRESS_SPACE), - ast::StateSpace::Sreg => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Param => Err(TranslateError::Todo), - ast::StateSpace::ParamEntry => Err(TranslateError::Todo), + ast::StateSpace::ParamEntry => Ok(CONSTANT_ADDRESS_SPACE), ast::StateSpace::ParamFunc => Err(TranslateError::Todo), ast::StateSpace::Local => Ok(PRIVATE_ADDRESS_SPACE), ast::StateSpace::Global => Ok(GLOBAL_ADDRESS_SPACE), @@ -647,7 +709,7 @@ struct ResolveIdent { } impl ResolveIdent { - fn new<'input>(_id_defs: &GlobalStringIdResolver<'input>) -> Self { + fn new<'input>(_id_defs: &GlobalStringIdentResolver2<'input>) -> Self { ResolveIdent { words: HashMap::new(), values: HashMap::new(), diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs index 5147b79..120a477 100644 --- a/ptx/src/pass/emit_spirv.rs +++ b/ptx/src/pass/emit_spirv.rs @@ -469,7 +469,6 @@ fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass { ast::StateSpace::Shared => spirv::StorageClass::Workgroup, ast::StateSpace::Param => spirv::StorageClass::Function, ast::StateSpace::Reg => spirv::StorageClass::Function, - ast::StateSpace::Sreg => spirv::StorageClass::Input, ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc | ast::StateSpace::SharedCluster @@ -693,7 +692,6 @@ fn emit_variable<'input>( ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant), ast::StateSpace::Generic => todo!(), - ast::StateSpace::Sreg => todo!(), ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc | ast::StateSpace::SharedCluster @@ -1563,6 +1561,7 @@ fn emit_function_body_ops<'input>( builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?; } } + Statement::VectorAccess(vector_access) => todo!(), } } Ok(()) diff --git a/ptx/src/pass/expand_arguments.rs b/ptx/src/pass/expand_arguments.rs index d0c7c98..e496c75 100644 --- a/ptx/src/pass/expand_arguments.rs +++ b/ptx/src/pass/expand_arguments.rs @@ -63,9 +63,9 @@ impl<'a, 'b> FlattenArguments<'a, 'b> { } else { return Err(TranslateError::UntypedSymbol); }; - if state_space == ast::StateSpace::Reg || state_space == ast::StateSpace::Sreg { + if state_space == ast::StateSpace::Reg { let (reg_type, reg_space) = self.id_def.get_typed(reg)?; - if !space_is_compatible(reg_space, ast::StateSpace::Reg) { + if reg_space != ast::StateSpace::Reg { return Err(error_mismatched_type()); } let reg_scalar_type = match reg_type { diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs new file mode 100644 index 0000000..3dabf40 --- /dev/null +++ b/ptx/src/pass/expand_operands.rs @@ -0,0 +1,289 @@ +use super::*; + +pub(super) fn run<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec>, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: Directive2< + 'input, + ast::Instruction>, + ast::ParsedOperand, + >, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + Directive2::Variable(linking, var) => Directive2::Variable(linking, var), + Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), + }) +} + +fn run_method<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + method: Function2< + 'input, + ast::Instruction>, + ast::ParsedOperand, + >, +) -> Result, SpirvWord>, TranslateError> { + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(resolver, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { + func_decl: method.func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + result: &mut Vec, SpirvWord>>, + statement: UnconditionalStatement, +) -> Result<(), TranslateError> { + let mut visitor = FlattenArguments::new(resolver, result); + let new_statement = statement.visit_map(&mut visitor)?; + visitor.result.push(new_statement); + Ok(()) +} + +struct FlattenArguments<'a, 'input> { + result: &'a mut Vec, + resolver: &'a mut GlobalStringIdentResolver2<'input>, + post_stmts: Vec, +} + +impl<'a, 'input> FlattenArguments<'a, 'input> { + fn new( + resolver: &'a mut GlobalStringIdentResolver2<'input>, + result: &'a mut Vec, + ) -> Self { + FlattenArguments { + result, + resolver, + post_stmts: Vec::new(), + } + } + + fn reg(&mut self, name: SpirvWord) -> Result { + Ok(name) + } + + fn reg_offset( + &mut self, + reg: SpirvWord, + offset: i32, + type_space: Option<(&ast::Type, ast::StateSpace)>, + _is_dst: bool, + ) -> Result { + let (type_, state_space) = if let Some((type_, state_space)) = type_space { + (type_, state_space) + } else { + return Err(TranslateError::UntypedSymbol); + }; + if state_space == ast::StateSpace::Reg { + let (reg_type, reg_space) = self.resolver.get_typed(reg)?; + if *reg_space != ast::StateSpace::Reg { + return Err(error_mismatched_type()); + } + let reg_scalar_type = match reg_type { + ast::Type::Scalar(underlying_type) => *underlying_type, + _ => return Err(error_mismatched_type()), + }; + let reg_type = reg_type.clone(); + let id_constant_stmt = self + .resolver + .register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: reg_scalar_type, + value: ast::ImmediateValue::S64(offset as i64), + })); + let arith_details = match reg_scalar_type.kind() { + ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, + saturate: false, + }), + ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => { + ast::ArithDetails::Integer(ast::ArithInteger { + type_: reg_scalar_type, + saturate: false, + }) + } + _ => return Err(error_unreachable()), + }; + let id_add_result = self + .resolver + .register_unnamed(Some((reg_type, state_space))); + self.result + .push(Statement::Instruction(ast::Instruction::Add { + data: arith_details, + arguments: ast::AddArgs { + dst: id_add_result, + src1: reg, + src2: id_constant_stmt, + }, + })); + Ok(id_add_result) + } else { + let id_constant_stmt = self.resolver.register_unnamed(Some(( + ast::Type::Scalar(ast::ScalarType::S64), + ast::StateSpace::Reg, + ))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: id_constant_stmt, + typ: ast::ScalarType::S64, + value: ast::ImmediateValue::S64(offset as i64), + })); + let dst = self + .resolver + .register_unnamed(Some((type_.clone(), state_space))); + self.result.push(Statement::PtrAccess(PtrAccess { + underlying_type: type_.clone(), + state_space: state_space, + dst, + ptr_src: reg, + offset_src: id_constant_stmt, + })); + Ok(dst) + } + } + + fn immediate( + &mut self, + value: ast::ImmediateValue, + type_space: Option<(&ast::Type, ast::StateSpace)>, + ) -> Result { + let (scalar_t, state_space) = + if let Some((ast::Type::Scalar(scalar), state_space)) = type_space { + (*scalar, state_space) + } else { + return Err(TranslateError::UntypedSymbol); + }; + let id = self + .resolver + .register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: id, + typ: scalar_t, + value, + })); + Ok(id) + } + + fn vec_member( + &mut self, + vector_src: SpirvWord, + member: u8, + _type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + ) -> Result { + if is_dst { + return Err(error_mismatched_type()); + } + let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? { + (ast::Type::Vector(vector_width, scalar_t), space) => { + (*vector_width, *scalar_t, *space) + } + _ => return Err(error_mismatched_type()), + }; + let temporary = self + .resolver + .register_unnamed(Some((scalar_type.into(), space))); + self.result.push(Statement::VectorAccess(VectorAccess { + scalar_type, + vector_width, + dst: temporary, + src: vector_src, + member: member, + })); + Ok(temporary) + } + + fn vec_pack( + &mut self, + vecs: Vec, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + let (scalar_t, state_space) = match type_space { + Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space), + _ => return Err(error_mismatched_type()), + }; + let temp_vec = self + .resolver + .register_unnamed(Some((scalar_t.into(), state_space))); + let statement = Statement::RepackVector(RepackVectorDetails { + is_extract: is_dst, + typ: scalar_t, + packed: temp_vec, + unpacked: vecs, + relaxed_type_check, + }); + if is_dst { + self.post_stmts.push(statement); + } else { + self.result.push(statement); + } + Ok(temp_vec) + } +} + +impl<'a, 'b> ast::VisitorMap, SpirvWord, TranslateError> + for FlattenArguments<'a, 'b> +{ + fn visit( + &mut self, + args: ast::ParsedOperand, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + match args { + ast::ParsedOperand::Reg(r) => self.reg(r), + ast::ParsedOperand::Imm(x) => self.immediate(x, type_space), + ast::ParsedOperand::RegOffset(reg, offset) => { + self.reg_offset(reg, offset, type_space, is_dst) + } + ast::ParsedOperand::VecMember(vec, member) => { + self.vec_member(vec, member, type_space, is_dst) + } + ast::ParsedOperand::VecPack(vecs) => { + self.vec_pack(vecs, type_space, is_dst, relaxed_type_check) + } + } + } + + fn visit_ident( + &mut self, + name: ::Ident, + _type_space: Option<(&ast::Type, ast::StateSpace)>, + _is_dst: bool, + _relaxed_type_check: bool, + ) -> Result<::Ident, TranslateError> { + self.reg(name) + } +} + +impl Drop for FlattenArguments<'_, '_> { + fn drop(&mut self) { + self.result.extend(self.post_stmts.drain(..)); + } +} diff --git a/ptx/src/pass/extract_globals.rs b/ptx/src/pass/extract_globals.rs index 680a5ee..2912366 100644 --- a/ptx/src/pass/extract_globals.rs +++ b/ptx/src/pass/extract_globals.rs @@ -273,7 +273,6 @@ fn space_to_ptx_name(this: ast::StateSpace) -> &'static str { ast::StateSpace::Const => "const", ast::StateSpace::Local => "local", ast::StateSpace::Param => "param", - ast::StateSpace::Sreg => "sreg", ast::StateSpace::SharedCluster => "shared_cluster", ast::StateSpace::ParamEntry => "param_entry", ast::StateSpace::SharedCta => "shared_cta", diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs new file mode 100644 index 0000000..97f6356 --- /dev/null +++ b/ptx/src/pass/fix_special_registers2.rs @@ -0,0 +1,209 @@ +use super::*; + +pub(super) fn run<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + special_registers: &'a SpecialRegistersMap2, + directives: Vec>, +) -> Result>, TranslateError> { + let declarations = SpecialRegistersMap2::generate_declarations(resolver); + let mut result = Vec::with_capacity(declarations.len() + directives.len()); + let mut sreg_to_function = + FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default()); + for (sreg, declaration) in declarations { + let name = if let ast::MethodName::Func(name) = declaration.name { + name + } else { + return Err(error_unreachable()); + }; + result.push(UnconditionalDirective::Method(UnconditionalFunction { + func_decl: declaration, + globals: Vec::new(), + body: None, + import_as: None, + tuning: Vec::new(), + linkage: ast::LinkingDirective::EXTERN, + })); + sreg_to_function.insert(sreg, name); + } + let mut visitor = SpecialRegisterResolver { + resolver, + special_registers, + sreg_to_function, + result: Vec::new(), + }; + directives + .into_iter() + .map(|directive| run_directive(&mut visitor, directive)) + .collect::, _>>() +} + +fn run_directive<'a, 'input>( + visitor: &mut SpecialRegisterResolver<'a, 'input>, + directive: UnconditionalDirective<'input>, +) -> Result, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?), + }) +} + +fn run_method<'a, 'input>( + visitor: &mut SpecialRegisterResolver<'a, 'input>, + method: UnconditionalFunction<'input>, +) -> Result, TranslateError> { + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(visitor, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { + func_decl: method.func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'a, 'input>( + visitor: &mut SpecialRegisterResolver<'a, 'input>, + result: &mut Vec, + statement: UnconditionalStatement, +) -> Result<(), TranslateError> { + let converted_statement = statement.visit_map(visitor)?; + result.extend(visitor.result.drain(..)); + result.push(converted_statement); + Ok(()) +} + +struct SpecialRegisterResolver<'a, 'input> { + resolver: &'a mut GlobalStringIdentResolver2<'input>, + special_registers: &'a SpecialRegistersMap2, + sreg_to_function: FxHashMap, + result: Vec, +} + +impl<'a, 'b, 'input> + ast::VisitorMap, ast::ParsedOperand, TranslateError> + for SpecialRegisterResolver<'a, 'input> +{ + fn visit( + &mut self, + operand: ast::ParsedOperand, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result, TranslateError> { + map_operand(operand, &mut |ident, vector_index| { + self.replace_sreg(ident, vector_index, is_dst) + }) + } + + fn visit_ident( + &mut self, + args: SpirvWord, + _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>, + is_dst: bool, + _relaxed_type_check: bool, + ) -> Result { + self.replace_sreg(args, None, is_dst) + } +} + +impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> { + fn replace_sreg( + &mut self, + name: SpirvWord, + vector_index: Option, + is_dst: bool, + ) -> Result { + if let Some(sreg) = self.special_registers.get(name) { + if is_dst { + return Err(error_mismatched_type()); + } + let input_arguments = match (vector_index, sreg.get_function_input_type()) { + (Some(idx), Some(inp_type)) => { + if inp_type != ast::ScalarType::U8 { + return Err(TranslateError::Unreachable); + } + let constant = self.resolver.register_unnamed(Some(( + ast::Type::Scalar(inp_type), + ast::StateSpace::Reg, + ))); + self.result.push(Statement::Constant(ConstantDefinition { + dst: constant, + typ: inp_type, + value: ast::ImmediateValue::U64(idx as u64), + })); + vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)] + } + (None, None) => Vec::new(), + _ => return Err(error_mismatched_type()), + }; + let return_type = sreg.get_function_return_type(); + let fn_result = self + .resolver + .register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg))); + let return_arguments = vec![( + fn_result, + ast::Type::Scalar(return_type), + ast::StateSpace::Reg, + )]; + let data = ast::CallDetails { + uniform: false, + return_arguments: return_arguments + .iter() + .map(|(_, typ, space)| (typ.clone(), *space)) + .collect(), + input_arguments: input_arguments + .iter() + .map(|(_, typ, space)| (typ.clone(), *space)) + .collect(), + }; + let arguments = ast::CallArgs::> { + return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(), + func: self.sreg_to_function[&sreg], + input_arguments: input_arguments + .iter() + .map(|(name, _, _)| ast::ParsedOperand::Reg(*name)) + .collect(), + }; + self.result + .push(Statement::Instruction(ast::Instruction::Call { + data, + arguments, + })); + Ok(fn_result) + } else { + Ok(name) + } + } +} + +pub fn map_operand( + this: ast::ParsedOperand, + fn_: &mut impl FnMut(T, Option) -> Result, +) -> Result, Err> { + Ok(match this { + ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?), + ast::ParsedOperand::RegOffset(ident, offset) => { + ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset) + } + ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm), + ast::ParsedOperand::VecMember(ident, member) => { + ast::ParsedOperand::Reg(fn_(ident, Some(member))?) + } + ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack( + idents + .into_iter() + .map(|ident| fn_(ident, None)) + .collect::, _>>()?, + ), + }) +} diff --git a/ptx/src/pass/hoist_globals.rs b/ptx/src/pass/hoist_globals.rs new file mode 100644 index 0000000..753172a --- /dev/null +++ b/ptx/src/pass/hoist_globals.rs @@ -0,0 +1,45 @@ +use super::*; + +pub(super) fn run<'input>( + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + let mut result = Vec::with_capacity(directives.len()); + for mut directive in directives.into_iter() { + run_directive(&mut result, &mut directive); + result.push(directive); + } + Ok(result) +} + +fn run_directive<'input>( + result: &mut Vec, SpirvWord>>, + directive: &mut Directive2<'input, ptx_parser::Instruction, SpirvWord>, +) -> Result<(), TranslateError> { + match directive { + Directive2::Variable(..) => {} + Directive2::Method(function2) => run_function(result, function2), + } + Ok(()) +} + +fn run_function<'input>( + result: &mut Vec, SpirvWord>>, + function: &mut Function2<'input, ptx_parser::Instruction, SpirvWord>, +) { + function.body = function.body.take().map(|statements| { + statements + .into_iter() + .filter_map(|statement| match statement { + Statement::Variable(var @ ast::Variable { + state_space: + ast::StateSpace::Global | ast::StateSpace::Const | ast::StateSpace::Shared, + .. + }) => { + result.push(Directive2::Variable(ast::LinkingDirective::NONE, var)); + None + } + s => Some(s), + }) + .collect() + }); +} diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs new file mode 100644 index 0000000..ec6498c --- /dev/null +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -0,0 +1,338 @@ +use super::*; +use ptx_parser::VisitorMap; +use rustc_hash::FxHashSet; + +// This pass: +// * Turns all .local, .param and .reg in-body variables into .local variables +// (if _not_ an input method argument) +// * Inserts explicit `ld`/`st` for newly converted .reg variables +// * Fixup state space of all existing `ld`/`st` instructions into newly +// converted variables +// * Turns `.entry` input arguments into param::entry and all related `.param` +// loads into `param::entry` loads +// * All `.func` input arguments are turned into `.reg` arguments by another +// pass, so we do nothing there +pub(super) fn run<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: Directive2<'input, ast::Instruction, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => { + let visitor = InsertMemSSAVisitor::new(resolver); + Directive2::Method(run_method(visitor, method)?) + } + }) +} + +fn run_method<'a, 'input>( + mut visitor: InsertMemSSAVisitor<'a, 'input>, + method: Function2<'input, ast::Instruction, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + let mut func_decl = method.func_decl; + for arg in func_decl.return_arguments.iter_mut() { + visitor.visit_variable(arg)?; + } + let is_kernel = func_decl.name.is_kernel(); + if is_kernel { + for arg in func_decl.input_arguments.iter_mut() { + let old_name = arg.name; + let old_space = arg.state_space; + let new_space = ast::StateSpace::ParamEntry; + let new_name = visitor + .resolver + .register_unnamed(Some((arg.v_type.clone(), new_space))); + visitor.input_argument(old_name, new_name, old_space); + arg.name = new_name; + arg.state_space = new_space; + } + }; + let body = method + .body + .map(move |statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(&mut visitor, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { + func_decl: func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'a, 'input>( + visitor: &mut InsertMemSSAVisitor<'a, 'input>, + result: &mut Vec, + statement: ExpandedStatement, +) -> Result<(), TranslateError> { + match statement { + Statement::Variable(mut var) => { + visitor.visit_variable(&mut var)?; + result.push(Statement::Variable(var)); + } + Statement::Instruction(ast::Instruction::Ld { data, arguments }) => { + let instruction = visitor.visit_ld(data, arguments)?; + let instruction = ast::visit_map(instruction, visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(Statement::Instruction(instruction)); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } + Statement::Instruction(ast::Instruction::St { data, arguments }) => { + let instruction = visitor.visit_st(data, arguments)?; + let instruction = ast::visit_map(instruction, visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(Statement::Instruction(instruction)); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } + s => { + let new_statement = s.visit_map(visitor)?; + result.extend(visitor.pre.drain(..).map(Statement::Instruction)); + result.push(new_statement); + result.extend(visitor.post.drain(..).map(Statement::Instruction)); + } + } + Ok(()) +} + +struct InsertMemSSAVisitor<'a, 'input> { + resolver: &'a mut GlobalStringIdentResolver2<'input>, + variables: FxHashMap, + pre: Vec>, + post: Vec>, +} + +impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { + fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self { + Self { + resolver, + variables: FxHashMap::default(), + pre: Vec::new(), + post: Vec::new(), + } + } + + fn input_argument( + &mut self, + old_name: SpirvWord, + new_name: SpirvWord, + old_space: ast::StateSpace, + ) -> Result<(), TranslateError> { + if old_space != ast::StateSpace::Param { + return Err(error_unreachable()); + } + self.variables.insert( + old_name, + RemapAction::LDStSpaceChange { + name: new_name, + old_space, + new_space: ast::StateSpace::ParamEntry, + }, + ); + Ok(()) + } + + fn variable( + &mut self, + type_: &ast::Type, + old_name: SpirvWord, + new_name: SpirvWord, + old_space: ast::StateSpace, + ) -> Result<(), TranslateError> { + Ok(match old_space { + ast::StateSpace::Reg => { + self.variables.insert( + old_name, + RemapAction::PreLdPostSt { + name: new_name, + type_: type_.clone(), + }, + ); + } + ast::StateSpace::Param => { + self.variables.insert( + old_name, + RemapAction::LDStSpaceChange { + old_space, + new_space: ast::StateSpace::Local, + name: new_name, + }, + ); + } + // Good as-is + ast::StateSpace::Local => {} + // Will be pulled into global scope later + ast::StateSpace::Generic + | ast::StateSpace::SharedCluster + | ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::SharedCta + | ast::StateSpace::Shared => {} + ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => { + return Err(error_unreachable()) + } + }) + } + + fn visit_st( + &self, + mut data: ast::StData, + mut arguments: ast::StArgs, + ) -> Result, TranslateError> { + if let Some(remap) = self.variables.get(&arguments.src1) { + match remap { + RemapAction::PreLdPostSt { .. } => {} + RemapAction::LDStSpaceChange { + old_space, + new_space, + name, + } => { + if data.state_space != *old_space { + return Err(error_mismatched_type()); + } + data.state_space = *new_space; + arguments.src1 = *name; + } + } + } + Ok(ast::Instruction::St { data, arguments }) + } + + fn visit_ld( + &self, + mut data: ast::LdDetails, + mut arguments: ast::LdArgs, + ) -> Result, TranslateError> { + if let Some(remap) = self.variables.get(&arguments.src) { + match remap { + RemapAction::PreLdPostSt { .. } => {} + RemapAction::LDStSpaceChange { + old_space, + new_space, + name, + } => { + if data.state_space != *old_space { + return Err(error_mismatched_type()); + } + data.state_space = *new_space; + arguments.src = *name; + } + } + } + Ok(ast::Instruction::Ld { data, arguments }) + } + + fn visit_variable(&mut self, var: &mut ast::Variable) -> Result<(), TranslateError> { + if var.state_space != ast::StateSpace::Local { + let old_name = var.name; + let old_space = var.state_space; + let new_space = ast::StateSpace::Local; + let new_name = self + .resolver + .register_unnamed(Some((var.v_type.clone(), new_space))); + self.variable(&var.v_type, old_name, new_name, old_space)?; + var.name = new_name; + var.state_space = new_space; + } + Ok(()) + } +} + +impl<'a, 'input> ast::VisitorMap + for InsertMemSSAVisitor<'a, 'input> +{ + fn visit( + &mut self, + ident: SpirvWord, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + if let Some(remap) = self.variables.get(&ident) { + match remap { + RemapAction::PreLdPostSt { name, type_ } => { + if is_dst { + let temp = self + .resolver + .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg))); + self.post.push(ast::Instruction::St { + data: ast::StData { + state_space: ast::StateSpace::Local, + qualifier: ast::LdStQualifier::Weak, + caching: ast::StCacheOperator::Writethrough, + typ: type_.clone(), + }, + arguments: ast::StArgs { + src1: *name, + src2: temp, + }, + }); + Ok(temp) + } else { + let temp = self + .resolver + .register_unnamed(Some((type_.clone(), ast::StateSpace::Reg))); + self.pre.push(ast::Instruction::Ld { + data: ast::LdDetails { + state_space: ast::StateSpace::Local, + qualifier: ast::LdStQualifier::Weak, + caching: ast::LdCacheOperator::Cached, + typ: type_.clone(), + non_coherent: false, + }, + arguments: ast::LdArgs { + dst: temp, + src: *name, + }, + }); + Ok(temp) + } + } + RemapAction::LDStSpaceChange { .. } => { + return Err(error_mismatched_type()); + } + } + } else { + Ok(ident) + } + } + + fn visit_ident( + &mut self, + args: SpirvWord, + type_space: Option<(&ast::Type, ast::StateSpace)>, + is_dst: bool, + relaxed_type_check: bool, + ) -> Result { + self.visit(args, type_space, is_dst, relaxed_type_check) + } +} + +#[derive(Clone)] +enum RemapAction { + PreLdPostSt { + name: SpirvWord, + type_: ast::Type, + }, + LDStSpaceChange { + old_space: ast::StateSpace, + new_space: ast::StateSpace, + name: SpirvWord, + }, +} diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs index 25e80f0..c04fa09 100644 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -45,6 +45,13 @@ pub(super) fn run( Statement::RepackVector(repack), )?; } + Statement::VectorAccess(vector_access) => { + insert_implicit_conversions_impl( + &mut result, + id_def, + Statement::VectorAccess(vector_access), + )?; + } s @ Statement::Conditional(_) | s @ Statement::Conversion(_) | s @ Statement::Label(_) @@ -128,7 +135,7 @@ pub(crate) fn default_implicit_conversion( (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { if instruction_space == ast::StateSpace::Reg { - if space_is_compatible(operand_space, ast::StateSpace::Reg) { + if operand_space == ast::StateSpace::Reg { if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) = (operand_type, instruction_type) { @@ -142,7 +149,7 @@ pub(crate) fn default_implicit_conversion( return Ok(Some(ConversionKind::AddressOf)); } } - if !space_is_compatible(instruction_space, operand_space) { + if instruction_space != operand_space { default_implicit_conversion_space( (operand_space, operand_type), (instruction_space, instruction_type), @@ -161,7 +168,7 @@ fn is_addressable(this: ast::StateSpace) -> bool { | ast::StateSpace::Global | ast::StateSpace::Local | ast::StateSpace::Shared => true, - ast::StateSpace::Param | ast::StateSpace::Reg | ast::StateSpace::Sreg => false, + ast::StateSpace::Param | ast::StateSpace::Reg => false, ast::StateSpace::SharedCluster | ast::StateSpace::SharedCta | ast::StateSpace::ParamEntry @@ -178,7 +185,7 @@ fn default_implicit_conversion_space( || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) { Ok(Some(ConversionKind::PtrToPtr)) - } else if space_is_compatible(operand_space, ast::StateSpace::Reg) { + } else if operand_space == ast::StateSpace::Reg { match operand_type { ast::Type::Pointer(operand_ptr_type, operand_ptr_space) if *operand_ptr_space == instruction_space => @@ -210,7 +217,7 @@ fn default_implicit_conversion_space( }, _ => Err(error_mismatched_type()), } - } else if space_is_compatible(instruction_space, ast::StateSpace::Reg) { + } else if instruction_space == ast::StateSpace::Reg { match instruction_type { ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) if operand_space == *instruction_ptr_space => @@ -234,7 +241,7 @@ fn default_implicit_conversion_type( operand_type: &ast::Type, instruction_type: &ast::Type, ) -> Result, TranslateError> { - if space_is_compatible(space, ast::StateSpace::Reg) { + if space == ast::StateSpace::Reg { if should_bitcast(instruction_type, operand_type) { Ok(Some(ConversionKind::Default)) } else { @@ -257,8 +264,7 @@ fn coerces_to_generic(this: ast::StateSpace) -> bool { | ast::StateSpace::Param | ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc - | ast::StateSpace::Generic - | ast::StateSpace::Sreg => false, + | ast::StateSpace::Generic => false, } } @@ -294,7 +300,7 @@ pub(crate) fn should_convert_relaxed_dst_wrapper( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if !space_is_compatible(operand_space, instruction_space) { + if operand_space != instruction_space { return Err(TranslateError::MismatchedType); } if operand_type == instruction_type { @@ -371,7 +377,7 @@ pub(crate) fn should_convert_relaxed_src_wrapper( (operand_space, operand_type): (ast::StateSpace, &ast::Type), (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), ) -> Result, TranslateError> { - if !space_is_compatible(operand_space, instruction_space) { + if operand_space != instruction_space { return Err(error_mismatched_type()); } if operand_type == instruction_type { diff --git a/ptx/src/pass/insert_implicit_conversions2.rs b/ptx/src/pass/insert_implicit_conversions2.rs new file mode 100644 index 0000000..4f738f5 --- /dev/null +++ b/ptx/src/pass/insert_implicit_conversions2.rs @@ -0,0 +1,426 @@ +use std::mem; + +use super::*; +use ptx_parser as ast; + +/* + There are several kinds of implicit conversions in PTX: + * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands + * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size + - ld.param: not documented, but for instruction `ld.param. x, [y]`, + semantics are to first zext/chop/bitcast `y` as needed and then do + documented special ld/st/cvt conversion rules for destination operands + - st.param [x] y (used as function return arguments) same rule as above applies + - generic/global ld: for instruction `ld x, [y]`, y must be of type + b64/u64/s64, which is bitcast to a pointer, dereferenced and then + documented special ld/st/cvt conversion rules are applied to dst + - generic/global st: for instruction `st [x], y`, x must be of type + b64/u64/s64, which is bitcast to a pointer +*/ +pub(super) fn run<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec, SpirvWord>>, +) -> Result, SpirvWord>>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'a, 'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: Directive2<'input, ast::Instruction, SpirvWord>, +) -> Result, SpirvWord>, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(mut method) => { + method.body = method + .body + .map(|statements| run_statements(resolver, statements)) + .transpose()?; + Directive2::Method(method) + } + }) +} + +fn run_statements<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + func: Vec, +) -> Result, TranslateError> { + let mut result = Vec::with_capacity(func.len()); + for s in func.into_iter() { + insert_implicit_conversions_impl(resolver, &mut result, s)?; + } + Ok(result) +} + +fn insert_implicit_conversions_impl<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + func: &mut Vec, + stmt: ExpandedStatement, +) -> Result<(), TranslateError> { + let mut post_conv = Vec::new(); + let statement = stmt.visit_map::( + &mut |operand, + type_state: Option<(&ast::Type, ast::StateSpace)>, + is_dst, + relaxed_type_check| { + let (instr_type, instruction_space) = match type_state { + None => return Ok(operand), + Some(t) => t, + }; + let (operand_type, operand_space) = resolver.get_typed(operand)?; + let conversion_fn = if relaxed_type_check { + if is_dst { + should_convert_relaxed_dst_wrapper + } else { + should_convert_relaxed_src_wrapper + } + } else { + default_implicit_conversion + }; + match conversion_fn( + (*operand_space, &operand_type), + (instruction_space, instr_type), + )? { + Some(conv_kind) => { + let conv_output = if is_dst { &mut post_conv } else { &mut *func }; + let mut from_type = instr_type.clone(); + let mut from_space = instruction_space; + let mut to_type = operand_type.clone(); + let mut to_space = *operand_space; + let mut src = + resolver.register_unnamed(Some((instr_type.clone(), instruction_space))); + let mut dst = operand; + let result = Ok::<_, TranslateError>(src); + if !is_dst { + mem::swap(&mut src, &mut dst); + mem::swap(&mut from_type, &mut to_type); + mem::swap(&mut from_space, &mut to_space); + } + conv_output.push(Statement::Conversion(ImplicitConversion { + src, + dst, + from_type, + from_space, + to_type, + to_space, + kind: conv_kind, + })); + result + } + None => Ok(operand), + } + }, + )?; + func.push(statement); + func.append(&mut post_conv); + Ok(()) +} + +pub(crate) fn default_implicit_conversion( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if instruction_space == ast::StateSpace::Reg { + if operand_space == ast::StateSpace::Reg { + if let (ast::Type::Vector(vec_len, vec_underlying_type), ast::Type::Scalar(scalar)) = + (operand_type, instruction_type) + { + if scalar.kind() == ast::ScalarKind::Bit + && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) + { + return Ok(Some(ConversionKind::Default)); + } + } + } else if is_addressable(operand_space) { + return Ok(Some(ConversionKind::AddressOf)); + } + } + if instruction_space != operand_space { + default_implicit_conversion_space( + (operand_space, operand_type), + (instruction_space, instruction_type), + ) + } else if instruction_type != operand_type { + default_implicit_conversion_type(instruction_space, operand_type, instruction_type) + } else { + Ok(None) + } +} + +fn is_addressable(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Const + | ast::StateSpace::Generic + | ast::StateSpace::Global + | ast::StateSpace::Local + | ast::StateSpace::Shared => true, + ast::StateSpace::Param | ast::StateSpace::Reg => false, + ast::StateSpace::SharedCluster + | ast::StateSpace::SharedCta + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc => todo!(), + } +} + +// Space is different +fn default_implicit_conversion_space( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if (instruction_space == ast::StateSpace::Generic && coerces_to_generic(operand_space)) + || (operand_space == ast::StateSpace::Generic && coerces_to_generic(instruction_space)) + { + Ok(Some(ConversionKind::PtrToPtr)) + } else if operand_space == ast::StateSpace::Reg { + match operand_type { + ast::Type::Pointer(operand_ptr_type, operand_ptr_space) + if *operand_ptr_space == instruction_space => + { + if instruction_type != &ast::Type::Scalar(*operand_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + // TODO: 32 bit + ast::Type::Scalar(ast::ScalarType::B64) + | ast::Type::Scalar(ast::ScalarType::U64) + | ast::Type::Scalar(ast::ScalarType::S64) => match instruction_space { + ast::StateSpace::Global + | ast::StateSpace::Generic + | ast::StateSpace::Const + | ast::StateSpace::Local + | ast::StateSpace::Shared => Ok(Some(ConversionKind::BitToPtr)), + _ => Err(error_mismatched_type()), + }, + ast::Type::Scalar(ast::ScalarType::B32) + | ast::Type::Scalar(ast::ScalarType::U32) + | ast::Type::Scalar(ast::ScalarType::S32) => match instruction_space { + ast::StateSpace::Const | ast::StateSpace::Local | ast::StateSpace::Shared => { + Ok(Some(ConversionKind::BitToPtr)) + } + _ => Err(error_mismatched_type()), + }, + _ => Err(error_mismatched_type()), + } + } else if instruction_space == ast::StateSpace::Reg { + match instruction_type { + ast::Type::Pointer(instruction_ptr_type, instruction_ptr_space) + if operand_space == *instruction_ptr_space => + { + if operand_type != &ast::Type::Scalar(*instruction_ptr_type) { + Ok(Some(ConversionKind::PtrToPtr)) + } else { + Ok(None) + } + } + _ => Err(error_mismatched_type()), + } + } else { + Err(error_mismatched_type()) + } +} + +// Space is same, but type is different +fn default_implicit_conversion_type( + space: ast::StateSpace, + operand_type: &ast::Type, + instruction_type: &ast::Type, +) -> Result, TranslateError> { + if space == ast::StateSpace::Reg { + if should_bitcast(instruction_type, operand_type) { + Ok(Some(ConversionKind::Default)) + } else { + Err(TranslateError::MismatchedType) + } + } else { + Ok(Some(ConversionKind::PtrToPtr)) + } +} + +fn coerces_to_generic(this: ast::StateSpace) -> bool { + match this { + ast::StateSpace::Global + | ast::StateSpace::Const + | ast::StateSpace::Local + | ptx_parser::StateSpace::SharedCta + | ast::StateSpace::SharedCluster + | ast::StateSpace::Shared => true, + ast::StateSpace::Reg + | ast::StateSpace::Param + | ast::StateSpace::ParamEntry + | ast::StateSpace::ParamFunc + | ast::StateSpace::Generic => false, + } +} + +fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { + match (instr, operand) { + (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { + if inst.size_of() != operand.size_of() { + return false; + } + match inst.kind() { + ast::ScalarKind::Bit => operand.kind() != ast::ScalarKind::Bit, + ast::ScalarKind::Float => operand.kind() == ast::ScalarKind::Bit, + ast::ScalarKind::Signed => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Unsigned + } + ast::ScalarKind::Unsigned => { + operand.kind() == ast::ScalarKind::Bit + || operand.kind() == ast::ScalarKind::Signed + } + ast::ScalarKind::Pred => false, + } + } + (ast::Type::Vector(_, inst), ast::Type::Vector(_, operand)) + | (ast::Type::Array(_, inst, _), ast::Type::Array(_, operand, _)) => { + should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) + } + _ => false, + } +} + +pub(crate) fn should_convert_relaxed_dst_wrapper( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if operand_space != instruction_space { + return Err(TranslateError::MismatchedType); + } + if operand_type == instruction_type { + return Ok(None); + } + match should_convert_relaxed_dst(operand_type, instruction_type) { + conv @ Some(_) => Ok(conv), + None => Err(TranslateError::MismatchedType), + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands +fn should_convert_relaxed_dst( + dst_type: &ast::Type, + instr_type: &ast::Type, +) -> Option { + if dst_type == instr_type { + return None; + } + match (dst_type, instr_type) { + (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ast::ScalarKind::Bit => { + if instr_type.size_of() <= dst_type.size_of() { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Signed => { + if dst_type.kind() != ast::ScalarKind::Float { + if instr_type.size_of() == dst_type.size_of() { + Some(ConversionKind::Default) + } else if instr_type.size_of() < dst_type.size_of() { + Some(ConversionKind::SignExtend) + } else { + None + } + } else { + None + } + } + ast::ScalarKind::Unsigned => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() != ast::ScalarKind::Float + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Float => { + if instr_type.size_of() <= dst_type.size_of() + && dst_type.kind() == ast::ScalarKind::Bit + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Pred => None, + }, + (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) + | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { + should_convert_relaxed_dst( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) + } + _ => None, + } +} + +pub(crate) fn should_convert_relaxed_src_wrapper( + (operand_space, operand_type): (ast::StateSpace, &ast::Type), + (instruction_space, instruction_type): (ast::StateSpace, &ast::Type), +) -> Result, TranslateError> { + if operand_space != instruction_space { + return Err(error_mismatched_type()); + } + if operand_type == instruction_type { + return Ok(None); + } + match should_convert_relaxed_src(operand_type, instruction_type) { + conv @ Some(_) => Ok(conv), + None => Err(error_mismatched_type()), + } +} + +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands +fn should_convert_relaxed_src( + src_type: &ast::Type, + instr_type: &ast::Type, +) -> Option { + if src_type == instr_type { + return None; + } + match (src_type, instr_type) { + (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { + ast::ScalarKind::Bit => { + if instr_type.size_of() <= src_type.size_of() { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Signed | ast::ScalarKind::Unsigned => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() != ast::ScalarKind::Float + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Float => { + if instr_type.size_of() <= src_type.size_of() + && src_type.kind() == ast::ScalarKind::Bit + { + Some(ConversionKind::Default) + } else { + None + } + } + ast::ScalarKind::Pred => None, + }, + (ast::Type::Vector(_, dst_type), ast::Type::Vector(_, instr_type)) + | (ast::Type::Array(_, dst_type, _), ast::Type::Array(_, instr_type, _)) => { + should_convert_relaxed_src( + &ast::Type::Scalar(*dst_type), + &ast::Type::Scalar(*instr_type), + ) + } + _ => None, + } +} diff --git a/ptx/src/pass/insert_mem_ssa_statements.rs b/ptx/src/pass/insert_mem_ssa_statements.rs index e314b05..150109b 100644 --- a/ptx/src/pass/insert_mem_ssa_statements.rs +++ b/ptx/src/pass/insert_mem_ssa_statements.rs @@ -189,7 +189,7 @@ impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { return Ok(symbol); }; let (mut var_type, var_space, is_variable) = self.id_def.get_typed(symbol)?; - if !space_is_compatible(var_space, ast::StateSpace::Reg) || !is_variable { + if var_space != ast::StateSpace::Reg || !is_variable { return Ok(symbol); }; let member_index = match member_index { diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 3aa3b0a..0e233ed 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,5 +1,6 @@ use ptx_parser as ast; use rspirv::{binary::Assemble, dr}; +use rustc_hash::FxHashMap; use std::hash::Hash; use std::num::NonZeroU8; use std::{ @@ -12,20 +13,31 @@ use std::{ mem, rc::Rc, }; +use strum::IntoEnumIterator; +use strum_macros::EnumIter; mod convert_dynamic_shared_memory_usage; mod convert_to_stateful_memory_access; mod convert_to_typed; +mod deparamize_functions; pub(crate) mod emit_llvm; mod emit_spirv; mod expand_arguments; +mod expand_operands; mod extract_globals; mod fix_special_registers; +mod fix_special_registers2; +mod hoist_globals; +mod insert_explicit_load_store; mod insert_implicit_conversions; +mod insert_implicit_conversions2; mod insert_mem_ssa_statements; mod normalize_identifiers; +mod normalize_identifiers2; mod normalize_labels; mod normalize_predicates; +mod normalize_predicates2; +mod resolve_function_pointers; static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv"); static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc"); @@ -57,7 +69,30 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result(ast: ast::Module<'input>) -> Result { + let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1)); + let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver); + let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?; + let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?; + let directives = normalize_predicates2::run(&mut flat_resolver, directives)?; + let directives = resolve_function_pointers::run(directives)?; + let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?; + let directives: Vec, SpirvWord>> = + expand_operands::run(&mut flat_resolver, directives)?; + let directives = deparamize_functions::run(&mut flat_resolver, directives)?; + let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?; + let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?; + let directives = hoist_globals::run(directives)?; + let llvm_ir = emit_llvm::run(flat_resolver, directives)?; Ok(Module { llvm_ir, kernel_info: HashMap::new(), @@ -319,7 +354,7 @@ pub struct KernelInfo { pub uses_shared_mem: bool, } -#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] +#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)] enum PtxSpecialRegister { Tid, Ntid, @@ -342,6 +377,17 @@ impl PtxSpecialRegister { } } + fn as_str(self) -> &'static str { + match self { + Self::Tid => "%tid", + Self::Ntid => "%ntid", + Self::Ctaid => "%ctaid", + Self::Nctaid => "%nctaid", + Self::Clock => "%clock", + Self::LanemaskLt => "%lanemask_lt", + } + } + fn get_type(self) -> ast::Type { match self { PtxSpecialRegister::Tid @@ -525,7 +571,7 @@ impl<'b> NumericIdResolver<'b> { Some(Some(x)) => Ok(x.clone()), Some(None) => Err(TranslateError::UntypedSymbol), None => match self.special_registers.get(id) { - Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)), + Some(x) => Ok((x.get_type(), ast::StateSpace::Reg, true)), None => match self.global_type_check.get(&id) { Some(Some(result)) => Ok(result.clone()), Some(None) | None => Err(TranslateError::UntypedSymbol), @@ -722,6 +768,7 @@ enum Statement { PtrAccess(PtrAccess

), RepackVector(RepackVectorDetails), FunctionPointer(FunctionPointerDetails), + VectorAccess(VectorAccess), } impl> Statement, T> { @@ -890,6 +937,36 @@ impl> Statement, T> { offset_src, }) } + Statement::VectorAccess(VectorAccess { + scalar_type, + vector_width, + dst, + src: vector_src, + member, + }) => { + let dst: SpirvWord = visitor.visit_ident( + dst, + Some((&scalar_type.into(), ast::StateSpace::Reg)), + true, + false, + )?; + let src = visitor.visit_ident( + vector_src, + Some(( + &ast::Type::Vector(vector_width, scalar_type), + ast::StateSpace::Reg, + )), + false, + false, + )?; + Statement::VectorAccess(VectorAccess { + scalar_type, + vector_width, + dst, + src, + member, + }) + } Statement::RepackVector(RepackVectorDetails { is_extract, typ, @@ -1207,12 +1284,6 @@ impl< } } -fn space_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool { - this == other - || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg - || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg -} - fn register_external_fn_call<'a>( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, @@ -1450,6 +1521,7 @@ fn compute_denorm_information<'input>( Statement::Label(_) => {} Statement::Variable(_) => {} Statement::PtrAccess { .. } => {} + Statement::VectorAccess { .. } => {} Statement::RepackVector(_) => {} Statement::FunctionPointer(_) => {} } @@ -1663,3 +1735,278 @@ fn denorm_count_map_update_impl( } } } + +pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> { + Variable(ast::LinkingDirective, ast::Variable), + Method(Function2<'input, Instruction, Operand>), +} + +pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> { + pub func_decl: ast::MethodDeclaration<'input, SpirvWord>, + pub globals: Vec>, + pub body: Option>>, + import_as: Option, + tuning: Vec, + linkage: ast::LinkingDirective, +} + +type NormalizedDirective2<'input> = Directive2< + 'input, + ( + Option>, + ast::Instruction>, + ), + ast::ParsedOperand, +>; + +type NormalizedFunction2<'input> = Function2< + 'input, + ( + Option>, + ast::Instruction>, + ), + ast::ParsedOperand, +>; + +type UnconditionalDirective<'input> = Directive2< + 'input, + ast::Instruction>, + ast::ParsedOperand, +>; + +type UnconditionalFunction<'input> = Function2< + 'input, + ast::Instruction>, + ast::ParsedOperand, +>; + +struct GlobalStringIdentResolver2<'input> { + pub(crate) current_id: SpirvWord, + pub(crate) ident_map: FxHashMap>, +} + +impl<'input> GlobalStringIdentResolver2<'input> { + fn new(spirv_word: SpirvWord) -> Self { + Self { + current_id: spirv_word, + ident_map: FxHashMap::default(), + } + } + + fn register_named( + &mut self, + name: Cow<'input, str>, + type_space: Option<(ast::Type, ast::StateSpace)>, + ) -> SpirvWord { + let new_id = self.current_id; + self.ident_map.insert( + new_id, + IdentEntry { + name: Some(name), + type_space, + }, + ); + self.current_id.0 += 1; + new_id + } + + fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord { + let new_id = self.current_id; + self.ident_map.insert( + new_id, + IdentEntry { + name: None, + type_space, + }, + ); + self.current_id.0 += 1; + new_id + } + + fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> { + match self.ident_map.get(&id) { + Some(IdentEntry { + type_space: Some(type_space), + .. + }) => Ok(type_space), + _ => Err(error_unknown_symbol()), + } + } +} + +struct IdentEntry<'input> { + name: Option>, + type_space: Option<(ast::Type, ast::StateSpace)>, +} + +struct ScopedResolver<'input, 'b> { + flat_resolver: &'b mut GlobalStringIdentResolver2<'input>, + scopes: Vec>, +} + +impl<'input, 'b> ScopedResolver<'input, 'b> { + fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self { + Self { + flat_resolver, + scopes: vec![ScopeMarker::new()], + } + } + + fn start_scope(&mut self) { + self.scopes.push(ScopeMarker::new()); + } + + fn end_scope(&mut self) { + let scope = self.scopes.pop().unwrap(); + scope.flush(self.flat_resolver); + } + + fn add( + &mut self, + name: Cow<'input, str>, + type_space: Option<(ast::Type, ast::StateSpace)>, + ) -> Result { + let result = self.flat_resolver.current_id; + self.flat_resolver.current_id.0 += 1; + let current_scope = self.scopes.last_mut().unwrap(); + if current_scope + .name_to_ident + .insert(name.clone(), result) + .is_some() + { + return Err(error_unknown_symbol()); + } + current_scope.ident_map.insert( + result, + IdentEntry { + name: Some(name), + type_space, + }, + ); + Ok(result) + } + + fn get(&mut self, name: &str) -> Result { + self.scopes + .iter() + .rev() + .find_map(|resolver| resolver.name_to_ident.get(name).copied()) + .ok_or_else(|| error_unreachable()) + } + + fn get_in_current_scope(&self, label: &'input str) -> Result { + let current_scope = self.scopes.last().unwrap(); + current_scope + .name_to_ident + .get(label) + .copied() + .ok_or_else(|| error_unreachable()) + } +} + +struct ScopeMarker<'input> { + ident_map: FxHashMap>, + name_to_ident: FxHashMap, SpirvWord>, +} + +impl<'input> ScopeMarker<'input> { + fn new() -> Self { + Self { + ident_map: FxHashMap::default(), + name_to_ident: FxHashMap::default(), + } + } + + fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) { + resolver.ident_map.extend(self.ident_map); + } +} + +struct SpecialRegistersMap2 { + reg_to_id: FxHashMap, + id_to_reg: FxHashMap, +} + +impl SpecialRegistersMap2 { + fn new(resolver: &mut ScopedResolver) -> Result { + let mut result = SpecialRegistersMap2 { + reg_to_id: FxHashMap::default(), + id_to_reg: FxHashMap::default(), + }; + for sreg in PtxSpecialRegister::iter() { + let text = sreg.as_str(); + let id = resolver.add( + Cow::Borrowed(text), + Some((sreg.get_type(), ast::StateSpace::Reg)), + )?; + result.reg_to_id.insert(sreg, id); + result.id_to_reg.insert(id, sreg); + } + Ok(result) + } + + fn get(&self, id: SpirvWord) -> Option { + self.id_to_reg.get(&id).copied() + } + + fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord { + match self.reg_to_id.entry(reg) { + hash_map::Entry::Occupied(e) => *e.get(), + hash_map::Entry::Vacant(e) => { + let numeric_id = SpirvWord(current_id.0); + current_id.0 += 1; + e.insert(numeric_id); + self.id_to_reg.insert(numeric_id, reg); + numeric_id + } + } + } + + fn generate_declarations<'a, 'input>( + resolver: &'a mut GlobalStringIdentResolver2<'input>, + ) -> impl ExactSizeIterator< + Item = ( + PtxSpecialRegister, + ast::MethodDeclaration<'input, SpirvWord>, + ), + > + 'a { + PtxSpecialRegister::iter().map(|sreg| { + let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat(); + let name = + ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None)); + let return_type = sreg.get_function_return_type(); + let input_type = sreg.get_function_return_type(); + ( + sreg, + ast::MethodDeclaration { + return_arguments: vec![ast::Variable { + align: None, + v_type: return_type.into(), + state_space: ast::StateSpace::Reg, + name: resolver + .register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))), + array_init: Vec::new(), + }], + name: name, + input_arguments: vec![ast::Variable { + align: None, + v_type: input_type.into(), + state_space: ast::StateSpace::Reg, + name: resolver + .register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))), + array_init: Vec::new(), + }], + shared_mem: None, + }, + ) + }) + } +} + +pub struct VectorAccess { + scalar_type: ast::ScalarType, + vector_width: u8, + dst: SpirvWord, + src: SpirvWord, + member: u8, +} diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs new file mode 100644 index 0000000..beaf08b --- /dev/null +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -0,0 +1,199 @@ +use super::*; +use ptx_parser as ast; +use rustc_hash::FxHashMap; + +pub(crate) fn run<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + directives: Vec>>, +) -> Result>, TranslateError> { + resolver.start_scope(); + let result = directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>()?; + resolver.end_scope(); + Ok(result) +} + +fn run_directive<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>, +) -> Result, TranslateError> { + Ok(match directive { + ast::Directive::Variable(linking, var) => { + NormalizedDirective2::Variable(linking, run_variable(resolver, var)?) + } + ast::Directive::Method(linking, directive) => { + NormalizedDirective2::Method(run_method(resolver, linking, directive)?) + } + }) +} + +fn run_method<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + linkage: ast::LinkingDirective, + method: ast::Function<'input, &'input str, ast::Statement>>, +) -> Result, TranslateError> { + let name = match method.func_directive.name { + ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name), + ast::MethodName::Func(text) => { + ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?) + } + }; + resolver.start_scope(); + let func_decl = run_function_decl(resolver, method.func_directive, name)?; + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + run_statements(resolver, &mut result, statements)?; + Ok::<_, TranslateError>(result) + }) + .transpose()?; + resolver.end_scope(); + Ok(Function2 { + func_decl, + globals: Vec::new(), + body, + import_as: None, + tuning: method.tuning, + linkage, + }) +} + +fn run_function_decl<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + func_directive: ast::MethodDeclaration<'input, &'input str>, + name: ast::MethodName<'input, SpirvWord>, +) -> Result, TranslateError> { + assert!(func_directive.shared_mem.is_none()); + let return_arguments = func_directive + .return_arguments + .into_iter() + .map(|var| run_variable(resolver, var)) + .collect::, _>>()?; + let input_arguments = func_directive + .input_arguments + .into_iter() + .map(|var| run_variable(resolver, var)) + .collect::, _>>()?; + Ok(ast::MethodDeclaration { + return_arguments, + name, + input_arguments, + shared_mem: None, + }) +} + +fn run_variable<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + variable: ast::Variable<&'input str>, +) -> Result, TranslateError> { + Ok(ast::Variable { + name: resolver.add( + Cow::Borrowed(variable.name), + Some((variable.v_type.clone(), variable.state_space)), + )?, + align: variable.align, + v_type: variable.v_type, + state_space: variable.state_space, + array_init: variable.array_init, + }) +} + +fn run_statements<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + result: &mut Vec, + statements: Vec>>, +) -> Result<(), TranslateError> { + for statement in statements.iter() { + match statement { + ast::Statement::Label(label) => { + resolver.add(Cow::Borrowed(*label), None)?; + } + _ => {} + } + } + for statement in statements { + match statement { + ast::Statement::Label(label) => { + result.push(Statement::Label(resolver.get_in_current_scope(label)?)) + } + ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?, + ast::Statement::Instruction(predicate, instruction) => { + result.push(Statement::Instruction(( + predicate + .map(|pred| { + Ok::<_, TranslateError>(ast::PredAt { + not: pred.not, + label: resolver.get(pred.label)?, + }) + }) + .transpose()?, + run_instruction(resolver, instruction)?, + ))) + } + ast::Statement::Block(block) => { + resolver.start_scope(); + run_statements(resolver, result, block)?; + resolver.end_scope(); + } + } + } + Ok(()) +} + +fn run_instruction<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + instruction: ast::Instruction>, +) -> Result>, TranslateError> { + ast::visit_map(instruction, &mut |name: &'input str, + _: Option<( + &ast::Type, + ast::StateSpace, + )>, + _, + _| { + resolver.get(&name) + }) +} + +fn run_multivariable<'input, 'b>( + resolver: &mut ScopedResolver<'input, 'b>, + result: &mut Vec, + variable: ast::MultiVariable<&'input str>, +) -> Result<(), TranslateError> { + match variable.count { + Some(count) => { + for i in 0..count { + let name = Cow::Owned(format!("{}{}", variable.var.name, i)); + let ident = resolver.add( + name, + Some((variable.var.v_type.clone(), variable.var.state_space)), + )?; + result.push(Statement::Variable(ast::Variable { + align: variable.var.align, + v_type: variable.var.v_type.clone(), + state_space: variable.var.state_space, + name: ident, + array_init: variable.var.array_init.clone(), + })); + } + } + None => { + let name = Cow::Borrowed(variable.var.name); + let ident = resolver.add( + name, + Some((variable.var.v_type.clone(), variable.var.state_space)), + )?; + result.push(Statement::Variable(ast::Variable { + align: variable.var.align, + v_type: variable.var.v_type.clone(), + state_space: variable.var.state_space, + name: ident, + array_init: variable.var.array_init.clone(), + })); + } + } + Ok(()) +} diff --git a/ptx/src/pass/normalize_labels.rs b/ptx/src/pass/normalize_labels.rs index 097d87c..037e918 100644 --- a/ptx/src/pass/normalize_labels.rs +++ b/ptx/src/pass/normalize_labels.rs @@ -26,6 +26,7 @@ pub(super) fn run( | Statement::Constant(..) | Statement::Label(..) | Statement::PtrAccess { .. } + | Statement::VectorAccess { .. } | Statement::RepackVector(..) | Statement::FunctionPointer(..) => {} } diff --git a/ptx/src/pass/normalize_predicates2.rs b/ptx/src/pass/normalize_predicates2.rs new file mode 100644 index 0000000..d91e23c --- /dev/null +++ b/ptx/src/pass/normalize_predicates2.rs @@ -0,0 +1,84 @@ +use super::*; +use ptx_parser as ast; + +pub(crate) fn run<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directives: Vec>, +) -> Result>, TranslateError> { + directives + .into_iter() + .map(|directive| run_directive(resolver, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + directive: NormalizedDirective2<'input>, +) -> Result, TranslateError> { + Ok(match directive { + Directive2::Variable(linking, var) => Directive2::Variable(linking, var), + Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?), + }) +} + +fn run_method<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + method: NormalizedFunction2<'input>, +) -> Result, TranslateError> { + let body = method + .body + .map(|statements| { + let mut result = Vec::with_capacity(statements.len()); + for statement in statements { + run_statement(resolver, &mut result, statement)?; + } + Ok::<_, TranslateError>(result) + }) + .transpose()?; + Ok(Function2 { + func_decl: method.func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'input>( + resolver: &mut GlobalStringIdentResolver2<'input>, + result: &mut Vec, + statement: NormalizedStatement, +) -> Result<(), TranslateError> { + Ok(match statement { + Statement::Label(label) => result.push(Statement::Label(label)), + Statement::Variable(var) => result.push(Statement::Variable(var)), + Statement::Instruction((predicate, instruction)) => { + if let Some(pred) = predicate { + let if_true = resolver.register_unnamed(None); + let if_false = resolver.register_unnamed(None); + let folded_bra = match &instruction { + ast::Instruction::Bra { arguments, .. } => Some(arguments.src), + _ => None, + }; + let mut branch = BrachCondition { + predicate: pred.label, + if_true: folded_bra.unwrap_or(if_true), + if_false, + }; + if pred.not { + std::mem::swap(&mut branch.if_true, &mut branch.if_false); + } + result.push(Statement::Conditional(branch)); + if folded_bra.is_none() { + result.push(Statement::Label(if_true)); + result.push(Statement::Instruction(instruction)); + } + result.push(Statement::Label(if_false)); + } else { + result.push(Statement::Instruction(instruction)); + } + } + _ => return Err(error_unreachable()), + }) +} diff --git a/ptx/src/pass/resolve_function_pointers.rs b/ptx/src/pass/resolve_function_pointers.rs new file mode 100644 index 0000000..eb7abb1 --- /dev/null +++ b/ptx/src/pass/resolve_function_pointers.rs @@ -0,0 +1,82 @@ +use super::*; +use ptx_parser as ast; +use rustc_hash::FxHashSet; + +pub(crate) fn run<'input>( + directives: Vec>, +) -> Result>, TranslateError> { + let mut functions = FxHashSet::default(); + directives + .into_iter() + .map(|directive| run_directive(&mut functions, directive)) + .collect::, _>>() +} + +fn run_directive<'input>( + functions: &mut FxHashSet, + directive: UnconditionalDirective<'input>, +) -> Result, TranslateError> { + Ok(match directive { + var @ Directive2::Variable(..) => var, + Directive2::Method(method) => { + { + let func_decl = &method.func_decl; + match func_decl.name { + ptx_parser::MethodName::Kernel(_) => {} + ptx_parser::MethodName::Func(name) => { + functions.insert(name); + } + } + } + Directive2::Method(run_method(functions, method)?) + } + }) +} + +fn run_method<'input>( + functions: &mut FxHashSet, + method: UnconditionalFunction<'input>, +) -> Result, TranslateError> { + let body = method + .body + .map(|statements| { + statements + .into_iter() + .map(|statement| run_statement(functions, statement)) + .collect::, _>>() + }) + .transpose()?; + Ok(Function2 { + func_decl: method.func_decl, + globals: method.globals, + body, + import_as: method.import_as, + tuning: method.tuning, + linkage: method.linkage, + }) +} + +fn run_statement<'input>( + functions: &mut FxHashSet, + statement: UnconditionalStatement, +) -> Result { + Ok(match statement { + Statement::Instruction(ast::Instruction::Mov { + data, + arguments: + ast::MovArgs { + dst: ast::ParsedOperand::Reg(dst_reg), + src: ast::ParsedOperand::Reg(src_reg), + }, + }) if functions.contains(&src_reg) => { + if data.typ != ast::Type::Scalar(ast::ScalarType::U64) { + return Err(error_mismatched_type()); + } + UnconditionalStatement::FunctionPointer(FunctionPointerDetails { + dst: dst_reg, + src: src_reg, + }) + } + s => s, + }) +} diff --git a/ptx/src/test/spirv_run/mod.rs b/ptx/src/test/spirv_run/mod.rs index 69dd206..e15d6ea 100644 --- a/ptx/src/test/spirv_run/mod.rs +++ b/ptx/src/test/spirv_run/mod.rs @@ -236,7 +236,7 @@ fn test_hip_assert< output: &mut [Output], ) -> Result<(), Box> { let ast = ptx_parser::parse_module_checked(ptx_text).unwrap(); - let llvm_ir = pass::to_llvm_module(ast).unwrap(); + let llvm_ir = pass::to_llvm_module2(ast).unwrap(); let name = CString::new(name)?; let result = run_hip(name.as_c_str(), llvm_ir, input, output).map_err(|err| DisplayError { err })?; diff --git a/ptx_parser/src/ast.rs b/ptx_parser/src/ast.rs index a90b21e..65c624e 100644 --- a/ptx_parser/src/ast.rs +++ b/ptx_parser/src/ast.rs @@ -1049,6 +1049,15 @@ impl<'input, ID> MethodName<'input, ID> { } } +impl<'input> MethodName<'input, &'input str> { + pub fn text(&self) -> &'input str { + match self { + MethodName::Kernel(name) => *name, + MethodName::Func(name) => *name, + } + } +} + bitflags! { pub struct LinkingDirective: u8 { const NONE = 0b000; @@ -1291,7 +1300,12 @@ impl CallArgs { .iter() .zip(details.return_arguments.iter()) { - visitor.visit_ident(param, Some((type_, *space)), true, false)?; + visitor.visit_ident( + param, + Some((type_, *space)), + *space == StateSpace::Reg, + false, + )?; } visitor.visit_ident(&self.func, None, false, false)?; for (param, (type_, space)) in self @@ -1315,7 +1329,12 @@ impl CallArgs { .iter_mut() .zip(details.return_arguments.iter()) { - visitor.visit_ident(param, Some((type_, *space)), true, false)?; + visitor.visit_ident( + param, + Some((type_, *space)), + *space == StateSpace::Reg, + false, + )?; } visitor.visit_ident(&mut self.func, None, false, false)?; for (param, (type_, space)) in self @@ -1339,7 +1358,12 @@ impl CallArgs { .into_iter() .zip(details.return_arguments.iter()) .map(|(param, (type_, space))| { - visitor.visit_ident(param, Some((type_, *space)), true, false) + visitor.visit_ident( + param, + Some((type_, *space)), + *space == StateSpace::Reg, + false, + ) }) .collect::, _>>()?; let func = visitor.visit_ident(self.func, None, false, false)?; diff --git a/ptx_parser/src/lib.rs b/ptx_parser/src/lib.rs index f842ace..fee11aa 100644 --- a/ptx_parser/src/lib.rs +++ b/ptx_parser/src/lib.rs @@ -1499,7 +1499,6 @@ derive_parser!( pub enum StateSpace { Reg, Generic, - Sreg, } #[derive(Copy, Clone, PartialEq, Eq, Hash)] -- cgit v1.2.3