diff options
-rw-r--r-- | ptx/Cargo.toml | 2 | ||||
-rw-r--r-- | ptx/src/pass/deparamize_functions.rs | 141 | ||||
-rw-r--r-- | ptx/src/pass/emit_llvm.rs | 1 | ||||
-rw-r--r-- | ptx/src/pass/emit_spirv.rs | 1 | ||||
-rw-r--r-- | ptx/src/pass/expand_operands.rs | 289 | ||||
-rw-r--r-- | ptx/src/pass/fix_special_registers2.rs | 209 | ||||
-rw-r--r-- | ptx/src/pass/insert_explicit_load_store.rs | 273 | ||||
-rw-r--r-- | ptx/src/pass/insert_implicit_conversions.rs | 7 | ||||
-rw-r--r-- | ptx/src/pass/mod.rs | 272 | ||||
-rw-r--r-- | ptx/src/pass/normalize_identifiers2.rs | 111 | ||||
-rw-r--r-- | ptx/src/pass/normalize_labels.rs | 1 | ||||
-rw-r--r-- | ptx/src/pass/normalize_predicates2.rs | 4 | ||||
-rw-r--r-- | ptx/src/pass/resolve_function_pointers.rs | 2 |
13 files changed, 1208 insertions, 105 deletions
diff --git a/ptx/Cargo.toml b/ptx/Cargo.toml index fd86f15..e2c4ff8 100644 --- a/ptx/Cargo.toml +++ b/ptx/Cargo.toml @@ -18,6 +18,8 @@ bit-vec = "0.6" half ="1.6" bitflags = "1.2" rustc-hash = "2.0.0" +strum = "0.26" +strum_macros = "0.26" [dependencies.lalrpop-util] version = "0.19.12" diff --git a/ptx/src/pass/deparamize_functions.rs b/ptx/src/pass/deparamize_functions.rs new file mode 100644 index 0000000..04c8831 --- /dev/null +++ b/ptx/src/pass/deparamize_functions.rs @@ -0,0 +1,141 @@ +use std::collections::BTreeMap;
+
+use super::*;
+
+pub(super) fn run<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'input>(
+ resolver: &mut GlobalStringIdentResolver2,
+ directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ Ok(match directive {
+ var @ Directive2::Variable(..) => var,
+ Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
+ })
+}
+
+fn run_method<'input>(
+ resolver: &mut GlobalStringIdentResolver2,
+ mut method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ if method.func_decl.name.is_kernel() {
+ return Ok(method);
+ }
+ let is_declaration = method.body.is_none();
+ let mut body = Vec::new();
+ let mut remap_returns = Vec::new();
+ for arg in method.func_decl.return_arguments.iter_mut() {
+ match arg.state_space {
+ ptx_parser::StateSpace::Param => {
+ arg.state_space = ptx_parser::StateSpace::Reg;
+ let old_name = arg.name;
+ arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
+ if is_declaration {
+ continue;
+ }
+ remap_returns.push((old_name, arg.name, arg.v_type.clone()));
+ body.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: old_name,
+ v_type: arg.v_type.clone(),
+ state_space: ptx_parser::StateSpace::Param,
+ array_init: Vec::new(),
+ }));
+ }
+ ptx_parser::StateSpace::Reg => {}
+ _ => return Err(error_unreachable()),
+ }
+ }
+ for arg in method.func_decl.input_arguments.iter_mut() {
+ match arg.state_space {
+ ptx_parser::StateSpace::Param => {
+ arg.state_space = ptx_parser::StateSpace::Reg;
+ let old_name = arg.name;
+ arg.name = resolver.register_unnamed(Some((arg.v_type.clone(), arg.state_space)));
+ if is_declaration {
+ continue;
+ }
+ body.push(Statement::Variable(ast::Variable {
+ align: None,
+ name: old_name,
+ v_type: arg.v_type.clone(),
+ state_space: ptx_parser::StateSpace::Param,
+ array_init: Vec::new(),
+ }));
+ body.push(Statement::Instruction(ast::Instruction::St {
+ data: ast::StData {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Param,
+ caching: ast::StCacheOperator::Writethrough,
+ typ: arg.v_type.clone(),
+ },
+ arguments: ast::StArgs {
+ src1: old_name,
+ src2: arg.name,
+ },
+ }));
+ }
+ ptx_parser::StateSpace::Reg => {}
+ _ => return Err(error_unreachable()),
+ }
+ }
+ if remap_returns.is_empty() {
+ return Ok(method);
+ }
+ let body = method
+ .body
+ .map(|statements| {
+ for statement in statements {
+ run_statement(&remap_returns, &mut body, statement)?;
+ }
+ Ok::<_, TranslateError>(body)
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: method.func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'input>(
+ remap_returns: &Vec<(SpirvWord, SpirvWord, ast::Type)>,
+ result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
+ statement: Statement<ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<(), TranslateError> {
+ match statement {
+ Statement::Instruction(ast::Instruction::Ret { .. }) => {
+ for (old_name, new_name, type_) in remap_returns.iter().cloned() {
+ result.push(Statement::Instruction(ast::Instruction::Ld {
+ data: ast::LdDetails {
+ qualifier: ast::LdStQualifier::Weak,
+ state_space: ast::StateSpace::Reg,
+ caching: ast::LdCacheOperator::Cached,
+ typ: type_,
+ non_coherent: false,
+ },
+ arguments: ast::LdArgs {
+ dst: new_name,
+ src: old_name,
+ },
+ }));
+ }
+ result.push(statement);
+ }
+ statement => {
+ result.push(statement);
+ }
+ }
+ Ok(())
+}
diff --git a/ptx/src/pass/emit_llvm.rs b/ptx/src/pass/emit_llvm.rs index de85efc..3060335 100644 --- a/ptx/src/pass/emit_llvm.rs +++ b/ptx/src/pass/emit_llvm.rs @@ -308,6 +308,7 @@ impl<'a, 'input> MethodEmitContext<'a, 'input> { Statement::PtrAccess(_) => todo!(),
Statement::RepackVector(_) => todo!(),
Statement::FunctionPointer(_) => todo!(),
+ Statement::VectorAccess(_) => todo!(),
})
}
diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs index ae4dcfe..120a477 100644 --- a/ptx/src/pass/emit_spirv.rs +++ b/ptx/src/pass/emit_spirv.rs @@ -1561,6 +1561,7 @@ fn emit_function_body_ops<'input>( builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?;
}
}
+ Statement::VectorAccess(vector_access) => todo!(),
}
}
Ok(())
diff --git a/ptx/src/pass/expand_operands.rs b/ptx/src/pass/expand_operands.rs new file mode 100644 index 0000000..3dabf40 --- /dev/null +++ b/ptx/src/pass/expand_operands.rs @@ -0,0 +1,289 @@ +use super::*;
+
+pub(super) fn run<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<UnconditionalDirective<'input>>,
+) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directive: Directive2<
+ 'input,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ast::ParsedOperand<SpirvWord>,
+ >,
+) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ Ok(match directive {
+ Directive2::Variable(linking, var) => Directive2::Variable(linking, var),
+ Directive2::Method(method) => Directive2::Method(run_method(resolver, method)?),
+ })
+}
+
+fn run_method<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ method: Function2<
+ 'input,
+ ast::Instruction<ast::ParsedOperand<SpirvWord>>,
+ ast::ParsedOperand<SpirvWord>,
+ >,
+) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ let body = method
+ .body
+ .map(|statements| {
+ let mut result = Vec::with_capacity(statements.len());
+ for statement in statements {
+ run_statement(resolver, &mut result, statement)?;
+ }
+ Ok::<_, TranslateError>(result)
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: method.func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ result: &mut Vec<Statement<ast::Instruction<SpirvWord>, SpirvWord>>,
+ statement: UnconditionalStatement,
+) -> Result<(), TranslateError> {
+ let mut visitor = FlattenArguments::new(resolver, result);
+ let new_statement = statement.visit_map(&mut visitor)?;
+ visitor.result.push(new_statement);
+ Ok(())
+}
+
+struct FlattenArguments<'a, 'input> {
+ result: &'a mut Vec<ExpandedStatement>,
+ resolver: &'a mut GlobalStringIdentResolver2<'input>,
+ post_stmts: Vec<ExpandedStatement>,
+}
+
+impl<'a, 'input> FlattenArguments<'a, 'input> {
+ fn new(
+ resolver: &'a mut GlobalStringIdentResolver2<'input>,
+ result: &'a mut Vec<ExpandedStatement>,
+ ) -> Self {
+ FlattenArguments {
+ result,
+ resolver,
+ post_stmts: Vec::new(),
+ }
+ }
+
+ fn reg(&mut self, name: SpirvWord) -> Result<SpirvWord, TranslateError> {
+ Ok(name)
+ }
+
+ fn reg_offset(
+ &mut self,
+ reg: SpirvWord,
+ offset: i32,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ _is_dst: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ let (type_, state_space) = if let Some((type_, state_space)) = type_space {
+ (type_, state_space)
+ } else {
+ return Err(TranslateError::UntypedSymbol);
+ };
+ if state_space == ast::StateSpace::Reg {
+ let (reg_type, reg_space) = self.resolver.get_typed(reg)?;
+ if *reg_space != ast::StateSpace::Reg {
+ return Err(error_mismatched_type());
+ }
+ let reg_scalar_type = match reg_type {
+ ast::Type::Scalar(underlying_type) => *underlying_type,
+ _ => return Err(error_mismatched_type()),
+ };
+ let reg_type = reg_type.clone();
+ let id_constant_stmt = self
+ .resolver
+ .register_unnamed(Some((reg_type.clone(), ast::StateSpace::Reg)));
+ self.result.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: reg_scalar_type,
+ value: ast::ImmediateValue::S64(offset as i64),
+ }));
+ let arith_details = match reg_scalar_type.kind() {
+ ast::ScalarKind::Signed => ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: reg_scalar_type,
+ saturate: false,
+ }),
+ ast::ScalarKind::Unsigned | ast::ScalarKind::Bit => {
+ ast::ArithDetails::Integer(ast::ArithInteger {
+ type_: reg_scalar_type,
+ saturate: false,
+ })
+ }
+ _ => return Err(error_unreachable()),
+ };
+ let id_add_result = self
+ .resolver
+ .register_unnamed(Some((reg_type, state_space)));
+ self.result
+ .push(Statement::Instruction(ast::Instruction::Add {
+ data: arith_details,
+ arguments: ast::AddArgs {
+ dst: id_add_result,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ }));
+ Ok(id_add_result)
+ } else {
+ let id_constant_stmt = self.resolver.register_unnamed(Some((
+ ast::Type::Scalar(ast::ScalarType::S64),
+ ast::StateSpace::Reg,
+ )));
+ self.result.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: ast::ScalarType::S64,
+ value: ast::ImmediateValue::S64(offset as i64),
+ }));
+ let dst = self
+ .resolver
+ .register_unnamed(Some((type_.clone(), state_space)));
+ self.result.push(Statement::PtrAccess(PtrAccess {
+ underlying_type: type_.clone(),
+ state_space: state_space,
+ dst,
+ ptr_src: reg,
+ offset_src: id_constant_stmt,
+ }));
+ Ok(dst)
+ }
+ }
+
+ fn immediate(
+ &mut self,
+ value: ast::ImmediateValue,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ ) -> Result<SpirvWord, TranslateError> {
+ let (scalar_t, state_space) =
+ if let Some((ast::Type::Scalar(scalar), state_space)) = type_space {
+ (*scalar, state_space)
+ } else {
+ return Err(TranslateError::UntypedSymbol);
+ };
+ let id = self
+ .resolver
+ .register_unnamed(Some((ast::Type::Scalar(scalar_t), state_space)));
+ self.result.push(Statement::Constant(ConstantDefinition {
+ dst: id,
+ typ: scalar_t,
+ value,
+ }));
+ Ok(id)
+ }
+
+ fn vec_member(
+ &mut self,
+ vector_src: SpirvWord,
+ member: u8,
+ _type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ if is_dst {
+ return Err(error_mismatched_type());
+ }
+ let (vector_width, scalar_type, space) = match self.resolver.get_typed(vector_src)? {
+ (ast::Type::Vector(vector_width, scalar_t), space) => {
+ (*vector_width, *scalar_t, *space)
+ }
+ _ => return Err(error_mismatched_type()),
+ };
+ let temporary = self
+ .resolver
+ .register_unnamed(Some((scalar_type.into(), space)));
+ self.result.push(Statement::VectorAccess(VectorAccess {
+ scalar_type,
+ vector_width,
+ dst: temporary,
+ src: vector_src,
+ member: member,
+ }));
+ Ok(temporary)
+ }
+
+ fn vec_pack(
+ &mut self,
+ vecs: Vec<SpirvWord>,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ let (scalar_t, state_space) = match type_space {
+ Some((ast::Type::Vector(_, scalar_t), space)) => (*scalar_t, space),
+ _ => return Err(error_mismatched_type()),
+ };
+ let temp_vec = self
+ .resolver
+ .register_unnamed(Some((scalar_t.into(), state_space)));
+ let statement = Statement::RepackVector(RepackVectorDetails {
+ is_extract: is_dst,
+ typ: scalar_t,
+ packed: temp_vec,
+ unpacked: vecs,
+ relaxed_type_check,
+ });
+ if is_dst {
+ self.post_stmts.push(statement);
+ } else {
+ self.result.push(statement);
+ }
+ Ok(temp_vec)
+ }
+}
+
+impl<'a, 'b> ast::VisitorMap<ast::ParsedOperand<SpirvWord>, SpirvWord, TranslateError>
+ for FlattenArguments<'a, 'b>
+{
+ fn visit(
+ &mut self,
+ args: ast::ParsedOperand<SpirvWord>,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ match args {
+ ast::ParsedOperand::Reg(r) => self.reg(r),
+ ast::ParsedOperand::Imm(x) => self.immediate(x, type_space),
+ ast::ParsedOperand::RegOffset(reg, offset) => {
+ self.reg_offset(reg, offset, type_space, is_dst)
+ }
+ ast::ParsedOperand::VecMember(vec, member) => {
+ self.vec_member(vec, member, type_space, is_dst)
+ }
+ ast::ParsedOperand::VecPack(vecs) => {
+ self.vec_pack(vecs, type_space, is_dst, relaxed_type_check)
+ }
+ }
+ }
+
+ fn visit_ident(
+ &mut self,
+ name: <TypedOperand as ast::Operand>::Ident,
+ _type_space: Option<(&ast::Type, ast::StateSpace)>,
+ _is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<<SpirvWord as ast::Operand>::Ident, TranslateError> {
+ self.reg(name)
+ }
+}
+
+impl Drop for FlattenArguments<'_, '_> {
+ fn drop(&mut self) {
+ self.result.extend(self.post_stmts.drain(..));
+ }
+}
diff --git a/ptx/src/pass/fix_special_registers2.rs b/ptx/src/pass/fix_special_registers2.rs new file mode 100644 index 0000000..97f6356 --- /dev/null +++ b/ptx/src/pass/fix_special_registers2.rs @@ -0,0 +1,209 @@ +use super::*;
+
+pub(super) fn run<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ special_registers: &'a SpecialRegistersMap2,
+ directives: Vec<UnconditionalDirective<'input>>,
+) -> Result<Vec<UnconditionalDirective<'input>>, TranslateError> {
+ let declarations = SpecialRegistersMap2::generate_declarations(resolver);
+ let mut result = Vec::with_capacity(declarations.len() + directives.len());
+ let mut sreg_to_function =
+ FxHashMap::with_capacity_and_hasher(declarations.len(), Default::default());
+ for (sreg, declaration) in declarations {
+ let name = if let ast::MethodName::Func(name) = declaration.name {
+ name
+ } else {
+ return Err(error_unreachable());
+ };
+ result.push(UnconditionalDirective::Method(UnconditionalFunction {
+ func_decl: declaration,
+ globals: Vec::new(),
+ body: None,
+ import_as: None,
+ tuning: Vec::new(),
+ linkage: ast::LinkingDirective::EXTERN,
+ }));
+ sreg_to_function.insert(sreg, name);
+ }
+ let mut visitor = SpecialRegisterResolver {
+ resolver,
+ special_registers,
+ sreg_to_function,
+ result: Vec::new(),
+ };
+ directives
+ .into_iter()
+ .map(|directive| run_directive(&mut visitor, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'a, 'input>(
+ visitor: &mut SpecialRegisterResolver<'a, 'input>,
+ directive: UnconditionalDirective<'input>,
+) -> Result<UnconditionalDirective<'input>, TranslateError> {
+ Ok(match directive {
+ var @ Directive2::Variable(..) => var,
+ Directive2::Method(method) => Directive2::Method(run_method(visitor, method)?),
+ })
+}
+
+fn run_method<'a, 'input>(
+ visitor: &mut SpecialRegisterResolver<'a, 'input>,
+ method: UnconditionalFunction<'input>,
+) -> Result<UnconditionalFunction<'input>, TranslateError> {
+ let body = method
+ .body
+ .map(|statements| {
+ let mut result = Vec::with_capacity(statements.len());
+ for statement in statements {
+ run_statement(visitor, &mut result, statement)?;
+ }
+ Ok::<_, TranslateError>(result)
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: method.func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'a, 'input>(
+ visitor: &mut SpecialRegisterResolver<'a, 'input>,
+ result: &mut Vec<UnconditionalStatement>,
+ statement: UnconditionalStatement,
+) -> Result<(), TranslateError> {
+ let converted_statement = statement.visit_map(visitor)?;
+ result.extend(visitor.result.drain(..));
+ result.push(converted_statement);
+ Ok(())
+}
+
+struct SpecialRegisterResolver<'a, 'input> {
+ resolver: &'a mut GlobalStringIdentResolver2<'input>,
+ special_registers: &'a SpecialRegistersMap2,
+ sreg_to_function: FxHashMap<PtxSpecialRegister, SpirvWord>,
+ result: Vec<UnconditionalStatement>,
+}
+
+impl<'a, 'b, 'input>
+ ast::VisitorMap<ast::ParsedOperand<SpirvWord>, ast::ParsedOperand<SpirvWord>, TranslateError>
+ for SpecialRegisterResolver<'a, 'input>
+{
+ fn visit(
+ &mut self,
+ operand: ast::ParsedOperand<SpirvWord>,
+ _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<ast::ParsedOperand<SpirvWord>, TranslateError> {
+ map_operand(operand, &mut |ident, vector_index| {
+ self.replace_sreg(ident, vector_index, is_dst)
+ })
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: SpirvWord,
+ _type_space: Option<(&ptx_parser::Type, ptx_parser::StateSpace)>,
+ is_dst: bool,
+ _relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ self.replace_sreg(args, None, is_dst)
+ }
+}
+
+impl<'a, 'b, 'input> SpecialRegisterResolver<'a, 'input> {
+ fn replace_sreg(
+ &mut self,
+ name: SpirvWord,
+ vector_index: Option<u8>,
+ is_dst: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ if let Some(sreg) = self.special_registers.get(name) {
+ if is_dst {
+ return Err(error_mismatched_type());
+ }
+ let input_arguments = match (vector_index, sreg.get_function_input_type()) {
+ (Some(idx), Some(inp_type)) => {
+ if inp_type != ast::ScalarType::U8 {
+ return Err(TranslateError::Unreachable);
+ }
+ let constant = self.resolver.register_unnamed(Some((
+ ast::Type::Scalar(inp_type),
+ ast::StateSpace::Reg,
+ )));
+ self.result.push(Statement::Constant(ConstantDefinition {
+ dst: constant,
+ typ: inp_type,
+ value: ast::ImmediateValue::U64(idx as u64),
+ }));
+ vec![(constant, ast::Type::Scalar(inp_type), ast::StateSpace::Reg)]
+ }
+ (None, None) => Vec::new(),
+ _ => return Err(error_mismatched_type()),
+ };
+ let return_type = sreg.get_function_return_type();
+ let fn_result = self
+ .resolver
+ .register_unnamed(Some((ast::Type::Scalar(return_type), ast::StateSpace::Reg)));
+ let return_arguments = vec![(
+ fn_result,
+ ast::Type::Scalar(return_type),
+ ast::StateSpace::Reg,
+ )];
+ let data = ast::CallDetails {
+ uniform: false,
+ return_arguments: return_arguments
+ .iter()
+ .map(|(_, typ, space)| (typ.clone(), *space))
+ .collect(),
+ input_arguments: input_arguments
+ .iter()
+ .map(|(_, typ, space)| (typ.clone(), *space))
+ .collect(),
+ };
+ let arguments = ast::CallArgs::<ast::ParsedOperand<SpirvWord>> {
+ return_arguments: return_arguments.iter().map(|(name, _, _)| *name).collect(),
+ func: self.sreg_to_function[&sreg],
+ input_arguments: input_arguments
+ .iter()
+ .map(|(name, _, _)| ast::ParsedOperand::Reg(*name))
+ .collect(),
+ };
+ self.result
+ .push(Statement::Instruction(ast::Instruction::Call {
+ data,
+ arguments,
+ }));
+ Ok(fn_result)
+ } else {
+ Ok(name)
+ }
+ }
+}
+
+pub fn map_operand<T, U, Err>(
+ this: ast::ParsedOperand<T>,
+ fn_: &mut impl FnMut(T, Option<u8>) -> Result<U, Err>,
+) -> Result<ast::ParsedOperand<U>, Err> {
+ Ok(match this {
+ ast::ParsedOperand::Reg(ident) => ast::ParsedOperand::Reg(fn_(ident, None)?),
+ ast::ParsedOperand::RegOffset(ident, offset) => {
+ ast::ParsedOperand::RegOffset(fn_(ident, None)?, offset)
+ }
+ ast::ParsedOperand::Imm(imm) => ast::ParsedOperand::Imm(imm),
+ ast::ParsedOperand::VecMember(ident, member) => {
+ ast::ParsedOperand::Reg(fn_(ident, Some(member))?)
+ }
+ ast::ParsedOperand::VecPack(idents) => ast::ParsedOperand::VecPack(
+ idents
+ .into_iter()
+ .map(|ident| fn_(ident, None))
+ .collect::<Result<Vec<_>, _>>()?,
+ ),
+ })
+}
diff --git a/ptx/src/pass/insert_explicit_load_store.rs b/ptx/src/pass/insert_explicit_load_store.rs new file mode 100644 index 0000000..e8f01cd --- /dev/null +++ b/ptx/src/pass/insert_explicit_load_store.rs @@ -0,0 +1,273 @@ +use super::*;
+use ptx_parser::VisitorMap;
+use rustc_hash::FxHashSet;
+
+// This pass:
+// * Turns all .local, .param and .reg in-body variables into .local variables
+// (if _not_ an input method argument)
+// * Inserts explicit `ld`/`st` for newly converted .reg variables
+// * Fixup state space of all existing `ld`/`st` instructions into newly
+// converted variables
+// * Turns `.entry` input arguments into param::entry and all related `.param`
+// loads into `param::entry` loads
+// * All `.func` input arguments are turned into `.reg` arguments by another
+// pass, so we do nothing there
+pub(super) fn run<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directives: Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>,
+) -> Result<Vec<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>>, TranslateError> {
+ directives
+ .into_iter()
+ .map(|directive| run_directive(resolver, directive))
+ .collect::<Result<Vec<_>, _>>()
+}
+
+fn run_directive<'a, 'input>(
+ resolver: &mut GlobalStringIdentResolver2<'input>,
+ directive: Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<Directive2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ Ok(match directive {
+ var @ Directive2::Variable(..) => var,
+ Directive2::Method(method) => {
+ let visitor = InsertMemSSAVisitor::new(resolver);
+ Directive2::Method(run_method(visitor, method)?)
+ }
+ })
+}
+
+fn run_method<'a, 'input>(
+ mut visitor: InsertMemSSAVisitor<'a, 'input>,
+ method: Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>,
+) -> Result<Function2<'input, ast::Instruction<SpirvWord>, SpirvWord>, TranslateError> {
+ let mut func_decl = method.func_decl;
+ for arg in func_decl.return_arguments.iter_mut() {
+ visitor.visit_variable(arg);
+ }
+ let is_kernel = func_decl.name.is_kernel();
+ // let mut prelude = Vec::with_capacity(method.body.as_ref().map(Vec::len).unwrap_or(0));
+ if is_kernel {
+ for arg in func_decl.input_arguments.iter_mut() {
+ let old_name = arg.name;
+ let old_space = arg.state_space;
+ let new_space = ast::StateSpace::ParamEntry;
+ let new_name = visitor
+ .resolver
+ .register_unnamed(Some((arg.v_type.clone(), new_space)));
+ visitor.input_argument(old_name, new_name, old_space);
+ arg.name = new_name;
+ arg.state_space = new_space;
+ }
+ };
+ let body = method
+ .body
+ .map(move |statements| {
+ let mut result = Vec::with_capacity(statements.len());
+ for statement in statements {
+ run_statement(&mut visitor, &mut result, statement)?;
+ }
+ Ok::<_, TranslateError>(result)
+ })
+ .transpose()?;
+ Ok(Function2 {
+ func_decl: func_decl,
+ globals: method.globals,
+ body,
+ import_as: method.import_as,
+ tuning: method.tuning,
+ linkage: method.linkage,
+ })
+}
+
+fn run_statement<'a, 'input>(
+ visitor: &mut InsertMemSSAVisitor<'a, 'input>,
+ result: &mut Vec<ExpandedStatement>,
+ statement: ExpandedStatement,
+) -> Result<(), TranslateError> {
+ match statement {
+ Statement::Variable(mut var) => {
+ visitor.visit_variable(&mut var);
+ result.push(Statement::Variable(var));
+ }
+ Statement::Instruction(ast::Instruction::Ld { data, arguments }) => {
+ let instruction = visitor.visit_ld(data, arguments)?;
+ let instruction = ast::visit_map(instruction, visitor)?;
+ result.push(Statement::Instruction(instruction));
+ }
+ Statement::Instruction(ast::Instruction::St {
+ data,
+ mut arguments,
+ }) => {
+ let instruction = visitor.visit_st(data, arguments)?;
+ let instruction = ast::visit_map(instruction, visitor)?;
+ result.push(Statement::Instruction(instruction));
+ }
+ s => result.push(s.visit_map(visitor)?),
+ }
+ Ok(())
+}
+
+struct InsertMemSSAVisitor<'a, 'input> {
+ resolver: &'a mut GlobalStringIdentResolver2<'input>,
+ variables: FxHashMap<SpirvWord, RemapAction>,
+}
+
+impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> {
+ fn new(resolver: &'a mut GlobalStringIdentResolver2<'input>) -> Self {
+ Self {
+ resolver,
+ variables: FxHashMap::default(),
+ }
+ }
+
+ fn input_argument(
+ &mut self,
+ old_name: SpirvWord,
+ new_name: SpirvWord,
+ old_space: ast::StateSpace,
+ ) -> Result<(), TranslateError> {
+ if old_space != ast::StateSpace::Param {
+ return Err(error_unreachable());
+ }
+ self.variables.insert(
+ old_name,
+ RemapAction::LDStSpaceChange {
+ name: new_name,
+ old_space,
+ new_space: ast::StateSpace::ParamEntry,
+ },
+ );
+ Ok(())
+ }
+
+ fn variable(
+ &mut self,
+ old_name: SpirvWord,
+ new_name: SpirvWord,
+ old_space: ast::StateSpace,
+ ) -> Result<(), TranslateError> {
+ Ok(match old_space {
+ ast::StateSpace::Reg => {
+ self.variables
+ .insert(old_name, RemapAction::PreLdPostSt(new_name));
+ }
+ ast::StateSpace::Param => {
+ self.variables.insert(
+ old_name,
+ RemapAction::LDStSpaceChange {
+ old_space,
+ new_space: ast::StateSpace::Local,
+ name: new_name,
+ },
+ );
+ }
+ // Good as-is
+ ast::StateSpace::Local => {}
+ // Will be pulled into global scope later
+ ast::StateSpace::Generic
+ | ast::StateSpace::SharedCluster
+ | ast::StateSpace::Global
+ | ast::StateSpace::Const
+ | ast::StateSpace::SharedCta
+ | ast::StateSpace::Shared => {}
+ ast::StateSpace::ParamEntry | ast::StateSpace::ParamFunc => {
+ return Err(error_unreachable())
+ }
+ })
+ }
+
+ fn visit_st(
+ &self,
+ mut data: ast::StData,
+ mut arguments: ast::StArgs<SpirvWord>,
+ ) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
+ if let Some(remap) = self.variables.get(&arguments.src1) {
+ match remap {
+ RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()),
+ RemapAction::LDStSpaceChange {
+ old_space,
+ new_space,
+ name,
+ } => {
+ if data.state_space != *old_space {
+ return Err(error_mismatched_type());
+ }
+ data.state_space = *new_space;
+ arguments.src1 = *name;
+ }
+ }
+ }
+ Ok(ast::Instruction::St { data, arguments })
+ }
+
+ fn visit_ld(
+ &self,
+ mut data: ast::LdDetails,
+ mut arguments: ast::LdArgs<SpirvWord>,
+ ) -> Result<ast::Instruction<SpirvWord>, TranslateError> {
+ if let Some(remap) = self.variables.get(&arguments.src) {
+ match remap {
+ RemapAction::PreLdPostSt(_) => return Err(error_mismatched_type()),
+ RemapAction::LDStSpaceChange {
+ old_space,
+ new_space,
+ name,
+ } => {
+ if data.state_space != *old_space {
+ return Err(error_mismatched_type());
+ }
+ data.state_space = *new_space;
+ arguments.src = *name;
+ }
+ }
+ }
+ Ok(ast::Instruction::Ld { data, arguments })
+ }
+
+ fn visit_variable(&mut self, var: &mut ast::Variable<SpirvWord>) {
+ if var.state_space != ast::StateSpace::Local {
+ let old_name = var.name;
+ let old_space = var.state_space;
+ let new_space = ast::StateSpace::Local;
+ let new_name = self
+ .resolver
+ .register_unnamed(Some((var.v_type.clone(), new_space)));
+ self.variable(old_name, new_name, old_space);
+ var.name = new_name;
+ var.state_space = new_space;
+ }
+ }
+}
+
+impl<'a, 'input> ast::VisitorMap<SpirvWord, SpirvWord, TranslateError>
+ for InsertMemSSAVisitor<'a, 'input>
+{
+ fn visit(
+ &mut self,
+ args: SpirvWord,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ todo!()
+ }
+
+ fn visit_ident(
+ &mut self,
+ args: SpirvWord,
+ type_space: Option<(&ast::Type, ast::StateSpace)>,
+ is_dst: bool,
+ relaxed_type_check: bool,
+ ) -> Result<SpirvWord, TranslateError> {
+ self.visit(args, type_space, is_dst, relaxed_type_check)
+ }
+}
+
+#[derive(Clone, Copy)]
+enum RemapAction {
+ PreLdPostSt(SpirvWord),
+ LDStSpaceChange {
+ old_space: ast::StateSpace,
+ new_space: ast::StateSpace,
+ name: SpirvWord,
+ },
+}
diff --git a/ptx/src/pass/insert_implicit_conversions.rs b/ptx/src/pass/insert_implicit_conversions.rs index 3249b82..c04fa09 100644 --- a/ptx/src/pass/insert_implicit_conversions.rs +++ b/ptx/src/pass/insert_implicit_conversions.rs @@ -45,6 +45,13 @@ pub(super) fn run( Statement::RepackVector(repack),
)?;
}
+ Statement::VectorAccess(vector_access) => {
+ insert_implicit_conversions_impl(
+ &mut result,
+ id_def,
+ Statement::VectorAccess(vector_access),
+ )?;
+ }
s @ Statement::Conditional(_)
| s @ Statement::Conversion(_)
| s @ Statement::Label(_)
diff --git a/ptx/src/pass/mod.rs b/ptx/src/pass/mod.rs index 04d3e49..b82d3c5 100644 --- a/ptx/src/pass/mod.rs +++ b/ptx/src/pass/mod.rs @@ -13,15 +13,21 @@ use std::{ mem,
rc::Rc,
};
+use strum::IntoEnumIterator;
+use strum_macros::EnumIter;
mod convert_dynamic_shared_memory_usage;
mod convert_to_stateful_memory_access;
mod convert_to_typed;
+mod deparamize_functions;
pub(crate) mod emit_llvm;
mod emit_spirv;
mod expand_arguments;
+mod expand_operands;
mod extract_globals;
mod fix_special_registers;
+mod fix_special_registers2;
+mod insert_explicit_load_store;
mod insert_implicit_conversions;
mod insert_mem_ssa_statements;
mod normalize_identifiers;
@@ -68,6 +74,20 @@ pub fn to_llvm_module<'input>(ast: ast::Module<'input>) -> Result<Module, Transl })
}
+pub fn to_llvm_module2<'input>(ast: ast::Module<'input>) -> Result<Module, TranslateError> {
+ let mut flat_resolver = GlobalStringIdentResolver2::<'input>::new(SpirvWord(1));
+ let mut scoped_resolver = ScopedResolver::new(&mut flat_resolver);
+ let sreg_map = SpecialRegistersMap2::new(&mut scoped_resolver)?;
+ let directives = normalize_identifiers2::run(&mut scoped_resolver, ast.directives)?;
+ let directives = normalize_predicates2::run(&mut flat_resolver, directives)?;
+ let directives = resolve_function_pointers::run(directives)?;
+ let directives = fix_special_registers2::run(&mut flat_resolver, &sreg_map, directives)?;
+ let directives = expand_operands::run(&mut flat_resolver, directives)?;
+ let directives = deparamize_functions::run(&mut flat_resolver, directives)?;
+ let directives = insert_explicit_load_store::run(&mut flat_resolver, directives)?;
+ todo!()
+}
+
fn translate_directive<'input, 'a>(
id_defs: &'a mut GlobalStringIdResolver<'input>,
ptx_impl_imports: &'a mut HashMap<String, Directive<'input>>,
@@ -323,7 +343,7 @@ pub struct KernelInfo { pub uses_shared_mem: bool,
}
-#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
+#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone, EnumIter)]
enum PtxSpecialRegister {
Tid,
Ntid,
@@ -346,6 +366,17 @@ impl PtxSpecialRegister { }
}
+ fn as_str(self) -> &'static str {
+ match self {
+ Self::Tid => "%tid",
+ Self::Ntid => "%ntid",
+ Self::Ctaid => "%ctaid",
+ Self::Nctaid => "%nctaid",
+ Self::Clock => "%clock",
+ Self::LanemaskLt => "%lanemask_lt",
+ }
+ }
+
fn get_type(self) -> ast::Type {
match self {
PtxSpecialRegister::Tid
@@ -726,6 +757,7 @@ enum Statement<I, P: ast::Operand> { PtrAccess(PtrAccess<P>),
RepackVector(RepackVectorDetails),
FunctionPointer(FunctionPointerDetails),
+ VectorAccess(VectorAccess),
}
impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> {
@@ -894,6 +926,36 @@ impl<T: ast::Operand<Ident = SpirvWord>> Statement<ast::Instruction<T>, T> { offset_src,
})
}
+ Statement::VectorAccess(VectorAccess {
+ scalar_type,
+ vector_width,
+ dst,
+ src: vector_src,
+ member,
+ }) => {
+ let dst: SpirvWord = visitor.visit_ident(
+ dst,
+ Some((&scalar_type.into(), ast::StateSpace::Reg)),
+ true,
+ false,
+ )?;
+ let src = visitor.visit_ident(
+ vector_src,
+ Some((
+ &ast::Type::Vector(vector_width, scalar_type),
+ ast::StateSpace::Reg,
+ )),
+ false,
+ false,
+ )?;
+ Statement::VectorAccess(VectorAccess {
+ scalar_type,
+ vector_width,
+ dst,
+ src,
+ member,
+ })
+ }
Statement::RepackVector(RepackVectorDetails {
is_extract,
typ,
@@ -1448,6 +1510,7 @@ fn compute_denorm_information<'input>( Statement::Label(_) => {}
Statement::Variable(_) => {}
Statement::PtrAccess { .. } => {}
+ Statement::VectorAccess { .. } => {}
Statement::RepackVector(_) => {}
Statement::FunctionPointer(_) => {}
}
@@ -1668,7 +1731,7 @@ pub(crate) enum Directive2<'input, Instruction, Operand: ast::Operand> { }
pub(crate) struct Function2<'input, Instruction, Operand: ast::Operand> {
- pub func_decl: Rc<RefCell<ast::MethodDeclaration<'input, SpirvWord>>>,
+ pub func_decl: ast::MethodDeclaration<'input, SpirvWord>,
pub globals: Vec<ast::Variable<SpirvWord>>,
pub body: Option<Vec<Statement<Instruction, Operand>>>,
import_as: Option<String>,
@@ -1712,14 +1775,35 @@ struct GlobalStringIdentResolver2<'input> { }
impl<'input> GlobalStringIdentResolver2<'input> {
- fn register_intermediate(
+ fn new(spirv_word: SpirvWord) -> Self {
+ Self {
+ current_id: spirv_word,
+ ident_map: FxHashMap::default(),
+ }
+ }
+
+ fn register_named(
&mut self,
+ name: Cow<'input, str>,
type_space: Option<(ast::Type, ast::StateSpace)>,
) -> SpirvWord {
let new_id = self.current_id;
self.ident_map.insert(
new_id,
IdentEntry {
+ name: Some(name),
+ type_space,
+ },
+ );
+ self.current_id.0 += 1;
+ new_id
+ }
+
+ fn register_unnamed(&mut self, type_space: Option<(ast::Type, ast::StateSpace)>) -> SpirvWord {
+ let new_id = self.current_id;
+ self.ident_map.insert(
+ new_id,
+ IdentEntry {
name: None,
type_space,
},
@@ -1727,9 +1811,191 @@ impl<'input> GlobalStringIdentResolver2<'input> { self.current_id.0 += 1;
new_id
}
+
+ fn get_typed(&self, id: SpirvWord) -> Result<&(ast::Type, ast::StateSpace), TranslateError> {
+ match self.ident_map.get(&id) {
+ Some(IdentEntry {
+ type_space: Some(type_space),
+ ..
+ }) => Ok(type_space),
+ _ => Err(error_unknown_symbol()),
+ }
+ }
}
struct IdentEntry<'input> {
name: Option<Cow<'input, str>>,
type_space: Option<(ast::Type, ast::StateSpace)>,
}
+
+struct ScopedResolver<'input, 'b> {
+ flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
+ scopes: Vec<ScopeMarker<'input>>,
+}
+
+impl<'input, 'b> ScopedResolver<'input, 'b> {
+ fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
+ Self {
+ flat_resolver,
+ scopes: vec![ScopeMarker::new()],
+ }
+ }
+
+ fn start_scope(&mut self) {
+ self.scopes.push(ScopeMarker::new());
+ }
+
+ fn end_scope(&mut self) {
+ let scope = self.scopes.pop().unwrap();
+ scope.flush(self.flat_resolver);
+ }
+
+ fn add(
+ &mut self,
+ name: Cow<'input, str>,
+ type_space: Option<(ast::Type, ast::StateSpace)>,
+ ) -> Result<SpirvWord, TranslateError> {
+ let result = self.flat_resolver.current_id;
+ self.flat_resolver.current_id.0 += 1;
+ let current_scope = self.scopes.last_mut().unwrap();
+ if current_scope
+ .name_to_ident
+ .insert(name.clone(), result)
+ .is_some()
+ {
+ return Err(error_unknown_symbol());
+ }
+ current_scope.ident_map.insert(
+ result,
+ IdentEntry {
+ name: Some(name),
+ type_space,
+ },
+ );
+ Ok(result)
+ }
+
+ fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
+ self.scopes
+ .iter()
+ .rev()
+ .find_map(|resolver| resolver.name_to_ident.get(name).copied())
+ .ok_or_else(|| error_unreachable())
+ }
+
+ fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
+ let current_scope = self.scopes.last().unwrap();
+ current_scope
+ .name_to_ident
+ .get(label)
+ .copied()
+ .ok_or_else(|| error_unreachable())
+ }
+}
+
+struct ScopeMarker<'input> {
+ ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
+ name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
+}
+
+impl<'input> ScopeMarker<'input> {
+ fn new() -> Self {
+ Self {
+ ident_map: FxHashMap::default(),
+ name_to_ident: FxHashMap::default(),
+ }
+ }
+
+ fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
+ resolver.ident_map.extend(self.ident_map);
+ }
+}
+
+struct SpecialRegistersMap2 {
+ reg_to_id: FxHashMap<PtxSpecialRegister, SpirvWord>,
+ id_to_reg: FxHashMap<SpirvWord, PtxSpecialRegister>,
+}
+
+impl SpecialRegistersMap2 {
+ fn new(resolver: &mut ScopedResolver) -> Result<Self, TranslateError> {
+ let mut result = SpecialRegistersMap2 {
+ reg_to_id: FxHashMap::default(),
+ id_to_reg: FxHashMap::default(),
+ };
+ for sreg in PtxSpecialRegister::iter() {
+ let text = sreg.as_str();
+ let id = resolver.add(
+ Cow::Borrowed(text),
+ Some((sreg.get_type(), ast::StateSpace::Reg)),
+ )?;
+ result.reg_to_id.insert(sreg, id);
+ result.id_to_reg.insert(id, sreg);
+ }
+ Ok(result)
+ }
+
+ fn get(&self, id: SpirvWord) -> Option<PtxSpecialRegister> {
+ self.id_to_reg.get(&id).copied()
+ }
+
+ fn get_or_add(&mut self, current_id: &mut SpirvWord, reg: PtxSpecialRegister) -> SpirvWord {
+ match self.reg_to_id.entry(reg) {
+ hash_map::Entry::Occupied(e) => *e.get(),
+ hash_map::Entry::Vacant(e) => {
+ let numeric_id = SpirvWord(current_id.0);
+ current_id.0 += 1;
+ e.insert(numeric_id);
+ self.id_to_reg.insert(numeric_id, reg);
+ numeric_id
+ }
+ }
+ }
+
+ fn generate_declarations<'a, 'input>(
+ resolver: &'a mut GlobalStringIdentResolver2<'input>,
+ ) -> impl ExactSizeIterator<
+ Item = (
+ PtxSpecialRegister,
+ ast::MethodDeclaration<'input, SpirvWord>,
+ ),
+ > + 'a {
+ PtxSpecialRegister::iter().map(|sreg| {
+ let external_fn_name = [ZLUDA_PTX_PREFIX, sreg.get_unprefixed_function_name()].concat();
+ let name =
+ ast::MethodName::Func(resolver.register_named(Cow::Owned(external_fn_name), None));
+ let return_type = sreg.get_function_return_type();
+ let input_type = sreg.get_function_return_type();
+ (
+ sreg,
+ ast::MethodDeclaration {
+ return_arguments: vec![ast::Variable {
+ align: None,
+ v_type: return_type.into(),
+ state_space: ast::StateSpace::Reg,
+ name: resolver
+ .register_unnamed(Some((return_type.into(), ast::StateSpace::Reg))),
+ array_init: Vec::new(),
+ }],
+ name: name,
+ input_arguments: vec![ast::Variable {
+ align: None,
+ v_type: input_type.into(),
+ state_space: ast::StateSpace::Reg,
+ name: resolver
+ .register_unnamed(Some((input_type.into(), ast::StateSpace::Reg))),
+ array_init: Vec::new(),
+ }],
+ shared_mem: None,
+ },
+ )
+ })
+ }
+}
+
+pub struct VectorAccess {
+ scalar_type: ast::ScalarType,
+ vector_width: u8,
+ dst: SpirvWord,
+ src: SpirvWord,
+ member: u8,
+}
diff --git a/ptx/src/pass/normalize_identifiers2.rs b/ptx/src/pass/normalize_identifiers2.rs index e3fb88d..beaf08b 100644 --- a/ptx/src/pass/normalize_identifiers2.rs +++ b/ptx/src/pass/normalize_identifiers2.rs @@ -2,21 +2,21 @@ use super::*; use ptx_parser as ast;
use rustc_hash::FxHashMap;
-pub(crate) fn run<'input>(
- fn_defs: &mut GlobalStringIdentResolver2<'input>,
+pub(crate) fn run<'input, 'b>(
+ resolver: &mut ScopedResolver<'input, 'b>,
directives: Vec<ast::Directive<'input, ast::ParsedOperand<&'input str>>>,
) -> Result<Vec<NormalizedDirective2<'input>>, TranslateError> {
- let mut resolver = NameResolver::new(fn_defs);
+ resolver.start_scope();
let result = directives
.into_iter()
- .map(|directive| run_directive(&mut resolver, directive))
+ .map(|directive| run_directive(resolver, directive))
.collect::<Result<Vec<_>, _>>()?;
resolver.end_scope();
Ok(result)
}
fn run_directive<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
directive: ast::Directive<'input, ast::ParsedOperand<&'input str>>,
) -> Result<NormalizedDirective2<'input>, TranslateError> {
Ok(match directive {
@@ -30,7 +30,7 @@ fn run_directive<'input, 'b>( }
fn run_method<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
linkage: ast::LinkingDirective,
method: ast::Function<'input, &'input str, ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<NormalizedFunction2<'input>, TranslateError> {
@@ -41,11 +41,7 @@ fn run_method<'input, 'b>( }
};
resolver.start_scope();
- let func_decl = Rc::new(RefCell::new(run_function_decl(
- resolver,
- method.func_directive,
- name,
- )?));
+ let func_decl = run_function_decl(resolver, method.func_directive, name)?;
let body = method
.body
.map(|statements| {
@@ -66,7 +62,7 @@ fn run_method<'input, 'b>( }
fn run_function_decl<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
func_directive: ast::MethodDeclaration<'input, &'input str>,
name: ast::MethodName<'input, SpirvWord>,
) -> Result<ast::MethodDeclaration<'input, SpirvWord>, TranslateError> {
@@ -90,7 +86,7 @@ fn run_function_decl<'input, 'b>( }
fn run_variable<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
variable: ast::Variable<&'input str>,
) -> Result<ast::Variable<SpirvWord>, TranslateError> {
Ok(ast::Variable {
@@ -106,7 +102,7 @@ fn run_variable<'input, 'b>( }
fn run_statements<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
statements: Vec<ast::Statement<ast::ParsedOperand<&'input str>>>,
) -> Result<(), TranslateError> {
@@ -148,7 +144,7 @@ fn run_statements<'input, 'b>( }
fn run_instruction<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
instruction: ast::Instruction<ast::ParsedOperand<&'input str>>,
) -> Result<ast::Instruction<ast::ParsedOperand<SpirvWord>>, TranslateError> {
ast::visit_map(instruction, &mut |name: &'input str,
@@ -163,7 +159,7 @@ fn run_instruction<'input, 'b>( }
fn run_multivariable<'input, 'b>(
- resolver: &mut NameResolver<'input, 'b>,
+ resolver: &mut ScopedResolver<'input, 'b>,
result: &mut Vec<NormalizedStatement>,
variable: ast::MultiVariable<&'input str>,
) -> Result<(), TranslateError> {
@@ -201,86 +197,3 @@ fn run_multivariable<'input, 'b>( }
Ok(())
}
-
-struct NameResolver<'input, 'b> {
- flat_resolver: &'b mut GlobalStringIdentResolver2<'input>,
- scopes: Vec<ScopeStringIdentResolver<'input>>,
-}
-
-impl<'input, 'b> NameResolver<'input, 'b> {
- fn new(flat_resolver: &'b mut GlobalStringIdentResolver2<'input>) -> Self {
- Self {
- flat_resolver,
- scopes: vec![ScopeStringIdentResolver::new()],
- }
- }
-
- fn start_scope(&mut self) {
- self.scopes.push(ScopeStringIdentResolver::new());
- }
-
- fn end_scope(&mut self) {
- let scope = self.scopes.pop().unwrap();
- scope.flush(self.flat_resolver);
- }
-
- fn add(
- &mut self,
- name: Cow<'input, str>,
- type_space: Option<(ast::Type, ast::StateSpace)>,
- ) -> Result<SpirvWord, TranslateError> {
- let result = self.flat_resolver.current_id;
- self.flat_resolver.current_id.0 += 1;
- let current_scope = self.scopes.last_mut().unwrap();
- if current_scope
- .name_to_ident
- .insert(name.clone(), result)
- .is_some()
- {
- return Err(error_unknown_symbol());
- }
- current_scope.ident_map.insert(
- result,
- IdentEntry {
- name: Some(name),
- type_space,
- },
- );
- Ok(result)
- }
-
- fn get(&mut self, name: &str) -> Result<SpirvWord, TranslateError> {
- self.scopes
- .iter()
- .rev()
- .find_map(|resolver| resolver.name_to_ident.get(name).copied())
- .ok_or_else(|| error_unreachable())
- }
-
- fn get_in_current_scope(&self, label: &'input str) -> Result<SpirvWord, TranslateError> {
- let current_scope = self.scopes.last().unwrap();
- current_scope
- .name_to_ident
- .get(label)
- .copied()
- .ok_or_else(|| error_unreachable())
- }
-}
-
-struct ScopeStringIdentResolver<'input> {
- ident_map: FxHashMap<SpirvWord, IdentEntry<'input>>,
- name_to_ident: FxHashMap<Cow<'input, str>, SpirvWord>,
-}
-
-impl<'input> ScopeStringIdentResolver<'input> {
- fn new() -> Self {
- Self {
- ident_map: FxHashMap::default(),
- name_to_ident: FxHashMap::default(),
- }
- }
-
- fn flush(self, resolver: &mut GlobalStringIdentResolver2<'input>) {
- resolver.ident_map.extend(self.ident_map);
- }
-}
diff --git a/ptx/src/pass/normalize_labels.rs b/ptx/src/pass/normalize_labels.rs index 097d87c..037e918 100644 --- a/ptx/src/pass/normalize_labels.rs +++ b/ptx/src/pass/normalize_labels.rs @@ -26,6 +26,7 @@ pub(super) fn run( | Statement::Constant(..)
| Statement::Label(..)
| Statement::PtrAccess { .. }
+ | Statement::VectorAccess { .. }
| Statement::RepackVector(..)
| Statement::FunctionPointer(..) => {}
}
diff --git a/ptx/src/pass/normalize_predicates2.rs b/ptx/src/pass/normalize_predicates2.rs index 2d15bba..d91e23c 100644 --- a/ptx/src/pass/normalize_predicates2.rs +++ b/ptx/src/pass/normalize_predicates2.rs @@ -55,8 +55,8 @@ fn run_statement<'input>( Statement::Variable(var) => result.push(Statement::Variable(var)),
Statement::Instruction((predicate, instruction)) => {
if let Some(pred) = predicate {
- let if_true = resolver.register_intermediate(None);
- let if_false = resolver.register_intermediate(None);
+ let if_true = resolver.register_unnamed(None);
+ let if_false = resolver.register_unnamed(None);
let folded_bra = match &instruction {
ast::Instruction::Bra { arguments, .. } => Some(arguments.src),
_ => None,
diff --git a/ptx/src/pass/resolve_function_pointers.rs b/ptx/src/pass/resolve_function_pointers.rs index 9aaa694..eb7abb1 100644 --- a/ptx/src/pass/resolve_function_pointers.rs +++ b/ptx/src/pass/resolve_function_pointers.rs @@ -20,7 +20,7 @@ fn run_directive<'input>( var @ Directive2::Variable(..) => var,
Directive2::Method(method) => {
{
- let func_decl = method.func_decl.borrow();
+ let func_decl = &method.func_decl;
match func_decl.name {
ptx_parser::MethodName::Kernel(_) => {}
ptx_parser::MethodName::Func(name) => {
|