diff options
author | Andrzej Janik <[email protected]> | 2024-09-16 17:08:12 +0200 |
---|---|---|
committer | Andrzej Janik <[email protected]> | 2024-09-16 17:08:12 +0200 |
commit | e87388bc352601201960458c2768b571c5947696 (patch) | |
tree | da03bae8e883f5200325e8d53d0eaa43b01fe85f | |
parent | 3b5efbf88b99392f12e95e084fabb5a8960ae04c (diff) | |
download | ZLUDA-e87388bc352601201960458c2768b571c5947696.tar.gz ZLUDA-e87388bc352601201960458c2768b571c5947696.zip |
Port normalize_predicates
-rw-r--r-- | ptx/src/pass/mod.rs | 42 | ||||
-rw-r--r-- | ptx/src/pass/normalize_identifiers2.rs | 68 | ||||
-rw-r--r-- | ptx/src/pass/normalize_predicates2.rs | 84 |
3 files changed, 157 insertions, 37 deletions
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 409425f..9277de4 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -1,5 +1,6 @@ use ptx_parser as ast;
use rspirv::{binary::Assemble, dr};
+use rustc_hash::FxHashMap;
use std::hash::Hash;
use std::num::NonZeroU8;
use std::{
@@ -27,6 +28,7 @@ mod normalize_identifiers; mod normalize_identifiers2;
mod normalize_labels;
mod normalize_predicates;
+mod normalize_predicates2;
static ZLUDA_PTX_IMPL_INTEL: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.spv");
static ZLUDA_PTX_IMPL_AMD: &'static [u8] = include_bytes!("../../lib/zluda_ptx_impl.bc");
@@ -1690,3 +1692,43 @@ type NormalizedFunction2<'input> = Function2< ),
ast::ParsedOperand<SpirvWord>,
>;
+
+type UnconditionalDirective<'input> = Directive2<
+ 'input,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+type UnconditionalFunction<'input> = Function2<
+ 'input,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ast::ParsedOperand<SpirvWord>,
+>;
+
+struct GlobalStringIdentResolver2<'input> {
+ pub(crate) current_id: SpirvWord,
+ pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
+}
+
+impl<'input> GlobalStringIdentResolver2<'input> {
+ fn register_intermediate(
+ &mut self,
+ type_space: Option<(ast::Type, ast::StateSpace)>,
+ ) -> SpirvWord {
+ let new_id = self.current_id;
+ self.ident_map.insert(
+ new_id,
+ IdentEntry {
+ name: None,
+ type_space,
+ },
+ );
+ self.current_id.0 += 1;
+ new_id
+ }
+}
+
+struct IdentEntry<'input> {
+ name: Option<Cow<'input, str>>,
+ type_space: Option<(ast::Type, ast::StateSpace)>,
+}
diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs index 925feb7..e3fb88d 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -3,45 +3,45 @@ use ptx_parser as ast; use rustc_hash::FxHashMap;
pub(crate) fn run<'input>(
- fn_defs: &mut GlobalStringIdentResolver<'input>,
+ fn_defs: &mut GlobalStringIdentResolver2<'input>,
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
let mut resolver = NameResolver::new(fn_defs);
let result = directives
.into_iter()
- .map(|directive| remap_directive(&mut resolver, directive))
+ .map(|directive| run_directive(&mut resolver, directive))
.collect::<Result<Vec<_>, _>>()?;
resolver.end_scope();
Ok(result)
}
-fn remap_directive<'input, 'b>(
+fn run_directive<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<NormalizedDirective2<'input>, TranslateError> {
Ok(match directive {
ast::Directive::Variable(linking, var) => {
- NormalizedDirective2::Variable(linking, remap_variable(resolver, var)?)
+ NormalizedDirective2::Variable(linking, run_variable(resolver, var)?)
}
ast::Directive::Method(linking, directive) => {
- NormalizedDirective2::Method(remap_method(resolver, linking, directive)?)
+ NormalizedDirective2::Method(run_method(resolver, linking, directive)?)
}
})
}
-fn remap_method<'input, 'b>(
+fn run_method<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
linkage: ast::LinkingDirective,
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<NormalizedFunction2<'input>, TranslateError> {
let name = match method.func_directive.name {
ast::MethodName::Kernel(name) => ast::MethodName::Kernel(name),
- ast::MethodName::Func(text) => ast::MethodName::Func(
- resolver.add(Cow::Borrowed(method.func_directive.name.text()), None)?,
- ),
+ ast::MethodName::Func(text) => {
+ ast::MethodName::Func(resolver.add(Cow::Borrowed(text), None)?)
+ }
};
resolver.start_scope();
- let func_decl = Rc::new(RefCell::new(remap_function_decl(
+ let func_decl = Rc::new(RefCell::new(run_function_decl(
resolver,
method.func_directive,
name,
@@ -50,7 +50,7 @@ fn remap_method<'input, 'b>( .body
.map(|statements| {
let mut result = Vec::with_capacity(statements.len());
- remap_statements(resolver, &mut result, statements)?;
+ run_statements(resolver, &mut result, statements)?;
Ok::<_, TranslateError>(result)
})
.transpose()?;
@@ -65,7 +65,7 @@ fn remap_method<'input, 'b>( })
}
-fn remap_function_decl<'input, 'b>(
+fn run_function_decl<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
func_directive: ast::MethodDeclaration<'input, &'input str>,
name: ast::MethodName<'input, SpirvWord>,
@@ -74,12 +74,12 @@ fn remap_function_decl<'input, 'b>( let return_arguments = func_directive
.return_arguments
.into_iter()
- .map(|var| remap_variable(resolver, var))
+ .map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
let input_arguments = func_directive
.input_arguments
.into_iter()
- .map(|var| remap_variable(resolver, var))
+ .map(|var| run_variable(resolver, var))
.collect::<Result<Vec<_>, _>>()?;
Ok(ast::MethodDeclaration {
return_arguments,
@@ -89,7 +89,7 @@ fn remap_function_decl<'input, 'b>( })
}
-fn remap_variable<'input, 'b>(
+fn run_variable<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
variable: ast::Variable<&'input str>,
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
@@ -105,7 +105,7 @@ fn remap_variable<'input, 'b>( })
}
-fn remap_statements<'input, 'b>(
+fn run_statements<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
@@ -123,7 +123,7 @@ fn remap_statements<'input, 'b>( ast::Statement::Label(label) => {
result.push(Statement::Label(resolver.get_in_current_scope(label)?))
}
- ast::Statement::Variable(variable) => remap_multivariable(resolver, result, variable)?,
+ ast::Statement::Variable(variable) => run_multivariable(resolver, result, variable)?,
ast::Statement::Instruction(predicate, instruction) => {
result.push(Statement::Instruction((
predicate
@@ -134,12 +134,12 @@ fn remap_statements<'input, 'b>( })
})
.transpose()?,
- remap_instruction(resolver, instruction)?,
+ run_instruction(resolver, instruction)?,
)))
}
ast::Statement::Block(block) => {
resolver.start_scope();
- remap_statements(resolver, result, block)?;
+ run_statements(resolver, result, block)?;
resolver.end_scope();
}
}
@@ -147,7 +147,7 @@ fn remap_statements<'input, 'b>( Ok(())
}
-fn remap_instruction<'input, 'b>(
+fn run_instruction<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
@@ -162,7 +162,7 @@ fn remap_instruction<'input, 'b>( })
}
-fn remap_multivariable<'input, 'b>(
+fn run_multivariable<'input, 'b>(
resolver: &mut NameResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
variable: ast::MultiVariable<&'input str>,
@@ -203,12 +203,12 @@ fn remap_multivariable<'input, 'b>( }
struct NameResolver<'input, 'b> {
- flat_resolver: &'b mut GlobalStringIdentResolver<'input>,
+ flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
scopes: Vec<ScopeStringIdentResolver<'input>>,
}
impl<'input, 'b> NameResolver<'input, 'b> {
- fn new(flat_resolver: &'b mut GlobalStringIdentResolver<'input>) -> Self {
+ fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
Self {
flat_resolver,
scopes: vec![ScopeStringIdentResolver::new()],
@@ -239,9 +239,13 @@ impl<'input, 'b> NameResolver<'input, 'b> { {
return Err(error_unknown_symbol());
}
- current_scope
- .ident_map
- .insert(result, IdentEntry { name, type_space });
+ current_scope.ident_map.insert(
+ result,
+ IdentEntry {
+ name: Some(name),
+ type_space,
+ },
+ );
Ok(result)
}
@@ -276,17 +280,7 @@ impl<'input> ScopeStringIdentResolver<'input> { }
}
- fn flush(self, resolver: &mut GlobalStringIdentResolver<'input>) {
+ fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
resolver.ident_map.extend(self.ident_map);
}
}
-
-struct GlobalStringIdentResolver<'input> {
- pub(crate) current_id: SpirvWord,
- pub(crate) ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
-}
-
-struct IdentEntry<'input> {
- name: Cow<'input, str>,
- type_space: Option<(ast::Type, ast::StateSpace)>,
-}
diff --git a/ptx/src/pass/normalize_predicates2.rs b/ptx/src/pass/normalize_predicates2.rs new file mode 100644 index 0000000..2d15bba --- /dev/null +++ b/ptx/src/pass/normalize_predicates2.rs @@ -0,0 +1,84 @@ +use super::*;
+use ptx_parser as ast;
+
+pub(crate) fn run<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<NormalizedDirective2<'input>>,
+) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
+ directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directive: NormalizedDirective2<'input>,
+) -> Result<UnconditionalDirective<'input>, TranslateError> {
+ Ok(match directive {
+ Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
+ Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
+ })
+}
+
+fn run_method<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ method: NormalizedFunction2<'input>,
+) -> Result<UnconditionalFunction<'input>, TranslateError> {
+ let body = method
+ .body
+ .map(|statements| {
+ let mut result = Vec::with_capacity(statements.len());
+ for statement in statements {
+ run_statement(resolver, &mut result, statement)?;
+ }
+ Ok::<_, TranslateError>(result)
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: method.func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ result: &mut Vec<UnconditionalStatement>,
+ statement: NormalizedStatement,
+) -> Result<(), TranslateError> {
+ Ok(match statement {
+ Statement::Label(label) => result.push(Statement::Label(label)),
+ Statement::Variable(var) => result.push(Statement::Variable(var)),
+ Statement::Instruction((predicate, instruction)) => {
+ if let Some(pred) = predicate {
+ let if_true = resolver.register_intermediate(None);
+ let if_false = resolver.register_intermediate(None);
+ let folded_bra = match &instruction {
+ ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
+ _ => None,
+ };
+ let mut branch = BrachCondition {
+ predicate: pred.label,
+ if_true: folded_bra.unwrap_or(if_true),
+ if_false,
+ };
+ if pred.not {
+ std::mem::swap(&mut branch.if_true, &mut branch.if_false);
+ }
+ result.push(Statement::Conditional(branch));
+ if folded_bra.is_none() {
+ result.push(Statement::Label(if_true));
+ result.push(Statement::Instruction(instruction));
+ }
+ result.push(Statement::Label(if_false));
+ } else {
+ result.push(Statement::Instruction(instruction));
+ }
+ }
+ _ => return Err(error_unreachable()),
+ })
+}
|