diff options
Diffstat (limited to 'ptx/src/pass/mod.rs')
-rw-r--r-- | ptx/src/pass/mod.rs | 272 |
1 files changed, 269 insertions, 3 deletions
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 04d3e49..b82d3c5 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -13,15 +13,21 @@ use std::{ mem,
rc::Rc,
};
+use strum::IntoEnumIterator;
+use strum_macros::EnumIter;
mod convert_dynamic_shared_memory_usage;
mod convert_to_stateful_memory_access;
mod convert_to_typed;
+mod deparamize_functions;
pub(crate) mod emit_llvm;
mod emit_spirv;
mod expand_arguments;
+mod expand_operands;
mod extract_globals;
mod fix_special_registers;
+mod fix_special_registers2;
+mod insert_explicit_load_store;
mod insert_implicit_conversions;
mod insert_mem_ssa_statements;
mod normalize_identifiers;
@@ -68,6 +74,20 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl })
}
+pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
+ let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
+ let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
+ let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
+ let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
+ let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
+ let directives = resolve_function_pointers::run(directives)?;
+ let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
+ let directives = expand_operands::run(&mut flat_resolver, directives)?;
+ let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
+ let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
+ todo!()
+}
+
fn translate_directive<'input, 'a>(
id_defs: &'a mut GlobalStringIdResolver<'input>,
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
@@ -323,7 +343,7 @@ pub struct KernelInfo { pub uses_shared_mem: bool,
}
-#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
+#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)]
enum PtxSpecialRegister {
Tid,
Ntid,
@@ -346,6 +366,17 @@ impl PtxSpecialRegister { }
}
+ fn as_str(self) -> &'static str {
+ match self {
+ Self::Tid => "%tid",
+ Self::Ntid => "%ntid",
+ Self::Ctaid => "%ctaid",
+ Self::Nctaid => "%nctaid",
+ Self::Clock => "%clock",
+ Self::LanemaskLt => "%lanemask_lt",
+ }
+ }
+
fn get_type(self) -> ast::Type {
match self {
PtxSpecialRegister::Tid
@@ -726,6 +757,7 @@ enum Statement<I, P: ast::Operand> { PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails),
FunctionPointer(FunctionPointerDetails),
+ VectorAccess(VectorAccess),
}
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
@@ -894,6 +926,36 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> { offset_src,
})
}
+ Statement::VectorAccess(VectorAccess {
+ scalar_type,
+ vector_width,
+ dst,
+ src: vector_src,
+ member,
+ }) => {
+ let dst: SpirvWord = visitor.visit_ident(
+ dst,
+ Some((&scalar_type.into(), ast::StateSpace::Reg)),
+ true,
+ false,
+ )?;
+ let src = visitor.visit_ident(
+ vector_src,
+ Some((
+ &ast::Type::Vector(vector_width, scalar_type),
+ ast::StateSpace::Reg,
+ )),
+ false,
+ false,
+ )?;
+ Statement::VectorAccess(VectorAccess {
+ scalar_type,
+ vector_width,
+ dst,
+ src,
+ member,
+ })
+ }
Statement::RepackVector(RepackVectorDetails {
is_extract,
typ,
@@ -1448,6 +1510,7 @@ fn compute_denorm_information<'input>( Statement::Label(_) => {}
Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {}
+ Statement::VectorAccess { .. } => {}
Statement::RepackVector(_) => {}
Statement::FunctionPointer(_) => {}
}
@@ -1668,7 +1731,7 @@ pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> { }
pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> {
- pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
pub globals: Vec<ast::Variable<SpirvWord>>,
pub body: Option<Vec<Statement<Instruction, Operand>>>,
import_as: Option<String>,
@@ -1712,14 +1775,35 @@ struct GlobalStringIdentResolver2<'input> { }
impl<'input> GlobalStringIdentResolver2<'input> {
- fn register_intermediate(
+ fn new(spirv_word: SpirvWord) -> Self {
+ Self {
+ current_id: spirv_word,
+ ident_map: FxHashMap::default(),
+ }
+ }
+
+ fn register_named(
&mut self,
+ name: Cow<'input, str>,
type_space: Option<(ast::Type, ast::StateSpace)>,
) -> SpirvWord {
let new_id = self.current_id;
self.ident_map.insert(
new_id,
IdentEntry {
+ name: Some(name),
+ type_space,
+ },
+ );
+ self.current_id.0 += 1;
+ new_id
+ }
+
+ fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
+ let new_id = self.current_id;
+ self.ident_map.insert(
+ new_id,
+ IdentEntry {
name: None,
type_space,
},
@@ -1727,9 +1811,191 @@ impl<'input> GlobalStringIdentResolver2<'input> { self.current_id.0 += 1;
new_id
}
+
+ fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> {
+ match self.ident_map.get(&id) {
+ Some(IdentEntry {
+ type_space: Some(type_space),
+ ..
+ }) => Ok(type_space),
+ _ => Err(error_unknown_symbol()),
+ }
+ }
}
struct IdentEntry<'input> {
name: Option<Cow<'input, str>>,
type_space: Option<(ast::Type, ast::StateSpace)>,
}
+
+struct ScopedResolver<'input, 'b> {
+ flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
+ scopes: Vec<ScopeMarker<'input>>,
+}
+
+impl<'input, 'b> ScopedResolver<'input, 'b> {
+ fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
+ Self {
+ flat_resolver,
+ scopes: vec![ScopeMarker::new()],
+ }
+ }
+
+ fn start_scope(&mut self) {
+ self.scopes.push(ScopeMarker::new());
+ }
+
+ fn end_scope(&mut self) {
+ let scope = self.scopes.pop().unwrap();
+ scope.flush(self.flat_resolver);
+ }
+
+ fn add(
+ &mut self,
+ name: Cow<'input, str>,
+ type_space: Option<(ast::Type, ast::StateSpace)>,
+ ) -> Result<SpirvWord, TranslateError> {
+ let result = self.flat_resolver.current_id;
+ self.flat_resolver.current_id.0 += 1;
+ let current_scope = self.scopes.last_mut().unwrap();
+ if current_scope
+ .name_to_ident
+ .insert(name.clone(), result)
+ .is_some()
+ {
+ return Err(error_unknown_symbol());
+ }
+ current_scope.ident_map.insert(
+ result,
+ IdentEntry {
+ name: Some(name),
+ type_space,
+ },
+ );
+ Ok(result)
+ }
+
+ fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
+ self.scopes
+ .iter()
+ .rev()
+ .find_map(|resolver| resolver.name_to_ident.get(name).copied())
+ .ok_or_else(|| error_unreachable())
+ }
+
+ fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
+ let current_scope = self.scopes.last().unwrap();
+ current_scope
+ .name_to_ident
+ .get(label)
+ .copied()
+ .ok_or_else(|| error_unreachable())
+ }
+}
+
+struct ScopeMarker<'input> {
+ ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
+ name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
+}
+
+impl<'input> ScopeMarker<'input> {
+ fn new() -> Self {
+ Self {
+ ident_map: FxHashMap::default(),
+ name_to_ident: FxHashMap::default(),
+ }
+ }
+
+ fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
+ resolver.ident_map.extend(self.ident_map);
+ }
+}
+
+struct SpecialRegistersMap2 {
+ reg_to_id: FxHashMap<PtxSpecialRegister, SpirvWord>,
+ id_to_reg: FxHashMap<SpirvWord, PtxSpecialRegister>,
+}
+
+impl SpecialRegistersMap2 {
+ fn new(resolver: &mut ScopedResolver) -> Result<Self, TranslateError> {
+ let mut result = SpecialRegistersMap2 {
+ reg_to_id: FxHashMap::default(),
+ id_to_reg: FxHashMap::default(),
+ };
+ for sreg in PtxSpecialRegister::iter() {
+ let text = sreg.as_str();
+ let id = resolver.add(
+ Cow::Borrowed(text),
+ Some((sreg.get_type(), ast::StateSpace::Reg)),
+ )?;
+ result.reg_to_id.insert(sreg, id);
+ result.id_to_reg.insert(id, sreg);
+ }
+ Ok(result)
+ }
+
+ fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
+ self.id_to_reg.get(&id).copied()
+ }
+
+ fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
+ match self.reg_to_id.entry(reg) {
+ hash_map::Entry::Occupied(e) => *e.get(),
+ hash_map::Entry::Vacant(e) => {
+ let numeric_id = SpirvWord(current_id.0);
+ current_id.0 += 1;
+ e.insert(numeric_id);
+ self.id_to_reg.insert(numeric_id, reg);
+ numeric_id
+ }
+ }
+ }
+
+ fn generate_declarations<'a, 'input>(
+ resolver: &'a mut GlobalStringIdentResolver2<'input>,
+ ) -> impl ExactSizeIterator<
+ Item = (
+ PtxSpecialRegister,
+ ast::MethodDeclaration<'input, SpirvWord>,
+ ),
+ > + 'a {
+ PtxSpecialRegister::iter().map(|sreg| {
+ let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
+ let name =
+ ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
+ let return_type = sreg.get_function_return_type();
+ let input_type = sreg.get_function_return_type();
+ (
+ sreg,
+ ast::MethodDeclaration {
+ return_arguments: vec![ast::Variable {
+ align: None,
+ v_type: return_type.into(),
+ state_space: ast::StateSpace::Reg,
+ name: resolver
+ .register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
+ array_init: Vec::new(),
+ }],
+ name: name,
+ input_arguments: vec![ast::Variable {
+ align: None,
+ v_type: input_type.into(),
+ state_space: ast::StateSpace::Reg,
+ name: resolver
+ .register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))),
+ array_init: Vec::new(),
+ }],
+ shared_mem: None,
+ },
+ )
+ })
+ }
+}
+
+pub struct VectorAccess {
+ scalar_type: ast::ScalarType,
+ vector_width: u8,
+ dst: SpirvWord,
+ src: SpirvWord,
+ member: u8,
+}
|