diff options
Diffstat (limited to 'ptx/src/pass/mod.rs')
-rw-r--r-- | ptx/src/pass/mod.rs | 365 |
1 files changed, 356 insertions, 9 deletions
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 3aa3b0a..0e233ed 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,5 +1,6 @@ use ptx_parser as ast;
use rspirv::{binary::Assemble, dr};
+use rustc_hash::FxHashMap;
use std::hash::Hash;
use std::num::NonZeroU8;
use std::{
@@ -12,20 +13,31 @@ use std::{ mem,
rc::Rc,
};
+use strum::IntoEnumIterator;
+use strum_macros::EnumIter;
mod convert_dynamic_shared_memory_usage;
mod convert_to_stateful_memory_access;
mod convert_to_typed;
+mod deparamize_functions;
pub(crate) mod emit_llvm;
mod emit_spirv;
mod expand_arguments;
+mod expand_operands;
mod extract_globals;
mod fix_special_registers;
+mod fix_special_registers2;
+mod hoist_globals;
+mod insert_explicit_load_store;
mod insert_implicit_conversions;
+mod insert_implicit_conversions2;
mod insert_mem_ssa_statements;
mod normalize_identifiers;
+mod normalize_identifiers2;
mod normalize_labels;
mod normalize_predicates;
+mod normalize_predicates2;
+mod resolve_function_pointers;
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
@@ -57,7 +69,30 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl })?;
normalize_variable_decls(&mut directives);
let denorm_information = compute_denorm_information(&directives);
- let llvm_ir = emit_llvm::run(&id_defs, call_map, directives)?;
+ todo!()
+ /*
+ let llvm_ir: emit_llvm::MemoryBuffer = emit_llvm::run(&id_defs, call_map, directives)?;
+ Ok(Module {
+ llvm_ir,
+ kernel_info: HashMap::new(),
+ }) */
+}
+
+pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
+ let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
+ let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
+ let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
+ let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
+ let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
+ let directives = resolve_function_pointers::run(directives)?;
+ let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
+ let directives: Vec<Directive2<'_, ptx_parser::Instruction<SpirvWord>, SpirvWord>> =
+ expand_operands::run(&mut flat_resolver, directives)?;
+ let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
+ let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
+ let directives = insert_implicit_conversions2::run(&mut flat_resolver, directives)?;
+ let directives = hoist_globals::run(directives)?;
+ let llvm_ir = emit_llvm::run(flat_resolver, directives)?;
Ok(Module {
llvm_ir,
kernel_info: HashMap::new(),
@@ -319,7 +354,7 @@ pub struct KernelInfo { pub uses_shared_mem: bool,
}
-#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
+#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)]
enum PtxSpecialRegister {
Tid,
Ntid,
@@ -342,6 +377,17 @@ impl PtxSpecialRegister { }
}
+ fn as_str(self) -> &'static str {
+ match self {
+ Self::Tid => "%tid",
+ Self::Ntid => "%ntid",
+ Self::Ctaid => "%ctaid",
+ Self::Nctaid => "%nctaid",
+ Self::Clock => "%clock",
+ Self::LanemaskLt => "%lanemask_lt",
+ }
+ }
+
fn get_type(self) -> ast::Type {
match self {
PtxSpecialRegister::Tid
@@ -525,7 +571,7 @@ impl<'b> NumericIdResolver<'b> { Some(Some(x)) => Ok(x.clone()),
Some(None) => Err(TranslateError::UntypedSymbol),
None => match self.special_registers.get(id) {
- Some(x) => Ok((x.get_type(), ast::StateSpace::Sreg, true)),
+ Some(x) => Ok((x.get_type(), ast::StateSpace::Reg, true)),
None => match self.global_type_check.get(&id) {
Some(Some(result)) => Ok(result.clone()),
Some(None) | None => Err(TranslateError::UntypedSymbol),
@@ -722,6 +768,7 @@ enum Statement<I, P: ast::Operand> { PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails),
FunctionPointer(FunctionPointerDetails),
+ VectorAccess(VectorAccess),
}
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
@@ -890,6 +937,36 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> { offset_src,
})
}
+ Statement::VectorAccess(VectorAccess {
+ scalar_type,
+ vector_width,
+ dst,
+ src: vector_src,
+ member,
+ }) => {
+ let dst: SpirvWord = visitor.visit_ident(
+ dst,
+ Some((&scalar_type.into(), ast::StateSpace::Reg)),
+ true,
+ false,
+ )?;
+ let src = visitor.visit_ident(
+ vector_src,
+ Some((
+ &ast::Type::Vector(vector_width, scalar_type),
+ ast::StateSpace::Reg,
+ )),
+ false,
+ false,
+ )?;
+ Statement::VectorAccess(VectorAccess {
+ scalar_type,
+ vector_width,
+ dst,
+ src,
+ member,
+ })
+ }
Statement::RepackVector(RepackVectorDetails {
is_extract,
typ,
@@ -1207,12 +1284,6 @@ impl< }
}
-fn space_is_compatible(this: ast::StateSpace, other: ast::StateSpace) -> bool {
- this == other
- || this == ast::StateSpace::Reg && other == ast::StateSpace::Sreg
- || this == ast::StateSpace::Sreg && other == ast::StateSpace::Reg
-}
-
fn register_external_fn_call<'a>(
id_defs: &mut NumericIdResolver,
ptx_impl_imports: &mut HashMap<String, Directive>,
@@ -1450,6 +1521,7 @@ fn compute_denorm_information<'input>( Statement::Label(_) => {}
Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {}
+ Statement::VectorAccess { .. } => {}
Statement::RepackVector(_) => {}
Statement::FunctionPointer(_) => {}
}
@@ -1663,3 +1735,278 @@ fn denorm_count_map_update_impl<T: Eq + Hash>( }
}
}
+
+pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> {
+ Variable(ast::LinkingDirective, ast::Variable<SpirvWord>),
+ Method(Function2<'input, Instruction, Operand>),
+}
+
+pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> {
+ pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
+ pub globals: Vec<ast::Variable<SpirvWord>>,
+ pub body: Option<Vec<Statement<Instruction, Operand>>>,
+ import_as: Option<String>,
+ tuning: Vec<ast::TuningDirective>,
+ linkage: ast::LinkingDirective,
+}
+
+type NormalizedDirective2<'input> = Directive2<
+ 'input,
+ (
+ Option<ast::PredAt<SpirvWord>>,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ),
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+type NormalizedFunction2<'input> = Function2<
+ 'input,
+ (
+ Option<ast::PredAt<SpirvWord>>,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ),
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+type UnconditionalDirective<'input> = Directive2<
+ 'input,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+type UnconditionalFunction<'input> = Function2<
+ 'input,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+struct GlobalStringIdentResolver2<'input> {
+ pub(crate) current_id: SpirvWord,
+ pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
+}
+
+impl<'input> GlobalStringIdentResolver2<'input> {
+ fn new(spirv_word: SpirvWord) -> Self {
+ Self {
+ current_id: spirv_word,
+ ident_map: FxHashMap::default(),
+ }
+ }
+
+ fn register_named(
+ &mut self,
+ name: Cow<'input, str>,
+ type_space: Option<(ast::Type, ast::StateSpace)>,
+ ) -> SpirvWord {
+ let new_id = self.current_id;
+ self.ident_map.insert(
+ new_id,
+ IdentEntry {
+ name: Some(name),
+ type_space,
+ },
+ );
+ self.current_id.0 += 1;
+ new_id
+ }
+
+ fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
+ let new_id = self.current_id;
+ self.ident_map.insert(
+ new_id,
+ IdentEntry {
+ name: None,
+ type_space,
+ },
+ );
+ self.current_id.0 += 1;
+ new_id
+ }
+
+ fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> {
+ match self.ident_map.get(&id) {
+ Some(IdentEntry {
+ type_space: Some(type_space),
+ ..
+ }) => Ok(type_space),
+ _ => Err(error_unknown_symbol()),
+ }
+ }
+}
+
+struct IdentEntry<'input> {
+ name: Option<Cow<'input, str>>,
+ type_space: Option<(ast::Type, ast::StateSpace)>,
+}
+
+struct ScopedResolver<'input, 'b> {
+ flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
+ scopes: Vec<ScopeMarker<'input>>,
+}
+
+impl<'input, 'b> ScopedResolver<'input, 'b> {
+ fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
+ Self {
+ flat_resolver,
+ scopes: vec![ScopeMarker::new()],
+ }
+ }
+
+ fn start_scope(&mut self) {
+ self.scopes.push(ScopeMarker::new());
+ }
+
+ fn end_scope(&mut self) {
+ let scope = self.scopes.pop().unwrap();
+ scope.flush(self.flat_resolver);
+ }
+
+ fn add(
+ &mut self,
+ name: Cow<'input, str>,
+ type_space: Option<(ast::Type, ast::StateSpace)>,
+ ) -> Result<SpirvWord, TranslateError> {
+ let result = self.flat_resolver.current_id;
+ self.flat_resolver.current_id.0 += 1;
+ let current_scope = self.scopes.last_mut().unwrap();
+ if current_scope
+ .name_to_ident
+ .insert(name.clone(), result)
+ .is_some()
+ {
+ return Err(error_unknown_symbol());
+ }
+ current_scope.ident_map.insert(
+ result,
+ IdentEntry {
+ name: Some(name),
+ type_space,
+ },
+ );
+ Ok(result)
+ }
+
+ fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
+ self.scopes
+ .iter()
+ .rev()
+ .find_map(|resolver| resolver.name_to_ident.get(name).copied())
+ .ok_or_else(|| error_unreachable())
+ }
+
+ fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
+ let current_scope = self.scopes.last().unwrap();
+ current_scope
+ .name_to_ident
+ .get(label)
+ .copied()
+ .ok_or_else(|| error_unreachable())
+ }
+}
+
+struct ScopeMarker<'input> {
+ ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
+ name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
+}
+
+impl<'input> ScopeMarker<'input> {
+ fn new() -> Self {
+ Self {
+ ident_map: FxHashMap::default(),
+ name_to_ident: FxHashMap::default(),
+ }
+ }
+
+ fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
+ resolver.ident_map.extend(self.ident_map);
+ }
+}
+
+struct SpecialRegistersMap2 {
+ reg_to_id: FxHashMap<PtxSpecialRegister, SpirvWord>,
+ id_to_reg: FxHashMap<SpirvWord, PtxSpecialRegister>,
+}
+
+impl SpecialRegistersMap2 {
+ fn new(resolver: &mut ScopedResolver) -> Result<Self, TranslateError> {
+ let mut result = SpecialRegistersMap2 {
+ reg_to_id: FxHashMap::default(),
+ id_to_reg: FxHashMap::default(),
+ };
+ for sreg in PtxSpecialRegister::iter() {
+ let text = sreg.as_str();
+ let id = resolver.add(
+ Cow::Borrowed(text),
+ Some((sreg.get_type(), ast::StateSpace::Reg)),
+ )?;
+ result.reg_to_id.insert(sreg, id);
+ result.id_to_reg.insert(id, sreg);
+ }
+ Ok(result)
+ }
+
+ fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
+ self.id_to_reg.get(&id).copied()
+ }
+
+ fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
+ match self.reg_to_id.entry(reg) {
+ hash_map::Entry::Occupied(e) => *e.get(),
+ hash_map::Entry::Vacant(e) => {
+ let numeric_id = SpirvWord(current_id.0);
+ current_id.0 += 1;
+ e.insert(numeric_id);
+ self.id_to_reg.insert(numeric_id, reg);
+ numeric_id
+ }
+ }
+ }
+
+ fn generate_declarations<'a, 'input>(
+ resolver: &'a mut GlobalStringIdentResolver2<'input>,
+ ) -> impl ExactSizeIterator<
+ Item = (
+ PtxSpecialRegister,
+ ast::MethodDeclaration<'input, SpirvWord>,
+ ),
+ > + 'a {
+ PtxSpecialRegister::iter().map(|sreg| {
+ let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
+ let name =
+ ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
+ let return_type = sreg.get_function_return_type();
+ let input_type = sreg.get_function_return_type();
+ (
+ sreg,
+ ast::MethodDeclaration {
+ return_arguments: vec![ast::Variable {
+ align: None,
+ v_type: return_type.into(),
+ state_space: ast::StateSpace::Reg,
+ name: resolver
+ .register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
+ array_init: Vec::new(),
+ }],
+ name: name,
+ input_arguments: vec![ast::Variable {
+ align: None,
+ v_type: input_type.into(),
+ state_space: ast::StateSpace::Reg,
+ name: resolver
+ .register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))),
+ array_init: Vec::new(),
+ }],
+ shared_mem: None,
+ },
+ )
+ })
+ }
+}
+
+pub struct VectorAccess {
+ scalar_type: ast::ScalarType,
+ vector_width: u8,
+ dst: SpirvWord,
+ src: SpirvWord,
+ member: u8,
+}
|