aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorAndrzej Janik <[email protected]>2020-09-18 18:08:40 +0200
committerAndrzej Janik <[email protected]>2020-09-18 18:08:40 +0200
commitbcb749cdd913cb32c988f786982772e9b9b33bcb (patch)
treeca160793a8921669fa2d6de38a47e8ac4d1043e4
parent952ed5d5049462c60abf4149ee0ddbcb9cdb8cdc (diff)
downloadZLUDA-bcb749cdd913cb32c988f786982772e9b9b33bcb.tar.gz
ZLUDA-bcb749cdd913cb32c988f786982772e9b9b33bcb.zip
Continue working on a better addressable support
-rw-r--r--ptx/src/ast.rs38
-rw-r--r--ptx/src/lib.rs1
-rw-r--r--ptx/src/ptx.lalrpop2
-rw-r--r--ptx/src/test/mod.rs7
-rw-r--r--ptx/src/translate.rs1401
5 files changed, 844 insertions, 605 deletions
diff --git a/ptx/src/ast.rs b/ptx/src/ast.rs
index 7ac9d18..3a5022d 100644
--- a/ptx/src/ast.rs
+++ b/ptx/src/ast.rs
@@ -164,8 +164,8 @@ pub enum MethodDecl<'a, P: ArgParams> {
Kernel(&'a str, Vec<KernelArgument<P>>),
}
-pub type FnArgument<P: ArgParams> = Variable<FnArgumentType, P>;
-pub type KernelArgument<P: ArgParams> = Variable<VariableParamType, P>;
+pub type FnArgument<P> = Variable<FnArgumentType, P>;
+pub type KernelArgument<P> = Variable<VariableParamType, P>;
pub struct Function<'a, P: ArgParams, S> {
pub func_directive: MethodDecl<'a, P>,
@@ -316,7 +316,7 @@ pub struct PredAt<ID> {
pub enum Instruction<P: ArgParams> {
Ld(LdData, Arg2<P>),
- Mov(MovType, Arg2<P>),
+ Mov(MovType, Arg2Mov<P>),
MovVector(MovVectorDetails, Arg2Vec<P>),
Mul(MulDetails, Arg3<P>),
Add(AddDetails, Arg3<P>),
@@ -354,7 +354,7 @@ pub struct CallInst<P: ArgParams> {
pub trait ArgParams {
type ID;
type Operand;
- type MemoryOperand;
+ type MovOperand;
type CallOperand;
type VecOperand;
}
@@ -366,7 +366,7 @@ pub struct ParsedArgParams<'a> {
impl<'a> ArgParams for ParsedArgParams<'a> {
type ID = &'a str;
type Operand = Operand<&'a str>;
- type MemoryOperand = Operand<&'a str>;
+ type MovOperand = MovOperand<&'a str>;
type CallOperand = CallOperand<&'a str>;
type VecOperand = (&'a str, u8);
}
@@ -380,13 +380,27 @@ pub struct Arg2<P: ArgParams> {
pub src: P::Operand,
}
-pub struct Arg2Ld<P: ArgParams> {
+pub struct Arg2Mov<P: ArgParams> {
pub dst: P::ID,
- pub src: P::MemoryOperand,
+ pub src: P::MovOperand,
+}
+
+impl<'input> From<Arg2<ParsedArgParams<'input>>> for Arg2Mov<ParsedArgParams<'input>> {
+ fn from(a: Arg2<ParsedArgParams<'input>>) -> Arg2Mov<ParsedArgParams<'input>> {
+ let new_src = match a.src {
+ Operand::Reg(r) => MovOperand::Reg(r),
+ Operand::RegOffset(r, imm) => MovOperand::RegOffset(r, imm),
+ Operand::Imm(x) => MovOperand::Imm(x),
+ };
+ Arg2Mov {
+ dst: a.dst,
+ src: new_src,
+ }
+ }
}
pub struct Arg2St<P: ArgParams> {
- pub src1: P::MemoryOperand,
+ pub src1: P::Operand,
pub src2: P::Operand,
}
@@ -420,6 +434,14 @@ pub struct Arg5<P: ArgParams> {
}
#[derive(Copy, Clone)]
+pub enum MovOperand<ID> {
+ Reg(ID),
+ Address(ID),
+ RegOffset(ID, i32),
+ AddressOffset(ID, i32),
+ Imm(u32),
+}
+#[derive(Copy, Clone)]
pub enum Operand<ID> {
Reg(ID),
RegOffset(ID, i32),
diff --git a/ptx/src/lib.rs b/ptx/src/lib.rs
index 5e12579..8ae1c6d 100644
--- a/ptx/src/lib.rs
+++ b/ptx/src/lib.rs
@@ -31,6 +31,7 @@ pub use crate::ptx::ModuleParser;
pub use lalrpop_util::lexer::Token;
pub use lalrpop_util::ParseError;
pub use rspirv::dr::Error as SpirvError;
+pub use translate::TranslateError as TranslateError;
pub use translate::to_spirv;
pub(crate) fn without_none<T>(x: Vec<Option<T>>) -> Vec<T> {
diff --git a/ptx/src/ptx.lalrpop b/ptx/src/ptx.lalrpop
index 44f29a5..46d0b48 100644
--- a/ptx/src/ptx.lalrpop
+++ b/ptx/src/ptx.lalrpop
@@ -496,7 +496,7 @@ LdCacheOperator: ast::LdCacheOperator = {
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-mov
InstMov: ast::Instruction<ast::ParsedArgParams<'input>> = {
"mov" <t:MovType> <a:Arg2> => {
- ast::Instruction::Mov(t, a)
+ ast::Instruction::Mov(t, a.into())
},
"mov" <t:MovVectorType> <a:Arg2Vec> => {
ast::Instruction::MovVector(ast::MovVectorDetails{typ: t, length: 0}, a)
diff --git a/ptx/src/test/mod.rs b/ptx/src/test/mod.rs
index f40fc02..d251884 100644
--- a/ptx/src/test/mod.rs
+++ b/ptx/src/test/mod.rs
@@ -1,4 +1,5 @@
use super::ptx;
+use super::TranslateError;
mod spirv_run;
@@ -8,7 +9,7 @@ fn parse_and_assert(s: &str) {
assert!(errors.len() == 0);
}
-fn compile_and_assert(s: &str) -> Result<(), rspirv::dr::Error> {
+fn compile_and_assert(s: &str) -> Result<(), TranslateError> {
let mut errors = Vec::new();
let ast = ptx::ModuleParser::new().parse(&mut errors, s).unwrap();
crate::to_spirv(ast)?;
@@ -28,14 +29,14 @@ fn operands_ptx() {
#[test]
#[allow(non_snake_case)]
-fn vectorAdd_kernel64_ptx() -> Result<(), rspirv::dr::Error> {
+fn vectorAdd_kernel64_ptx() -> Result<(), TranslateError> {
let vector_add = include_str!("vectorAdd_kernel64.ptx");
compile_and_assert(vector_add)
}
#[test]
#[allow(non_snake_case)]
-fn _Z9vectorAddPKfS0_Pfi_ptx() -> Result<(), rspirv::dr::Error> {
+fn _Z9vectorAddPKfS0_Pfi_ptx() -> Result<(), TranslateError> {
let vector_add = include_str!("_Z9vectorAddPKfS0_Pfi.ptx");
compile_and_assert(vector_add)
}
diff --git a/ptx/src/translate.rs b/ptx/src/translate.rs
index 45372f1..0617cbe 100644
--- a/ptx/src/translate.rs
+++ b/ptx/src/translate.rs
@@ -5,6 +5,22 @@ use std::{borrow::Cow, iter, mem};
use rspirv::binary::Assemble;
+quick_error! {
+ #[derive(Debug)]
+ pub enum TranslateError {
+ UnknownSymbol {}
+ UntypedSymbol {}
+ MismatchedType {}
+ Spirv (err: rspirv::dr::Error) {
+ from()
+ display("{}", err)
+ cause(err)
+ }
+ Unreachable {}
+ Todo {}
+ }
+}
+
#[derive(PartialEq, Eq, Hash, Clone)]
enum SpirvType {
Base(SpirvScalarKey),
@@ -184,13 +200,13 @@ impl TypeWordMap {
}
}
-pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, dr::Error> {
+pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result<dr::Module, TranslateError> {
let mut id_defs = GlobalStringIdResolver::new(1);
let ssa_functions = ast
.functions
.into_iter()
.map(|f| to_ssa_function(&mut id_defs, f))
- .collect::<Vec<_>>();
+ .collect::<Result<Vec<_>, _>>()?;
let mut builder = dr::Builder::new();
builder.reserve_ids(id_defs.current_id());
// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module
@@ -217,11 +233,11 @@ fn emit_function_header<'a>(
map: &mut TypeWordMap,
global: &GlobalStringIdResolver<'a>,
func_directive: ast::MethodDecl<ExpandedArgParams>,
-) -> Result<(), dr::Error> {
+) -> Result<(), TranslateError> {
let (ret_type, func_type) = get_function_type(builder, map, &func_directive);
let fn_id = match func_directive {
ast::MethodDecl::Kernel(name, _) => {
- let fn_id = global.get_id(name);
+ let fn_id = global.get_id(name)?;
builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, &[]);
fn_id
}
@@ -246,7 +262,7 @@ fn emit_function_header<'a>(
Ok(())
}
-pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result<Vec<u32>, dr::Error> {
+pub fn to_spirv<'a>(ast: ast::Module<'a>) -> Result<Vec<u32>, TranslateError> {
let module = to_spirv_module(ast)?;
Ok(module.assemble())
}
@@ -276,7 +292,7 @@ fn emit_memory_model(builder: &mut dr::Builder) {
fn to_ssa_function<'a>(
id_defs: &mut GlobalStringIdResolver<'a>,
f: ast::ParsedFunction<'a>,
-) -> ExpandedFunction<'a> {
+) -> Result<ExpandedFunction<'a>, TranslateError> {
let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive);
to_ssa(str_resolver, fn_resolver, fn_decl, f.body)
}
@@ -316,25 +332,26 @@ fn to_ssa<'input, 'b>(
fn_defs: GlobalFnDeclResolver<'input, 'b>,
f_args: ast::MethodDecl<'input, ExpandedArgParams>,
f_body: Option<Vec<ast::Statement<ast::ParsedArgParams<'input>>>>,
-) -> ExpandedFunction<'input> {
+) -> Result<ExpandedFunction<'input>, TranslateError> {
let f_body = match f_body {
Some(vec) => vec,
None => {
- return ExpandedFunction {
+ return Ok(ExpandedFunction {
func_directive: f_args,
body: None,
- }
+ })
}
};
- let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body);
+ let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?;
let mut numeric_id_defs = id_defs.finish();
let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs);
let unadorned_statements =
- add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs);
+ add_types_to_statements(unadorned_statements, &fn_defs, &numeric_id_defs)?;
+ let mut numeric_id_defs = numeric_id_defs.finish();
+ let (f_args, ssa_statements) =
+ insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args)?;
todo!()
/*
- let (f_args, ssa_statements) =
- insert_mem_ssa_statements(unadorned_statements, &mut numeric_id_defs, f_args);
let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs);
let expanded_statements =
insert_implicit_conversions(expanded_statements, &mut numeric_id_defs);
@@ -359,33 +376,104 @@ fn add_types_to_statements(
func: Vec<UnadornedStatement>,
fn_defs: &GlobalFnDeclResolver,
id_defs: &NumericIdResolver,
-) -> Vec<TypedStatement> {
+) -> Result<Vec<UnadornedStatement>, TranslateError> {
func.into_iter()
.map(|s| {
match s {
Statement::Instruction(ast::Instruction::Call(call)) => {
// TODO: error out if lengths don't match
- let fn_def = fn_defs.get_fn_decl(call.func);
+ let fn_def = fn_defs.get_fn_decl(call.func)?;
let ret_params = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals);
let param_list = to_resolved_fn_args(call.param_list, &*fn_def.params);
- let resolved_call: ResolvedCall<TypedArgParams> = ResolvedCall {
+ let resolved_call = ResolvedCall {
uniform: call.uniform,
ret_params,
func: call.func,
param_list,
};
- Statement::Call(resolved_call)
+ Ok(Statement::Call(resolved_call))
}
- Statement::Instruction(ast::Instruction::Ld(d, arg)) => {
- todo!()
+ // Supported ld/st:
+ // global: only compatible with reg b64/u64/s64 source/dest
+ // generic: compatible with global/local sources
+ // param: compiled as mov
+ // local compiled as mov
+ // We would like to convert ld/st local/param to movs here,
+ // but they have different semantics for implicit conversions
+ // For now, we convert generic ld from local params to ld.local.
+ // This way, we can rely on further stages of the compilation on
+ // ld.generic & ld.global having bytes address source
+ // One complication: immediate address is only allowed in local,
+ // It is not supported in generic ld
+ // ld.local foo, [1];
+ Statement::Instruction(ast::Instruction::Ld(mut d, arg)) => {
+ match arg.src.underlying() {
+ None => return Ok(Statement::Instruction(ast::Instruction::Ld(d, arg))),
+ Some(u) => {
+ let (ss, typ) = id_defs.get_typed(*u)?;
+ match (d.state_space, ss) {
+ (ast::LdStateSpace::Generic, StateSpace::Local) => {
+ d.state_space = ast::LdStateSpace::Local;
+ }
+ _ => (),
+ };
+ }
+ };
+
+ Ok(Statement::Instruction(ast::Instruction::Ld(d, arg)))
}
- Statement::Instruction(ast::Instruction::MovVector(dets, args)) => {
- todo!()
+ Statement::Instruction(ast::Instruction::St(mut d, arg)) => {
+ match arg.src1.underlying() {
+ None => return Ok(Statement::Instruction(ast::Instruction::St(d, arg))),
+ Some(u) => {
+ let (ss, typ) = id_defs.get_typed(*u)?;
+ match (d.state_space, ss) {
+ (ast::StStateSpace::Generic, StateSpace::Local) => {
+ d.state_space = ast::StStateSpace::Local;
+ }
+ _ => (),
+ };
+ }
+ };
+ Ok(Statement::Instruction(ast::Instruction::St(d, arg)))
}
- s => todo!(),
+ Statement::Instruction(ast::Instruction::Mov(d, mut arg)) => {
+ arg.src = match arg.src {
+ ast::MovOperand::Reg(id) => {
+ let (ss, typ) = id_defs.get_typed(id)?;
+ match ss {
+ StateSpace::Reg => ast::MovOperand::Reg(id),
+ StateSpace::Const
+ | StateSpace::Global
+ | StateSpace::Local
+ | StateSpace::Shared
+ | StateSpace::Param
+ | StateSpace::ParamReg => ast::MovOperand::Address(id),
+ }
+ }
+ ast::MovOperand::RegOffset(id, imm) => {
+ let (ss, typ) = id_defs.get_typed(id)?;
+ match ss {
+ StateSpace::Reg => ast::MovOperand::RegOffset(id, imm),
+ StateSpace::Const
+ | StateSpace::Global
+ | StateSpace::Local
+ | StateSpace::Shared
+ | StateSpace::Param
+ | StateSpace::ParamReg => ast::MovOperand::AddressOffset(id, imm),
+ }
+ }
+ a @ ast::MovOperand::Imm(_) => a,
+ ast::MovOperand::Address(_) | ast::MovOperand::AddressOffset(_, _) => {
+ unreachable!()
+ }
+ };
+ Ok(Statement::Instruction(ast::Instruction::Mov(d, arg)))
+ }
+ s => Ok(s),
}
})
- .collect()
+ .collect::<Result<Vec<_>, _>>()
}
fn to_resolved_fn_args<T>(
@@ -478,18 +566,21 @@ fn normalize_predicates(
fn insert_mem_ssa_statements<'a, 'b>(
func: Vec<UnadornedStatement>,
- id_def: &mut NumericIdResolver,
+ id_def: &mut MutableNumericIdResolver,
mut f_args: ast::MethodDecl<'a, ExpandedArgParams>,
-) -> (
- ast::MethodDecl<'a, ExpandedArgParams>,
- Vec<UnadornedStatement>,
-) {
+) -> Result<
+ (
+ ast::MethodDecl<'a, ExpandedArgParams>,
+ Vec<UnadornedStatement>,
+ ),
+ TranslateError,
+> {
let mut result = Vec::with_capacity(func.len());
let out_param = match &mut f_args {
ast::MethodDecl::Kernel(_, in_params) => {
for p in in_params.iter_mut() {
let typ = ast::Type::from(p.v_type);
- let new_id = id_def.new_id(Some((StateSpace::Param, typ)));
+ let new_id = id_def.new_id(typ);
result.push(Statement::Variable(ast::Variable {
align: p.align,
v_type: ast::VariableType::Param(p.v_type),
@@ -508,12 +599,8 @@ fn insert_mem_ssa_statements<'a, 'b>(
}
ast::MethodDecl::Func(out_params, _, in_params) => {
for p in in_params.iter_mut() {
- let ss = match p.v_type {
- ast::FnArgumentType::Reg(_) => StateSpace::Reg,
- ast::FnArgumentType::Param(_) => StateSpace::Param,
- };
let typ = ast::Type::from(p.v_type);
- let new_id = id_def.new_id(Some((ss, typ)));
+ let new_id = id_def.new_id(typ);
let var_typ = ast::VariableType::from(p.v_type);
result.push(Statement::Variable(ast::Variable {
align: p.align,
@@ -545,31 +632,28 @@ fn insert_mem_ssa_statements<'a, 'b>(
};
for s in func {
match s {
- Statement::Call(call) => insert_mem_ssa_statement_default(id_def, &mut result, call),
+ Statement::Call(call) => insert_mem_ssa_statement_default(id_def, &mut result, call)?,
Statement::Instruction(inst) => match inst {
ast::Instruction::Ret(d) => {
if let Some(out_param) = out_param {
- let typ = id_def.get_type(out_param);
+ let typ = id_def.get_typed(out_param)?;
let new_id = id_def.new_id(typ);
result.push(Statement::LoadVar(
ast::Arg2 {
dst: new_id,
src: out_param,
},
- typ.unwrap().1,
+ typ,
));
result.push(Statement::RetValue(d, new_id));
} else {
result.push(Statement::Instruction(ast::Instruction::Ret(d)))
}
}
- inst => insert_mem_ssa_statement_default(id_def, &mut result, inst),
+ inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?,
},
Statement::Conditional(mut bra) => {
- let generated_id = id_def.new_id(Some((
- StateSpace::Reg,
- ast::Type::Scalar(ast::ScalarType::Pred),
- )));
+ let generated_id = id_def.new_id(ast::Type::Scalar(ast::ScalarType::Pred));
result.push(Statement::LoadVar(
Arg2 {
dst: generated_id,
@@ -589,41 +673,45 @@ fn insert_mem_ssa_statements<'a, 'b>(
Statement::Composite(_) => todo!(),
}
}
- (f_args, result)
+ Ok((f_args, result))
}
trait VisitVariable: Sized {
fn visit_variable<
'a,
- F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ F: FnMut(
+ ArgumentDescriptor<spirv::Word>,
+ Option<ast::Type>,
+ ) -> Result<spirv::Word, TranslateError>,
>(
self,
f: &mut F,
- ) -> UnadornedStatement;
+ ) -> Result<UnadornedStatement, TranslateError>;
}
trait VisitVariableExpanded {
fn visit_variable_extended<
- F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ F: FnMut(
+ ArgumentDescriptor<spirv::Word>,
+ Option<ast::Type>,
+ ) -> Result<spirv::Word, TranslateError>,
>(
self,
f: &mut F,
- ) -> ExpandedStatement;
+ ) -> Result<ExpandedStatement, TranslateError>;
}
fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
- id_def: &mut NumericIdResolver,
+ id_def: &mut MutableNumericIdResolver,
result: &mut Vec<UnadornedStatement>,
stmt: F,
-) {
+) -> Result<(), TranslateError> {
let mut post_statements = Vec::new();
let new_statement = stmt.visit_variable(&mut |desc: ArgumentDescriptor<spirv::Word>, _| {
- let id_type = match (id_def.get_type(desc.op), desc.sema) {
- (Some((_, t)), ArgumentSemantics::ParamPtr)
- | (Some((_, t)), ArgumentSemantics::Default) => t,
- (Some((_, t)), ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64),
- (None, _) => return desc.op,
+ let id_type = match (id_def.get_typed(desc.op)?, desc.sema) {
+ (t, ArgumentSemantics::ParamPtr) | (t, ArgumentSemantics::Default) => t,
+ (t, ArgumentSemantics::Ptr) => ast::Type::Scalar(ast::ScalarType::B64),
};
- let generated_id = id_def.new_id(Some((StateSpace::Reg, id_type)));
+ let generated_id = id_def.new_id(id_type);
if !desc.is_dst {
result.push(Statement::LoadVar(
Arg2 {
@@ -641,12 +729,14 @@ fn insert_mem_ssa_statement_default<'a, F: VisitVariable>(
id_type,
));
}
- generated_id
- });
+ Ok(generated_id)
+ })?;
result.push(new_statement);
result.append(&mut post_statements);
+ Ok(())
}
+/*
fn expand_arguments<'a, 'b>(
func: Vec<UnadornedStatement>,
id_def: &'b mut NumericIdResolver<'a>,
@@ -656,7 +746,7 @@ fn expand_arguments<'a, 'b>(
match s {
Statement::Call(call) => {
let mut visitor = FlattenArguments::new(&mut result, id_def);
- let (new_call, post_stmts) = (call.map(&mut visitor), visitor.post_stmts);
+ let (new_call, post_stmts) = (call.map(&mut visitor)?, visitor.post_stmts);
result.push(Statement::Call(new_call));
result.extend(post_stmts);
}
@@ -687,6 +777,7 @@ fn expand_arguments<'a, 'b>(
}
result
}
+*/
struct FlattenArguments<'a, 'b> {
func: &'b mut Vec<ExpandedStatement>,
@@ -711,15 +802,15 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
_: Option<ast::Type>,
- ) -> spirv::Word {
- desc.op
+ ) -> Result<spirv::Word, TranslateError> {
+ Ok(desc.op)
}
fn operand(
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
typ: ast::Type,
- ) -> spirv::Word {
+ ) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::Operand::Reg(r) => self.variable(desc.new_op(r), Some(typ)),
ast::Operand::Imm(x) => {
@@ -736,77 +827,74 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
typ: scalar_t,
value: x as i64,
}));
- id
- }
- ast::Operand::RegOffset(reg, offset) => {
- match desc.sema {
- ArgumentSemantics::Default => {
- let scalar_t = if let ast::Type::Scalar(scalar) = typ {
- scalar
- } else {
- todo!()
- };
- let id_constant_stmt = self
- .id_def
- .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
- let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ)));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i64,
- }));
- let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
- result_id
- }
- ArgumentSemantics::Ptr => {
- let scalar_t = ast::ScalarType::U64;
- let id_constant_stmt = self
- .id_def
- .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
- let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ)));
- self.func.push(Statement::Constant(ConstantDefinition {
- dst: id_constant_stmt,
- typ: scalar_t,
- value: offset as i64,
- }));
- let int_type = ast::IntType::U64;
- self.func.push(Statement::Instruction(
- ast::Instruction::<ExpandedArgParams>::Add(
- ast::AddDetails::Int(ast::AddIntDesc {
- typ: int_type,
- saturate: false,
- }),
- ast::Arg3 {
- dst: result_id,
- src1: reg,
- src2: id_constant_stmt,
- },
- ),
- ));
- result_id
- }
- ArgumentSemantics::ParamPtr => {
- if offset == 0 {
- return reg;
- }
- // Will be needed for arrays
+ Ok(id)
+ }
+ ast::Operand::RegOffset(reg, offset) => match desc.sema {
+ ArgumentSemantics::Default => {
+ let scalar_t = if let ast::Type::Scalar(scalar) = typ {
+ scalar
+ } else {
todo!()
+ };
+ let id_constant_stmt = self
+ .id_def
+ .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
+ let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ)));
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: scalar_t,
+ value: offset as i64,
+ }));
+ let int_type = ast::IntType::try_new(scalar_t).unwrap_or_else(|| todo!());
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Add(
+ ast::AddDetails::Int(ast::AddIntDesc {
+ typ: int_type,
+ saturate: false,
+ }),
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ Ok(result_id)
+ }
+ ArgumentSemantics::Ptr => {
+ let scalar_t = ast::ScalarType::U64;
+ let id_constant_stmt = self
+ .id_def
+ .new_id(Some((StateSpace::Reg, ast::Type::Scalar(scalar_t))));
+ let result_id = self.id_def.new_id(Some((StateSpace::Reg, typ)));
+ self.func.push(Statement::Constant(ConstantDefinition {
+ dst: id_constant_stmt,
+ typ: scalar_t,
+ value: offset as i64,
+ }));
+ let int_type = ast::IntType::U64;
+ self.func.push(Statement::Instruction(
+ ast::Instruction::<ExpandedArgParams>::Add(
+ ast::AddDetails::Int(ast::AddIntDesc {
+ typ: int_type,
+ saturate: false,
+ }),
+ ast::Arg3 {
+ dst: result_id,
+ src1: reg,
+ src2: id_constant_stmt,
+ },
+ ),
+ ));
+ Ok(result_id)
+ }
+ ArgumentSemantics::ParamPtr => {
+ if offset == 0 {
+ return Ok(reg);
}
+ todo!()
}
- }
+ },
}
}
@@ -814,7 +902,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
typ: ast::Type,
- ) -> spirv::Word {
+ ) -> Result<spirv::Word, TranslateError> {
match desc.op {
ast::CallOperand::Reg(reg) => self.variable(desc.new_op(reg), Some(typ)),
ast::CallOperand::Imm(x) => self.operand(desc.new_op(ast::Operand::Imm(x)), typ),
@@ -825,7 +913,7 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
(scalar_type, vec_len): (ast::MovVectorType, u8),
- ) -> spirv::Word {
+ ) -> Result<spirv::Word, TranslateError> {
let new_id = self.id_def.new_id(Some((
StateSpace::Reg,
ast::Type::Vector(scalar_type.into(), vec_len),
@@ -836,15 +924,15 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
src_composite: desc.op.0,
src_index: desc.op.1 as u32,
}));
- new_id
+ Ok(new_id)
}
fn mov_operand(
&mut self,
- desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ desc: ArgumentDescriptor<ast::MovOperand<spirv::Word>>,
typ: ast::Type,
- ) -> spirv::Word {
- self.operand(desc, typ)
+ ) -> Result<spirv::Word, TranslateError> {
+ todo!()
}
}
@@ -862,9 +950,10 @@ impl<'a, 'b> ArgumentMapVisitor<NormalizedArgParams, ExpandedArgParams>
- generic/global st: for instruction `st [x], y`, x must be of type
b64/u64/s64, which is bitcast to a pointer
*/
+/*
fn insert_implicit_conversions(
func: Vec<ExpandedStatement>,
- id_def: &mut NumericIdResolver,
+ id_def: &mut MutableNumericIdResolver,
) -> Vec<ExpandedStatement> {
let mut result = Vec::with_capacity(func.len());
for s in func.into_iter() {
@@ -936,7 +1025,7 @@ fn insert_implicit_conversions(
let mut did_vector_implicit = false;
let mut post_conv = None;
if inst_typ_is_bit {
- let src_type = id_def.get_type(arg.src).unwrap_or_else(|| todo!()).1;
+ let src_type = id_def.get_typed(arg.src)?;
if let ast::Type::Vector(_, _) = src_type {
arg.src = insert_conversion_src(
&mut result,
@@ -948,7 +1037,7 @@ fn insert_implicit_conversions(
);
did_vector_implicit = true;
}
- let dst_type = id_def.get_type(arg.dst).unwrap_or_else(|| todo!()).1;
+ let dst_type = id_def.get_typed(arg.dst)?;
if let ast::Type::Vector(_, _) = src_type {
post_conv = Some(get_conversion_dst(
id_def,
@@ -988,6 +1077,7 @@ fn insert_implicit_conversions(
}
result
}
+*/
fn get_function_type(
builder: &mut dr::Builder,
@@ -1600,12 +1690,11 @@ fn emit_implicit_conversion(
Ok(())
}
-// TODO: support scopes
fn normalize_identifiers<'a, 'b>(
id_defs: &mut FnStringIdResolver<'a, 'b>,
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
func: Vec<ast::Statement<ast::ParsedArgParams<'a>>>,
-) -> Vec<NormalizedStatement> {
+) -> Result<Vec<NormalizedStatement>, TranslateError> {
for s in func.iter() {
match s {
ast::Statement::Label(id) => {
@@ -1616,9 +1705,9 @@ fn normalize_identifiers<'a, 'b>(
}
let mut result = Vec::new();
for s in func {
- expand_map_variables(id_defs, fn_defs, &mut result, s);
+ expand_map_variables(id_defs, fn_defs, &mut result, s)?;
}
- result
+ Ok(result)
}
fn expand_map_variables<'a, 'b>(
@@ -1626,19 +1715,20 @@ fn expand_map_variables<'a, 'b>(
fn_defs: &GlobalFnDeclResolver<'a, 'b>,
result: &mut Vec<NormalizedStatement>,
s: ast::Statement<ast::ParsedArgParams<'a>>,
-) {
+) -> Result<(), TranslateError> {
match s {
ast::Statement::Block(block) => {
id_defs.start_block();
for s in block {
- expand_map_variables(id_defs, fn_defs, result, s);
+ expand_map_variables(id_defs, fn_defs, result, s)?;
}
id_defs.end_block();
}
- ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name))),
+ ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)),
ast::Statement::Instruction(p, i) => result.push(Statement::Instruction((
- p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))),
- i.map_variable(&mut |id| id_defs.get_id(id)),
+ p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id)))
+ .transpose()?,
+ i.map_variable(&mut |id| id_defs.get_id(id))?,
))),
ast::Statement::Variable(var) => {
let ss = match var.var.v_type {
@@ -1666,16 +1756,16 @@ fn expand_map_variables<'a, 'b>(
}
}
}
- }
+ };
+ Ok(())
}
-#[derive(Ord, PartialOrd, Eq, PartialEq, Hash)]
+#[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)]
enum PtxSpecialRegister {
Tid,
Ntid,
Ctaid,
Nctaid,
- Gridid,
}
impl PtxSpecialRegister {
@@ -1685,10 +1775,27 @@ impl PtxSpecialRegister {
"%ntid" => Some(Self::Ntid),
"%ctaid" => Some(Self::Ctaid),
"%nctaid" => Some(Self::Nctaid),
- "%gridid" => Some(Self::Gridid),
_ => None,
}
}
+
+ fn get_type(self) -> ast::Type {
+ match self {
+ PtxSpecialRegister::Tid => ast::Type::Vector(ast::ScalarType::U32, 4),
+ PtxSpecialRegister::Ntid => ast::Type::Vector(ast::ScalarType::U32, 4),
+ PtxSpecialRegister::Ctaid => ast::Type::Vector(ast::ScalarType::U32, 4),
+ PtxSpecialRegister::Nctaid => ast::Type::Vector(ast::ScalarType::U32, 4),
+ }
+ }
+
+ fn get_builtin(self) -> spirv::BuiltIn {
+ match self {
+ PtxSpecialRegister::Tid => spirv::BuiltIn::GlobalInvocationId,
+ PtxSpecialRegister::Ntid => spirv::BuiltIn::GlobalSize,
+ PtxSpecialRegister::Ctaid => spirv::BuiltIn::WorkgroupId,
+ PtxSpecialRegister::Nctaid => spirv::BuiltIn::NumWorkgroups,
+ }
+ }
}
struct GlobalStringIdResolver<'input> {
@@ -1725,8 +1832,11 @@ impl<'a> GlobalStringIdResolver<'a> {
}
}
- fn get_id(&self, id: &str) -> spirv::Word {
- self.variables[id]
+ fn get_id(&self, id: &str) -> Result<spirv::Word, TranslateError> {
+ self.variables
+ .get(id)
+ .copied()
+ .ok_or(TranslateError::UnknownSymbol)
}
fn current_id(&self) -> spirv::Word {
@@ -1741,7 +1851,7 @@ impl<'a> GlobalStringIdResolver<'a> {
GlobalFnDeclResolver<'a, 'b>,
ast::MethodDecl<'a, ExpandedArgParams>,
) {
- // In case a function decl was inserted eearlier we want to use its id
+ // In case a function decl was inserted earlier we want to use its id
let name_id = self.get_or_add_def(header.name());
let mut fn_resolver = FnStringIdResolver {
current_id: &mut self.current_id,
@@ -1784,12 +1894,15 @@ pub struct GlobalFnDeclResolver<'input, 'a> {
}
impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> {
- fn get_fn_decl(&self, id: spirv::Word) -> &FnDecl {
- &self.fns[&id]
+ fn get_fn_decl(&self, id: spirv::Word) -> Result<&FnDecl, TranslateError> {
+ self.fns.get(&id).ok_or(TranslateError::UnknownSymbol)
}
- fn get_fn_decl_str(&self, id: &str) -> &'a FnDecl {
- &self.fns[&self.variables[id]]
+ fn get_fn_decl_str(&self, id: &str) -> Result<&'a FnDecl, TranslateError> {
+ match self.variables.get(id).map(|var_id| self.fns.get(var_id)) {
+ Some(Some(fn_d)) => Ok(fn_d),
+ _ => Err(TranslateError::UnknownSymbol),
+ }
}
}
@@ -1798,7 +1911,7 @@ struct FnStringIdResolver<'input, 'b> {
global_variables: &'b HashMap<Cow<'input, str>, spirv::Word>,
special_registers: &'b mut HashMap<PtxSpecialRegister, spirv::Word>,
variables: Vec<HashMap<Cow<'input, str>, spirv::Word>>,
- type_check: HashMap<u32, (StateSpace, ast::Type)>,
+ type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
}
impl<'a, 'b> FnStringIdResolver<'a, 'b> {
@@ -1806,6 +1919,11 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
NumericIdResolver {
current_id: self.current_id,
type_check: self.type_check,
+ special_registers: self
+ .special_registers
+ .iter()
+ .map(|(reg, id)| (*id, *reg))
+ .collect(),
}
}
@@ -1817,24 +1935,25 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
self.variables.pop();
}
- fn get_id(&mut self, id: &str) -> spirv::Word {
+ fn get_id(&mut self, id: &str) -> Result<spirv::Word, TranslateError> {
for scope in self.variables.iter().rev() {
match scope.get(id) {
- Some(id) => return *id,
+ Some(id) => return Ok(*id),
None => continue,
}
}
match self.global_variables.get(id) {
- Some(id) => *id,
+ Some(id) => Ok(*id),
None => {
- let sreg = PtxSpecialRegister::try_parse(id).unwrap_or_else(|| todo!());
+ let sreg =
+ PtxSpecialRegister::try_parse(id).ok_or(TranslateError::UnknownSymbol)?;
match self.special_registers.entry(sreg) {
- hash_map::Entry::Occupied(e) => *e.get(),
+ hash_map::Entry::Occupied(e) => Ok(*e.get()),
hash_map::Entry::Vacant(e) => {
let numeric_id = *self.current_id;
*self.current_id += 1;
e.insert(numeric_id);
- numeric_id
+ Ok(numeric_id)
}
}
}
@@ -1847,9 +1966,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
.last_mut()
.unwrap()
.insert(Cow::Borrowed(id), numeric_id);
- if let Some(typ) = typ {
- self.type_check.insert(numeric_id, typ);
- }
+ self.type_check.insert(numeric_id, typ);
*self.current_id += 1;
numeric_id
}
@@ -1868,7 +1985,7 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
.last_mut()
.unwrap()
.insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i);
- self.type_check.insert(numeric_id + i, (ss, typ));
+ self.type_check.insert(numeric_id + i, Some((ss, typ)));
}
*self.current_id += count;
(0..count).into_iter().map(move |i| i + numeric_id)
@@ -1877,24 +1994,48 @@ impl<'a, 'b> FnStringIdResolver<'a, 'b> {
struct NumericIdResolver<'b> {
current_id: &'b mut spirv::Word,
- type_check: HashMap<u32, (StateSpace, ast::Type)>,
+ type_check: HashMap<u32, Option<(StateSpace, ast::Type)>>,
+ special_registers: HashMap<spirv::Word, PtxSpecialRegister>,
}
impl<'b> NumericIdResolver<'b> {
- fn get_type(&self, id: spirv::Word) -> Option<(StateSpace, ast::Type)> {
- self.type_check.get(&id).map(|x| *x)
+ fn finish(self) -> MutableNumericIdResolver<'b> {
+ MutableNumericIdResolver { base: self }
+ }
+
+ fn get_typed(&self, id: spirv::Word) -> Result<(StateSpace, ast::Type), TranslateError> {
+ match self.type_check.get(&id) {
+ Some(Some(x)) => Ok(*x),
+ Some(None) => Err(TranslateError::UntypedSymbol),
+ None => match self.special_registers.get(&id) {
+ Some(x) => Ok((StateSpace::Reg, x.get_type())),
+ None => Err(TranslateError::UntypedSymbol),
+ },
+ }
}
fn new_id(&mut self, typ: Option<(StateSpace, ast::Type)>) -> spirv::Word {
let new_id = *self.current_id;
- if let Some(typ) = typ {
- self.type_check.insert(new_id, typ);
- }
+ self.type_check.insert(new_id, typ);
*self.current_id += 1;
new_id
}
}
+struct MutableNumericIdResolver<'b> {
+ base: NumericIdResolver<'b>,
+}
+
+impl<'b> MutableNumericIdResolver<'b> {
+ fn get_typed(&self, id: spirv::Word) -> Result<ast::Type, TranslateError> {
+ self.base.get_typed(id).map(|(_, t)| t)
+ }
+
+ fn new_id(&mut self, typ: ast::Type) -> spirv::Word {
+ self.base.new_id(Some((StateSpace::Reg, typ)))
+ }
+}
+
enum Statement<I, P: ast::ArgParams> {
Label(u32),
Variable(ast::Variable<ast::VariableType, P>),
@@ -1921,11 +2062,11 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
fn map<To: ArgParamsEx<ID = spirv::Word>, V: ArgumentMapVisitor<From, To>>(
self,
visitor: &mut V,
- ) -> ResolvedCall<To> {
+ ) -> Result<ResolvedCall<To>, TranslateError> {
let ret_params = self
.ret_params
.into_iter()
- .map(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ)| {
let new_id = visitor.variable(
ArgumentDescriptor {
op: id,
@@ -1933,10 +2074,10 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
sema: ArgumentSemantics::Default,
},
Some(typ.into()),
- );
- (new_id, typ)
+ )?;
+ Ok((new_id, typ))
})
- .collect();
+ .collect::<Result<Vec<_>, _>>()?;
let func = visitor.variable(
ArgumentDescriptor {
op: self.func,
@@ -1944,11 +2085,11 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
sema: ArgumentSemantics::Default,
},
None,
- );
+ )?;
let param_list = self
.param_list
.into_iter()
- .map(|(id, typ)| {
+ .map::<Result<_, TranslateError>, _>(|(id, typ)| {
let new_id = visitor.src_call_operand(
ArgumentDescriptor {
op: id,
@@ -1956,48 +2097,60 @@ impl<From: ArgParamsEx<ID = spirv::Word>> ResolvedCall<From> {
sema: ArgumentSemantics::Default,
},
typ.into(),
- );
- (new_id, typ)
+ )?;
+ Ok((new_id, typ))
})
- .collect();
- ResolvedCall {
+ .collect::<Result<Vec<_>, _>>()?;
+ Ok(ResolvedCall {
uniform: self.uniform,
ret_params,
func,
param_list,
- }
+ })
}
}
impl VisitVariable for ResolvedCall<NormalizedArgParams> {
fn visit_variable<
'a,
- F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ F: FnMut(
+ ArgumentDescriptor<spirv::Word>,
+ Option<ast::Type>,
+ ) -> Result<spirv::Word, TranslateError>,
>(
self,
f: &mut F,
- ) -> UnadornedStatement {
- Statement::Call(self.map(f))
+ ) -> Result<UnadornedStatement, TranslateError> {
+ Ok(Statement::Call(self.map(f)?))
}
}
impl VisitVariableExpanded for ResolvedCall<ExpandedArgParams> {
fn visit_variable_extended<
- F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ F: FnMut(
+ ArgumentDescriptor<spirv::Word>,
+ Option<ast::Type>,
+ ) -> Result<spirv::Word, TranslateError>,
>(
self,
f: &mut F,
- ) -> ExpandedStatement {
- Statement::Call(self.map(f))
+ ) -> Result<ExpandedStatement, TranslateError> {
+ Ok(Statement::Call(self.map(f)?))
}
}
pub trait ArgParamsEx: ast::ArgParams {
- fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl;
+ fn get_fn_decl<'x, 'b>(
+ id: &Self::ID,
+ decl: &'b GlobalFnDeclResolver<'x, 'b>,
+ ) -> Result<&'b FnDecl, TranslateError>;
}
impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> {
- fn get_fn_decl<'x, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'x, 'b>) -> &'b FnDecl {
+ fn get_fn_decl<'x, 'b>(
+ id: &Self::ID,
+ decl: &'b GlobalFnDeclResolver<'x, 'b>,
+ ) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl_str(id)
}
}
@@ -2015,23 +2168,16 @@ type UnadornedStatement = Statement<ast::Instruction<NormalizedArgParams>, Norma
impl ast::ArgParams for NormalizedArgParams {
type ID = spirv::Word;
type Operand = ast::Operand<spirv::Word>;
- type MemoryOperand = ast::Operand<spirv::Word>;
- type CallOperand = ast::CallOperand<spirv::Word>;
- type VecOperand = (spirv::Word, u8);
-}
-
-enum TypedArgParams {}
-impl ast::ArgParams for TypedArgParams {
- type ID = spirv::Word;
- type Operand = ast::Operand<spirv::Word>;
- type MemoryOperand = MemoryOperand;
+ type MovOperand = ast::MovOperand<spirv::Word>;
type CallOperand = ast::CallOperand<spirv::Word>;
type VecOperand = (spirv::Word, u8);
}
-type TypedStatement = Statement<ast::Instruction<TypedArgParams>, TypedArgParams>;
impl ArgParamsEx for NormalizedArgParams {
- fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl {
+ fn get_fn_decl<'a, 'b>(
+ id: &Self::ID,
+ decl: &'b GlobalFnDeclResolver<'a, 'b>,
+ ) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl(*id)
}
}
@@ -2039,7 +2185,6 @@ impl ArgParamsEx for NormalizedArgParams {
#[derive(Copy, Clone)]
pub enum StateSpace {
Reg,
- Sreg,
Const,
Global,
Local,
@@ -2048,15 +2193,6 @@ pub enum StateSpace {
ParamReg,
}
-#[derive(Copy, Clone)]
-pub enum MemoryOperand {
- Reg(spirv::Word),
- Address(spirv::Word),
- RegOffset(spirv::Word, i32),
- AddressOffset(spirv::Word, i32),
- Imm(u32),
-}
-
enum ExpandedArgParams {}
type ExpandedStatement = Statement<ast::Instruction<ExpandedArgParams>, ExpandedArgParams>;
type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStatement>;
@@ -2064,54 +2200,76 @@ type ExpandedFunction<'a> = ast::Function<'a, ExpandedArgParams, ExpandedStateme
impl ast::ArgParams for ExpandedArgParams {
type ID = spirv::Word;
type Operand = spirv::Word;
- type MemoryOperand = spirv::Word;
+ type MovOperand = spirv::Word;
type CallOperand = spirv::Word;
type VecOperand = spirv::Word;
}
impl ArgParamsEx for ExpandedArgParams {
- fn get_fn_decl<'a, 'b>(id: &Self::ID, decl: &'b GlobalFnDeclResolver<'a, 'b>) -> &'b FnDecl {
+ fn get_fn_decl<'a, 'b>(
+ id: &Self::ID,
+ decl: &'b GlobalFnDeclResolver<'a, 'b>,
+ ) -> Result<&'b FnDecl, TranslateError> {
decl.get_fn_decl(*id)
}
}
trait ArgumentMapVisitor<T: ArgParamsEx, U: ArgParamsEx> {
- fn variable(&mut self, desc: ArgumentDescriptor<T::ID>, typ: Option<ast::Type>) -> U::ID;
- fn operand(&mut self, desc: ArgumentDescriptor<T::Operand>, typ: ast::Type) -> U::Operand;
+ fn variable(
+ &mut self,
+ desc: ArgumentDescriptor<T::ID>,
+ typ: Option<ast::Type>,
+ ) -> Result<U::ID, TranslateError>;
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<T::Operand>,
+ typ: ast::Type,
+ ) -> Result<U::Operand, TranslateError>;
fn mov_operand(
&mut self,
- desc: ArgumentDescriptor<T::MemoryOperand>,
+ desc: ArgumentDescriptor<T::MovOperand>,
typ: ast::Type,
- ) -> U::MemoryOperand;
+ ) -> Result<U::MovOperand, TranslateError>;
fn src_call_operand(
&mut self,
desc: ArgumentDescriptor<T::CallOperand>,
typ: ast::Type,
- ) -> U::CallOperand;
+ ) -> Result<U::CallOperand, TranslateError>;
fn src_vec_operand(
&mut self,
desc: ArgumentDescriptor<T::VecOperand>,
typ: (ast::MovVectorType, u8),
- ) -> U::VecOperand;
+ ) -> Result<U::VecOperand, TranslateError>;
}
impl<T> ArgumentMapVisitor<ExpandedArgParams, ExpandedArgParams> for T
where
- T: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ T: FnMut(
+ ArgumentDescriptor<spirv::Word>,
+ Option<ast::Type>,
+ ) -> Result<spirv::Word, TranslateError>,
{
fn variable(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: Option<ast::Type>,
- ) -> spirv::Word {
+ ) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
- fn operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
+ fn operand(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
self(desc, Some(t))
}
- fn mov_operand(&mut self, desc: ArgumentDescriptor<spirv::Word>, t: ast::Type) -> spirv::Word {
+ fn mov_operand(
+ &mut self,
+ desc: ArgumentDescriptor<spirv::Word>,
+ t: ast::Type,
+ ) -> Result<spirv::Word, TranslateError> {
self(desc, Some(t))
}
@@ -2119,7 +2277,7 @@ where
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: ast::Type,
- ) -> spirv::Word {
+ ) -> Result<spirv::Word, TranslateError> {
self(desc, Some(t))
}
@@ -2127,7 +2285,7 @@ where
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
(scalar_type, vec_len): (ast::MovVectorType, u8),
- ) -> spirv::Word {
+ ) -> Result<spirv::Word, TranslateError> {
self(
desc.new_op(desc.op),
Some(ast::Type::Vector(scalar_type.into(), vec_len)),
@@ -2137,9 +2295,13 @@ where
impl<'a, T> ArgumentMapVisitor<ast::ParsedArgParams<'a>, NormalizedArgParams> for T
where
- T: FnMut(&str) -> spirv::Word,
+ T: FnMut(&str) -> Result<spirv::Word, TranslateError>,
{
- fn variable(&mut self, desc: ArgumentDescriptor<&str>, _: Option<ast::Type>) -> spirv::Word {
+ fn variable(
+ &mut self,
+ desc: ArgumentDescriptor<&str>,
+ _: Option<ast::Type>,
+ ) -> Result<spirv::Word, TranslateError> {
self(desc.op)
}
@@ -2147,11 +2309,11 @@ where
&mut self,
desc: ArgumentDescriptor<ast::Operand<&str>>,
_: ast::Type,
- ) -> ast::Operand<spirv::Word> {
+ ) -> Result<ast::Operand<spirv::Word>, TranslateError> {
match desc.op {
- ast::Operand::Reg(id) => ast::Operand::Reg(self(id)),
- ast::Operand::Imm(imm) => ast::Operand::Imm(imm),
- ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id), imm),
+ ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(id)?)),
+ ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)),
+ ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(self(id)?, imm)),
}
}
@@ -2159,10 +2321,10 @@ where
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<&str>>,
_: ast::Type,
- ) -> ast::CallOperand<spirv::Word> {
+ ) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
match desc.op {
- ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(id)),
- ast::CallOperand::Imm(imm) => ast::CallOperand::Imm(imm),
+ ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(id)?)),
+ ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)),
}
}
@@ -2170,16 +2332,16 @@ where
&mut self,
desc: ArgumentDescriptor<(&str, u8)>,
_: (ast::MovVectorType, u8),
- ) -> (spirv::Word, u8) {
- (self(desc.op.0), desc.op.1)
+ ) -> Result<(spirv::Word, u8), TranslateError> {
+ Ok((self(desc.op.0)?, desc.op.1))
}
fn mov_operand(
&mut self,
- desc: ArgumentDescriptor<ast::Operand<&str>>,
+ desc: ArgumentDescriptor<ast::MovOperand<&str>>,
typ: ast::Type,
- ) -> ast::Operand<spirv::Word> {
- self.operand(desc, typ)
+ ) -> Result<ast::MovOperand<spirv::Word>, TranslateError> {
+ todo!()
}
}
@@ -2210,41 +2372,41 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
- ) -> ast::Instruction<U> {
- match self {
+ ) -> Result<ast::Instruction<U>, TranslateError> {
+ Ok(match self {
ast::Instruction::MovVector(t, a) => {
- ast::Instruction::MovVector(t, a.map(visitor, (t.typ, t.length)))
+ ast::Instruction::MovVector(t, a.map(visitor, (t.typ, t.length))?)
}
ast::Instruction::Abs(d, arg) => {
- ast::Instruction::Abs(d, arg.map(visitor, ast::Type::Scalar(d.typ)))
+ ast::Instruction::Abs(d, arg.map(visitor, ast::Type::Scalar(d.typ))?)
}
// Call instruction is converted to a call statement early on
ast::Instruction::Call(_) => unreachable!(),
ast::Instruction::Ld(d, a) => {
let inst_type = d.typ;
let is_param = d.state_space == ast::LdStateSpace::Param;
- ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, is_param))
+ ast::Instruction::Ld(d, a.map_ld(visitor, inst_type, is_param)?)
}
ast::Instruction::Mov(mov_type, a) => {
- ast::Instruction::Mov(mov_type, a.map(visitor, mov_type.into()))
+ ast::Instruction::Mov(mov_type, a.map(visitor, mov_type.into())?)
}
ast::Instruction::Mul(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type))
+ ast::Instruction::Mul(d, a.map_non_shift(visitor, inst_type)?)
}
ast::Instruction::Add(d, a) => {
let inst_type = d.get_type();
- ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type))
+ ast::Instruction::Add(d, a.map_non_shift(visitor, inst_type)?)
}
ast::Instruction::Setp(d, a) => {
let inst_type = d.typ;
- ast::Instruction::Setp(d, a.map(visitor, ast::Type::Scalar(inst_type)))
+ ast::Instruction::Setp(d, a.map(visitor, ast::Type::Scalar(inst_type))?)
}
ast::Instruction::SetpBool(d, a) => {
let inst_type = d.typ;
- ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type)))
+ ast::Instruction::SetpBool(d, a.map(visitor, ast::Type::Scalar(inst_type))?)
}
- ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, t.to_type())),
+ ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, t.to_type())?),
ast::Instruction::Cvt(d, a) => {
let (dst_t, src_t) = match &d {
ast::CvtDetails::FloatFromFloat(desc) => (
@@ -2264,47 +2426,53 @@ impl<T: ArgParamsEx> ast::Instruction<T> {
ast::Type::Scalar(desc.src.into()),
),
};
- ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t))
+ ast::Instruction::Cvt(d, a.map_cvt(visitor, dst_t, src_t)?)
}
ast::Instruction::Shl(t, a) => {
- ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type()))
+ ast::Instruction::Shl(t, a.map_shift(visitor, t.to_type())?)
}
ast::Instruction::St(d, a) => {
let inst_type = d.typ;
let is_param = d.state_space == ast::StStateSpace::Param;
- ast::Instruction::St(d, a.map(visitor, inst_type, is_param))
+ ast::Instruction::St(d, a.map(visitor, inst_type, is_param)?)
}
- ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)),
+ ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?),
ast::Instruction::Ret(d) => ast::Instruction::Ret(d),
ast::Instruction::Cvta(d, a) => {
let inst_type = ast::Type::Scalar(ast::ScalarType::B64);
- ast::Instruction::Cvta(d, a.map(visitor, inst_type))
+ ast::Instruction::Cvta(d, a.map(visitor, inst_type)?)
}
- }
+ })
}
}
impl VisitVariable for ast::Instruction<NormalizedArgParams> {
fn visit_variable<
'a,
- F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ F: FnMut(
+ ArgumentDescriptor<spirv_headers::Word>,
+ Option<ast::Type>,
+ ) -> Result<spirv_headers::Word, TranslateError>,
>(
self,
f: &mut F,
- ) -> UnadornedStatement {
- Statement::Instruction(self.map(f))
+ ) -> Result<UnadornedStatement, TranslateError> {
+ Ok(Statement::Instruction(self.map(f)?))
}
}
impl<T> ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams> for T
where
- T: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ T: FnMut(
+ ArgumentDescriptor<spirv::Word>,
+ Option<ast::Type>,
+ ) -> Result<spirv::Word, TranslateError>,
{
fn variable(
&mut self,
desc: ArgumentDescriptor<spirv::Word>,
t: Option<ast::Type>,
- ) -> spirv::Word {
+ ) -> Result<spirv::Word, TranslateError> {
self(desc, t)
}
@@ -2312,13 +2480,14 @@ where
&mut self,
desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
t: ast::Type,
- ) -> ast::Operand<spirv::Word> {
+ ) -> Result<ast::Operand<spirv::Word>, TranslateError> {
match desc.op {
- ast::Operand::Reg(id) => ast::Operand::Reg(self(desc.new_op(id), Some(t))),
- ast::Operand::Imm(imm) => ast::Operand::Imm(imm),
- ast::Operand::RegOffset(id, imm) => {
- ast::Operand::RegOffset(self(desc.new_op(id), Some(t)), imm)
- }
+ ast::Operand::Reg(id) => Ok(ast::Operand::Reg(self(desc.new_op(id), Some(t))?)),
+ ast::Operand::Imm(imm) => Ok(ast::Operand::Imm(imm)),
+ ast::Operand::RegOffset(id, imm) => Ok(ast::Operand::RegOffset(
+ self(desc.new_op(id), Some(t))?,
+ imm,
+ )),
}
}
@@ -2326,10 +2495,10 @@ where
&mut self,
desc: ArgumentDescriptor<ast::CallOperand<spirv::Word>>,
t: ast::Type,
- ) -> ast::CallOperand<spirv::Word> {
+ ) -> Result<ast::CallOperand<spirv::Word>, TranslateError> {
match desc.op {
- ast::CallOperand::Reg(id) => ast::CallOperand::Reg(self(desc.new_op(id), Some(t))),
- ast::CallOperand::Imm(imm) => ast::CallOperand::Imm(imm),
+ ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(self(desc.new_op(id), Some(t))?)),
+ ast::CallOperand::Imm(imm) => Ok(ast::CallOperand::Imm(imm)),
}
}
@@ -2337,24 +2506,22 @@ where
&mut self,
desc: ArgumentDescriptor<(spirv::Word, u8)>,
(scalar_type, vector_len): (ast::MovVectorType, u8),
- ) -> (spirv::Word, u8) {
- (
+ ) -> Result<(spirv::Word, u8), TranslateError> {
+ Ok((
self(
desc.new_op(desc.op.0),
Some(ast::Type::Vector(scalar_type.into(), vector_len)),
- ),
+ )?,
desc.op.1,
- )
+ ))
}
fn mov_operand(
&mut self,
- desc: ArgumentDescriptor<ast::Operand<spirv::Word>>,
+ desc: ArgumentDescriptor<ast::MovOperand<spirv::Word>>,
typ: ast::Type,
- ) -> ast::Operand<spirv::Word> {
- <Self as ArgumentMapVisitor<NormalizedArgParams, NormalizedArgParams>>::operand(
- self, desc, typ,
- )
+ ) -> Result<ast::MovOperand<spirv::Word>, TranslateError> {
+ todo!()
}
}
@@ -2439,12 +2606,15 @@ impl ast::Instruction<ExpandedArgParams> {
impl VisitVariableExpanded for ast::Instruction<ExpandedArgParams> {
fn visit_variable_extended<
- F: FnMut(ArgumentDescriptor<spirv::Word>, Option<ast::Type>) -> spirv::Word,
+ F: FnMut(
+ ArgumentDescriptor<spirv_headers::Word>,
+ Option<ast::Type>,
+ ) -> Result<spirv_headers::Word, TranslateError>,
>(
self,
f: &mut F,
- ) -> ExpandedStatement {
- Statement::Instruction(self.map(f))
+ ) -> Result<ExpandedStatement, TranslateError> {
+ Ok(Statement::Instruction(self.map(f)?))
}
}
@@ -2488,32 +2658,40 @@ enum ConversionKind {
}
impl<T> ast::PredAt<T> {
- fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::PredAt<U> {
- ast::PredAt {
+ fn map_variable<U, F: FnMut(T) -> Result<U, TranslateError>>(
+ self,
+ f: &mut F,
+ ) -> Result<ast::PredAt<U>, TranslateError> {
+ let new_label = f(self.label)?;
+ Ok(ast::PredAt {
not: self.not,
- label: f(self.label),
- }
+ label: new_label,
+ })
}
}
impl<'a> ast::Instruction<ast::ParsedArgParams<'a>> {
- fn map_variable<F: FnMut(&str) -> spirv::Word>(
+ fn map_variable<F: FnMut(&str) -> Result<spirv::Word, TranslateError>>(
self,
f: &mut F,
- ) -> ast::Instruction<NormalizedArgParams> {
+ ) -> Result<ast::Instruction<NormalizedArgParams>, TranslateError> {
match self {
ast::Instruction::Call(call) => {
let call_inst = ast::CallInst {
uniform: call.uniform,
- ret_params: call.ret_params.into_iter().map(|p| f(p)).collect(),
- func: f(call.func),
+ ret_params: call
+ .ret_params
+ .into_iter()
+ .map(|p| f(p))
+ .collect::<Result<_, _>>()?,
+ func: f(call.func)?,
param_list: call
.param_list
.into_iter()
.map(|p| p.map_variable(f))
- .collect(),
+ .collect::<Result<_, _>>()?,
};
- ast::Instruction::Call(call_inst)
+ Ok(ast::Instruction::Call(call_inst))
}
i => i.map(f),
}
@@ -2525,17 +2703,16 @@ impl<T: ArgParamsEx> ast::Arg1<T> {
self,
visitor: &mut V,
t: Option<ast::Type>,
- ) -> ast::Arg1<U> {
- ast::Arg1 {
- src: visitor.variable(
- ArgumentDescriptor {
- op: self.src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- t,
- ),
- }
+ ) -> Result<ast::Arg1<U>, TranslateError> {
+ let new_src = visitor.variable(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ Ok(ast::Arg1 { src: new_src })
}
}
@@ -2544,25 +2721,27 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
self,
visitor: &mut V,
t: ast::Type,
- ) -> ast::Arg2<U> {
- ast::Arg2 {
- dst: visitor.variable(
- ArgumentDescriptor {
- op: self.dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(t),
- ),
- src: visitor.operand(
- ArgumentDescriptor {
- op: self.src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- t,
- ),
- }
+ ) -> Result<ast::Arg2<U>, TranslateError> {
+ let new_dst = visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(t),
+ )?;
+ let new_src = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ Ok(ast::Arg2 {
+ dst: new_dst,
+ src: new_src,
+ })
}
fn map_ld<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
@@ -2570,29 +2749,28 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
visitor: &mut V,
t: ast::Type,
is_param: bool,
- ) -> ast::Arg2<U> {
- ast::Arg2 {
- dst: visitor.variable(
- ArgumentDescriptor {
- op: self.dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(t),
- ),
- src: visitor.operand(
- ArgumentDescriptor {
- op: self.src,
- is_dst: false,
- sema: if is_param {
- ArgumentSemantics::ParamPtr
- } else {
- ArgumentSemantics::Ptr
- },
+ ) -> Result<ast::Arg2<U>, TranslateError> {
+ let dst = visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(t),
+ )?;
+ let src = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: if is_param {
+ ArgumentSemantics::ParamPtr
+ } else {
+ ArgumentSemantics::Ptr
},
- t,
- ),
- }
+ },
+ t,
+ )?;
+ Ok(ast::Arg2 { dst, src })
}
fn map_cvt<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
@@ -2600,25 +2778,50 @@ impl<T: ArgParamsEx> ast::Arg2<T> {
visitor: &mut V,
dst_t: ast::Type,
src_t: ast::Type,
- ) -> ast::Arg2<U> {
- ast::Arg2 {
- dst: visitor.variable(
- ArgumentDescriptor {
- op: self.dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(dst_t),
- ),
- src: visitor.operand(
- ArgumentDescriptor {
- op: self.src,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- src_t,
- ),
- }
+ ) -> Result<ast::Arg2<U>, TranslateError> {
+ let dst = visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(dst_t),
+ )?;
+ let src = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ src_t,
+ )?;
+ Ok(ast::Arg2 { dst, src })
+ }
+}
+
+impl<T: ArgParamsEx> ast::Arg2Mov<T> {
+ fn map<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
+ self,
+ visitor: &mut V,
+ t: ast::Type,
+ ) -> Result<ast::Arg2Mov<U>, TranslateError> {
+ let dst = visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(t),
+ )?;
+ let src = visitor.mov_operand(
+ ArgumentDescriptor {
+ op: self.src,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ Ok(ast::Arg2Mov { dst, src })
}
}
@@ -2628,29 +2831,28 @@ impl<T: ArgParamsEx> ast::Arg2St<T> {
visitor: &mut V,
t: ast::Type,
is_param: bool,
- ) -> ast::Arg2St<U> {
- ast::Arg2St {
- src1: visitor.mov_operand(
- ArgumentDescriptor {
- op: self.src1,
- is_dst: is_param,
- sema: if is_param {
- ArgumentSemantics::ParamPtr
- } else {
- ArgumentSemantics::Ptr
- },
- },
- t,
- ),
- src2: visitor.operand(
- ArgumentDescriptor {
- op: self.src2,
- is_dst: false,
- sema: ArgumentSemantics::Default,
+ ) -> Result<ast::Arg2St<U>, TranslateError> {
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: is_param,
+ sema: if is_param {
+ ArgumentSemantics::ParamPtr
+ } else {
+ ArgumentSemantics::Ptr
},
- t,
- ),
- }
+ },
+ t,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ Ok(ast::Arg2St { src1, src2 })
}
}
@@ -2667,84 +2869,81 @@ impl<T: ArgParamsEx> ast::Arg2Vec<T> {
self,
visitor: &mut V,
(scalar_type, vec_len): (ast::MovVectorType, u8),
- ) -> ast::Arg2Vec<U> {
+ ) -> Result<ast::Arg2Vec<U>, TranslateError> {
match self {
- ast::Arg2Vec::Dst((dst, len), composite_src, scalar_src) => ast::Arg2Vec::Dst(
- (
- visitor.variable(
- ArgumentDescriptor {
- op: dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(ast::Type::Scalar(scalar_type.into())),
- ),
- len,
- ),
- visitor.variable(
+ ast::Arg2Vec::Dst((dst, len), composite_src, scalar_src) => {
+ let dst = visitor.variable(
+ ArgumentDescriptor {
+ op: dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(ast::Type::Scalar(scalar_type.into())),
+ )?;
+ let src1 = visitor.variable(
ArgumentDescriptor {
op: composite_src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(scalar_type.into())),
- ),
- visitor.variable(
+ )?;
+ let src2 = visitor.variable(
ArgumentDescriptor {
op: scalar_src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(scalar_type.into())),
- ),
- ),
- ast::Arg2Vec::Src(dst, src) => ast::Arg2Vec::Src(
- visitor.variable(
+ )?;
+ Ok(ast::Arg2Vec::Dst((dst, len), src1, src2))
+ }
+ ast::Arg2Vec::Src(dst, src) => {
+ let dst = visitor.variable(
ArgumentDescriptor {
op: dst,
is_dst: true,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(scalar_type.into())),
- ),
- visitor.src_vec_operand(
+ )?;
+ let src = visitor.src_vec_operand(
ArgumentDescriptor {
op: src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
(scalar_type, vec_len),
- ),
- ),
- ast::Arg2Vec::Both((dst, len), composite_src, src) => ast::Arg2Vec::Both(
- (
- visitor.variable(
- ArgumentDescriptor {
- op: dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(ast::Type::Scalar(scalar_type.into())),
- ),
- len,
- ),
- visitor.variable(
+ )?;
+ Ok(ast::Arg2Vec::Src(dst, src))
+ }
+ ast::Arg2Vec::Both((dst, len), composite_src, src) => {
+ let dst = visitor.variable(
+ ArgumentDescriptor {
+ op: dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(ast::Type::Scalar(scalar_type.into())),
+ )?;
+ let composite_src = visitor.variable(
ArgumentDescriptor {
op: composite_src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
Some(ast::Type::Scalar(scalar_type.into())),
- ),
- visitor.src_vec_operand(
+ )?;
+ let src = visitor.src_vec_operand(
ArgumentDescriptor {
op: src,
is_dst: false,
sema: ArgumentSemantics::Default,
},
(scalar_type, vec_len),
- ),
- ),
+ )?;
+ Ok(ast::Arg2Vec::Both((dst, len), composite_src, src))
+ }
}
}
}
@@ -2754,66 +2953,64 @@ impl<T: ArgParamsEx> ast::Arg3<T> {
self,
visitor: &mut V,
t: ast::Type,
- ) -> ast::Arg3<U> {
- ast::Arg3 {
- dst: visitor.variable(
- ArgumentDescriptor {
- op: self.dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(t),
- ),
- src1: visitor.operand(
- ArgumentDescriptor {
- op: self.src1,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- t,
- ),
- src2: visitor.operand(
- ArgumentDescriptor {
- op: self.src2,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- t,
- ),
- }
+ ) -> Result<ast::Arg3<U>, TranslateError> {
+ let dst = visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(t),
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ Ok(ast::Arg3 { dst, src1, src2 })
}
fn map_shift<U: ArgParamsEx, V: ArgumentMapVisitor<T, U>>(
self,
visitor: &mut V,
t: ast::Type,
- ) -> ast::Arg3<U> {
- ast::Arg3 {
- dst: visitor.variable(
- ArgumentDescriptor {
- op: self.dst,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(t),
- ),
- src1: visitor.operand(
- ArgumentDescriptor {
- op: self.src1,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- t,
- ),
- src2: visitor.operand(
- ArgumentDescriptor {
- op: self.src2,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- ast::Type::Scalar(ast::ScalarType::U32),
- ),
- }
+ ) -> Result<ast::Arg3<U>, TranslateError> {
+ let dst = visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(t),
+ )?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ ast::Type::Scalar(ast::ScalarType::U32),
+ )?;
+ Ok(ast::Arg3 { dst, src1, src2 })
}
}
@@ -2822,17 +3019,18 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
self,
visitor: &mut V,
t: ast::Type,
- ) -> ast::Arg4<U> {
- ast::Arg4 {
- dst1: visitor.variable(
- ArgumentDescriptor {
- op: self.dst1,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(ast::Type::Scalar(ast::ScalarType::Pred)),
- ),
- dst2: self.dst2.map(|dst2| {
+ ) -> Result<ast::Arg4<U>, TranslateError> {
+ let dst1 = visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst1,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ )?;
+ let dst2 = self
+ .dst2
+ .map(|dst2| {
visitor.variable(
ArgumentDescriptor {
op: dst2,
@@ -2841,24 +3039,30 @@ impl<T: ArgParamsEx> ast::Arg4<T> {
},
Some(ast::Type::Scalar(ast::ScalarType::Pred)),
)
- }),
- src1: visitor.operand(
- ArgumentDescriptor {
- op: self.src1,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- t,
- ),
- src2: visitor.operand(
- ArgumentDescriptor {
- op: self.src2,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- t,
- ),
- }
+ })
+ .transpose()?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ Ok(ast::Arg4 {
+ dst1,
+ dst2,
+ src1,
+ src2,
+ })
}
}
@@ -2867,17 +3071,18 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
self,
visitor: &mut V,
t: ast::Type,
- ) -> ast::Arg5<U> {
- ast::Arg5 {
- dst1: visitor.variable(
- ArgumentDescriptor {
- op: self.dst1,
- is_dst: true,
- sema: ArgumentSemantics::Default,
- },
- Some(ast::Type::Scalar(ast::ScalarType::Pred)),
- ),
- dst2: self.dst2.map(|dst2| {
+ ) -> Result<ast::Arg5<U>, TranslateError> {
+ let dst1 = visitor.variable(
+ ArgumentDescriptor {
+ op: self.dst1,
+ is_dst: true,
+ sema: ArgumentSemantics::Default,
+ },
+ Some(ast::Type::Scalar(ast::ScalarType::Pred)),
+ )?;
+ let dst2 = self
+ .dst2
+ .map(|dst2| {
visitor.variable(
ArgumentDescriptor {
op: dst2,
@@ -2886,40 +3091,47 @@ impl<T: ArgParamsEx> ast::Arg5<T> {
},
Some(ast::Type::Scalar(ast::ScalarType::Pred)),
)
- }),
- src1: visitor.operand(
- ArgumentDescriptor {
- op: self.src1,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- t,
- ),
- src2: visitor.operand(
- ArgumentDescriptor {
- op: self.src2,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- t,
- ),
- src3: visitor.operand(
- ArgumentDescriptor {
- op: self.src3,
- is_dst: false,
- sema: ArgumentSemantics::Default,
- },
- ast::Type::Scalar(ast::ScalarType::Pred),
- ),
- }
+ })
+ .transpose()?;
+ let src1 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src1,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ let src2 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src2,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ t,
+ )?;
+ let src3 = visitor.operand(
+ ArgumentDescriptor {
+ op: self.src3,
+ is_dst: false,
+ sema: ArgumentSemantics::Default,
+ },
+ ast::Type::Scalar(ast::ScalarType::Pred),
+ )?;
+ Ok(ast::Arg5 {
+ dst1,
+ dst2,
+ src1,
+ src2,
+ src3,
+ })
}
}
impl<T> ast::CallOperand<T> {
- fn map_variable<U, F: FnMut(T) -> U>(self, f: &mut F) -> ast::CallOperand<U> {
+ fn map_variable<U, F: FnMut(T) -> Result<U, TranslateError>>(self, f: &mut F) -> Result<ast::CallOperand<U>, TranslateError> {
match self {
- ast::CallOperand::Reg(id) => ast::CallOperand::Reg(f(id)),
- ast::CallOperand::Imm(x) => ast::CallOperand::Imm(x),
+ ast::CallOperand::Reg(id) => Ok(ast::CallOperand::Reg(f(id)?)),
+ ast::CallOperand::Imm(x) => Ok(ast::CallOperand::Imm(x)),
}
}
}
@@ -3195,37 +3407,37 @@ fn insert_with_conversions_pre_conv<T>(
fn get_implicit_conversions_ld_dst<
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
>(
- id_def: &mut NumericIdResolver,
+ id_def: &mut MutableNumericIdResolver,
instr_type: ast::Type,
dst: spirv::Word,
should_convert: ShouldConvert,
in_reverse: bool,
-) -> Option<ImplicitConversion> {
- let dst_type = id_def.get_type(dst).unwrap_or_else(|| todo!()).1;
+) -> Result<Option<ImplicitConversion>, TranslateError> {
+ let dst_type = id_def.get_typed(dst)?;
if let Some(conv) = should_convert(dst_type, instr_type) {
- Some(ImplicitConversion {
+ Ok(Some(ImplicitConversion {
src: u32::max_value(),
dst: u32::max_value(),
from: if !in_reverse { dst_type } else { instr_type },
to: if !in_reverse { instr_type } else { dst_type },
kind: conv,
- })
+ }))
} else {
- None
+ Ok(None)
}
}
fn get_implicit_conversions_ld_src(
- id_def: &mut NumericIdResolver,
+ id_def: &mut MutableNumericIdResolver,
instr_type: ast::Type,
state_space: ast::LdStateSpace,
src: spirv::Word,
-) -> Vec<ImplicitConversion> {
- let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1;
+) -> Result<Vec<ImplicitConversion>, TranslateError> {
+ let src_type = id_def.get_typed(src)?;
match state_space {
ast::LdStateSpace::Param => {
if src_type != instr_type {
- vec![
+ Ok(vec![
ImplicitConversion {
src: u32::max_value(),
dst: u32::max_value(),
@@ -3234,9 +3446,9 @@ fn get_implicit_conversions_ld_src(
kind: ConversionKind::Default,
};
1
- ]
+ ])
} else {
- Vec::new()
+ Ok(Vec::new())
}
}
ast::LdStateSpace::Generic | ast::LdStateSpace::Global => {
@@ -3268,12 +3480,12 @@ fn get_implicit_conversions_ld_src(
kind: ConversionKind::Ptr(state_space),
});
if result.len() == 2 {
- let new_id = id_def.new_id(Some((StateSpace::Reg, new_src_type)));
+ let new_id = id_def.new_id(new_src_type);
result[0].dst = new_id;
result[1].src = new_id;
result[1].from = new_src_type;
}
- result
+ Ok(result)
}
_ => todo!(),
}
@@ -3281,10 +3493,10 @@ fn get_implicit_conversions_ld_src(
fn insert_implicit_conversions_ld_src(
func: &mut Vec<ExpandedStatement>,
instr_type: ast::Type,
- id_def: &mut NumericIdResolver,
+ id_def: &mut MutableNumericIdResolver,
state_space: ast::LdStateSpace,
src: spirv::Word,
-) -> spirv::Word {
+) -> Result<spirv::Word, TranslateError> {
match state_space {
ast::LdStateSpace::Param => insert_implicit_conversions_ld_src_impl(
func,
@@ -3304,15 +3516,15 @@ fn insert_implicit_conversions_ld_src(
new_src_type,
src,
should_convert_ld_generic_src_to_bitcast,
- );
- insert_conversion_src(
+ )?;
+ Ok(insert_conversion_src(
func,
id_def,
new_src,
new_src_type,
instr_type,
ConversionKind::Ptr(state_space),
- )
+ ))
}
_ => todo!(),
}
@@ -3322,16 +3534,18 @@ fn insert_implicit_conversions_ld_src_impl<
ShouldConvert: FnOnce(ast::Type, ast::Type) -> Option<ConversionKind>,
>(
func: &mut Vec<ExpandedStatement>,
- id_def: &mut NumericIdResolver,
+ id_def: &mut MutableNumericIdResolver,
instr_type: ast::Type,
src: spirv::Word,
should_convert: ShouldConvert,
-) -> spirv::Word {
- let src_type = id_def.get_type(src).unwrap_or_else(|| todo!()).1;
+) -> Result<spirv::Word, TranslateError> {
+ let src_type = id_def.get_typed(src)?;
if let Some(conv) = should_convert(src_type, instr_type) {
- insert_conversion_src(func, id_def, src, src_type, instr_type, conv)
+ Ok(insert_conversion_src(
+ func, id_def, src, src_type, instr_type, conv,
+ ))
} else {
- src
+ Ok(src)
}
}
@@ -3363,13 +3577,13 @@ fn should_convert_ld_generic_src_to_bitcast(
#[must_use]
fn insert_conversion_src(
func: &mut Vec<ExpandedStatement>,
- id_def: &mut NumericIdResolver,
+ id_def: &mut MutableNumericIdResolver,
src: spirv::Word,
src_type: ast::Type,
instr_type: ast::Type,
conv: ConversionKind,
) -> spirv::Word {
- let temp_src = id_def.new_id(Some((StateSpace::Reg, instr_type)));
+ let temp_src = id_def.new_id(instr_type);
func.push(Statement::Conversion(ImplicitConversion {
src: src,
dst: temp_src,
@@ -3408,14 +3622,14 @@ fn insert_with_implicit_conversion_dst<
#[must_use]
fn get_conversion_dst(
- id_def: &mut NumericIdResolver,
+ id_def: &mut MutableNumericIdResolver,
dst: &mut spirv::Word,
instr_type: ast::Type,
dst_type: ast::Type,
kind: ConversionKind,
) -> ExpandedStatement {
let original_dst = *dst;
- let temp_dst = id_def.new_id(Some((StateSpace::Reg, instr_type)));
+ let temp_dst = id_def.new_id(instr_type);
*dst = temp_dst;
Statement::Conversion(ImplicitConversion {
src: temp_dst,
@@ -3525,17 +3739,17 @@ fn should_convert_relaxed_dst(
fn insert_implicit_bitcasts(
func: &mut Vec<ExpandedStatement>,
- id_def: &mut NumericIdResolver,
+ id_def: &mut MutableNumericIdResolver,
stmt: impl VisitVariableExpanded,
-) {
+) -> Result<(), TranslateError> {
let mut dst_coercion = None;
let instr = stmt.visit_variable_extended(&mut |mut desc, typ| {
let id_type_from_instr = match typ {
Some(t) => t,
- None => return desc.op,
+ None => return Ok(desc.op),
};
- let id_actual_type = id_def.get_type(desc.op).unwrap().1;
- if should_bitcast(id_type_from_instr, id_def.get_type(desc.op).unwrap().1) {
+ let id_actual_type = id_def.get_typed(desc.op)?;
+ if should_bitcast(id_type_from_instr, id_def.get_typed(desc.op)?) {
if desc.is_dst {
dst_coercion = Some(get_conversion_dst(
id_def,
@@ -3544,25 +3758,26 @@ fn insert_implicit_bitcasts(
id_actual_type,
ConversionKind::Default,
));
- desc.op
+ Ok(desc.op)
} else {
- insert_conversion_src(
+ Ok(insert_conversion_src(
func,
id_def,
desc.op,
id_actual_type,
id_type_from_instr,
ConversionKind::Default,
- )
+ ))
}
} else {
- desc.op
+ Ok(desc.op)
}
- });
+ })?;
func.push(instr);
if let Some(cond) = dst_coercion {
func.push(cond);
}
+ Ok(())
}
impl<'a> ast::MethodDecl<'a, ast::ParsedArgParams<'a>> {
fn name(&self) -> &'a str {