use crate::ast; use half::f16; use rspirv::dr; use std::{borrow::Cow, collections::BTreeSet, ffi::CString, hash::Hash, iter, mem}; use std::{ collections::{hash_map, HashMap, HashSet}, convert::TryInto, }; use rspirv::binary::Assemble; static ZLUDA_PTX_IMPL: &'static [u8] = include_bytes!("../lib/zluda_ptx_impl.spv"); quick_error! { #[derive(Debug)] pub enum TranslateError { UnknownSymbol {} UntypedSymbol {} MismatchedType {} Spirv(err: rspirv::dr::Error) { from() display("{}", err) cause(err) } Unreachable {} Todo {} } } #[cfg(debug_assertions)] fn error_unreachable() -> TranslateError { unreachable!() } #[cfg(not(debug_assertions))] fn error_unreachable() -> TranslateError { TranslateError::Unreachable } #[derive(PartialEq, Eq, Hash, Clone)] enum SpirvType { Base(SpirvScalarKey), Vector(SpirvScalarKey, u8), Array(SpirvScalarKey, Vec), Pointer(Box, spirv::StorageClass), Func(Option>, Vec), Struct(Vec), } impl SpirvType { fn new_pointer(t: ast::Type, sc: spirv::StorageClass) -> Self { let key = t.into(); SpirvType::Pointer(Box::new(key), sc) } } impl From for SpirvType { fn from(t: ast::Type) -> Self { match t { ast::Type::Scalar(t) => SpirvType::Base(t.into()), ast::Type::Vector(typ, len) => SpirvType::Vector(typ.into(), len), ast::Type::Array(t, len) => SpirvType::Array(t.into(), len), ast::Type::Pointer(pointer_t, state_space) => SpirvType::Pointer( Box::new(SpirvType::from(ast::Type::from(pointer_t))), state_space.to_spirv(), ), } } } impl From for ast::Type { fn from(t: ast::PointerType) -> Self { match t { ast::PointerType::Scalar(t) => ast::Type::Scalar(t), ast::PointerType::Vector(t, len) => ast::Type::Vector(t, len), ast::PointerType::Array(t, dims) => ast::Type::Array(t, dims), ast::PointerType::Pointer(t, space) => { ast::Type::Pointer(ast::PointerType::Scalar(t), space) } } } } impl ast::Type { fn param_pointer_to(self, space: ast::LdStateSpace) -> Result { Ok(match self { ast::Type::Scalar(t) => ast::Type::Pointer(ast::PointerType::Scalar(t), space), ast::Type::Vector(t, len) => { ast::Type::Pointer(ast::PointerType::Vector(t, len), space) } ast::Type::Array(t, _) => ast::Type::Pointer(ast::PointerType::Scalar(t), space), ast::Type::Pointer(ast::PointerType::Scalar(t), space) => { ast::Type::Pointer(ast::PointerType::Pointer(t, space), space) } ast::Type::Pointer(_, _) => return Err(error_unreachable()), }) } } impl Into for ast::PointerStateSpace { fn into(self) -> spirv::StorageClass { match self { ast::PointerStateSpace::Const => spirv::StorageClass::UniformConstant, ast::PointerStateSpace::Global => spirv::StorageClass::CrossWorkgroup, ast::PointerStateSpace::Shared => spirv::StorageClass::Workgroup, ast::PointerStateSpace::Param => spirv::StorageClass::Function, ast::PointerStateSpace::Generic => spirv::StorageClass::Generic, } } } impl From for SpirvType { fn from(t: ast::ScalarType) -> Self { SpirvType::Base(t.into()) } } struct TypeWordMap { void: spirv::Word, complex: HashMap, constants: HashMap<(SpirvType, u64), spirv::Word>, } // SPIR-V integer type definitions are signless, more below: // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a #[derive(PartialEq, Eq, Hash, Clone, Copy)] enum SpirvScalarKey { B8, B16, B32, B64, F16, F32, F64, Pred, F16x2, } impl From for SpirvScalarKey { fn from(t: ast::ScalarType) -> Self { match t { ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8, ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { SpirvScalarKey::B16 } ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => { SpirvScalarKey::B32 } ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => { SpirvScalarKey::B64 } ast::ScalarType::F16 => SpirvScalarKey::F16, ast::ScalarType::F32 => SpirvScalarKey::F32, ast::ScalarType::F64 => SpirvScalarKey::F64, ast::ScalarType::F16x2 => SpirvScalarKey::F16x2, ast::ScalarType::Pred => SpirvScalarKey::Pred, } } } impl TypeWordMap { fn new(b: &mut dr::Builder) -> TypeWordMap { let void = b.type_void(); TypeWordMap { void: void, complex: HashMap::::new(), constants: HashMap::new(), } } fn void(&self) -> spirv::Word { self.void } fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> spirv::Word { let key: SpirvScalarKey = t.into(); self.get_or_add_spirv_scalar(b, key) } fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> spirv::Word { *self .complex .entry(SpirvType::Base(key)) .or_insert_with(|| match key { SpirvScalarKey::B8 => b.type_int(8, 0), SpirvScalarKey::B16 => b.type_int(16, 0), SpirvScalarKey::B32 => b.type_int(32, 0), SpirvScalarKey::B64 => b.type_int(64, 0), SpirvScalarKey::F16 => b.type_float(16), SpirvScalarKey::F32 => b.type_float(32), SpirvScalarKey::F64 => b.type_float(64), SpirvScalarKey::Pred => b.type_bool(), SpirvScalarKey::F16x2 => todo!(), }) } fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> spirv::Word { match t { SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key), SpirvType::Pointer(ref typ, storage) => { let base = self.get_or_add(b, *typ.clone()); *self .complex .entry(t) .or_insert_with(|| b.type_pointer(None, storage, base)) } SpirvType::Vector(typ, len) => { let base = self.get_or_add_spirv_scalar(b, typ); *self .complex .entry(t) .or_insert_with(|| b.type_vector(base, len as u32)) } SpirvType::Array(typ, array_dimensions) => { let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); let (base_type, length) = match &*array_dimensions { &[len] => { let base = self.get_or_add_spirv_scalar(b, typ); let len_const = b.constant_u32(u32_type, None, len); (base, len_const) } array_dimensions => { let base = self .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec())); let len_const = b.constant_u32(u32_type, None, array_dimensions[0]); (base, len_const) } }; *self .complex .entry(SpirvType::Array(typ, array_dimensions)) .or_insert_with(|| b.type_array(base_type, length)) } SpirvType::Func(ref out_params, ref in_params) => { let out_t = match out_params { Some(p) => self.get_or_add(b, *p.clone()), None => self.void(), }; let in_t = in_params .iter() .map(|t| self.get_or_add(b, t.clone())) .collect::>(); *self .complex .entry(t) .or_insert_with(|| b.type_function(out_t, in_t)) } SpirvType::Struct(ref underlying) => { let underlying_ids = underlying .iter() .map(|t| self.get_or_add_spirv_scalar(b, *t)) .collect::>(); *self .complex .entry(t) .or_insert_with(|| b.type_struct(underlying_ids)) } } } fn get_or_add_fn( &mut self, b: &mut dr::Builder, in_params: impl ExactSizeIterator, mut out_params: impl ExactSizeIterator, ) -> (spirv::Word, spirv::Word) { let (out_args, out_spirv_type) = if out_params.len() == 0 { (None, self.void()) } else if out_params.len() == 1 { let arg_as_key = out_params.next().unwrap(); ( Some(Box::new(arg_as_key.clone())), self.get_or_add(b, arg_as_key), ) } else { todo!() }; ( out_spirv_type, self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::>())), ) } fn get_or_add_constant( &mut self, b: &mut dr::Builder, typ: &ast::Type, init: &[u8], ) -> Result { Ok(match typ { ast::Type::Scalar(t) => match t { ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self .get_or_add_constant_single::( b, *t, init, |v| v as u64, |b, result_type, v| b.constant_u32(result_type, None, v as u32), ), ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self .get_or_add_constant_single::( b, *t, init, |v| v as u64, |b, result_type, v| b.constant_u32(result_type, None, v as u32), ), ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self .get_or_add_constant_single::( b, *t, init, |v| v as u64, |b, result_type, v| b.constant_u32(result_type, None, v), ), ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self .get_or_add_constant_single::( b, *t, init, |v| v, |b, result_type, v| b.constant_u64(result_type, None, v), ), ast::ScalarType::F16 => self.get_or_add_constant_single::( b, *t, init, |v| unsafe { mem::transmute::<_, u16>(v) } as u64, |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()), ), ast::ScalarType::F32 => self.get_or_add_constant_single::( b, *t, init, |v| unsafe { mem::transmute::<_, u32>(v) } as u64, |b, result_type, v| b.constant_f32(result_type, None, v), ), ast::ScalarType::F64 => self.get_or_add_constant_single::( b, *t, init, |v| unsafe { mem::transmute::<_, u64>(v) }, |b, result_type, v| b.constant_f64(result_type, None, v), ), ast::ScalarType::F16x2 => return Err(TranslateError::Todo), ast::ScalarType::Pred => self.get_or_add_constant_single::( b, *t, init, |v| v as u64, |b, result_type, v| { if v == 0 { b.constant_false(result_type, None) } else { b.constant_true(result_type, None) } }, ), }, ast::Type::Vector(typ, len) => { let result_type = self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len)); let size_of_t = typ.size_of(); let components = (0..*len) .map(|x| { self.get_or_add_constant( b, &ast::Type::Scalar(*typ), &init[((size_of_t as usize) * (x as usize))..], ) }) .collect::, _>>()?; b.constant_composite(result_type, None, &components) } ast::Type::Array(typ, dims) => match dims.as_slice() { [] => return Err(error_unreachable()), [dim] => { let result_type = self .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim])); let size_of_t = typ.size_of(); let components = (0..*dim) .map(|x| { self.get_or_add_constant( b, &ast::Type::Scalar(*typ), &init[((size_of_t as usize) * (x as usize))..], ) }) .collect::, _>>()?; b.constant_composite(result_type, None, &components) } [first_dim, rest @ ..] => { let result_type = self.get_or_add( b, SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()), ); let size_of_t = rest .iter() .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y)); let components = (0..*first_dim) .map(|x| { self.get_or_add_constant( b, &ast::Type::Array(*typ, rest.to_vec()), &init[((size_of_t as usize) * (x as usize))..], ) }) .collect::, _>>()?; b.constant_composite(result_type, None, &components) } }, ast::Type::Pointer(typ, state_space) => { let base_t = typ.clone().into(); let base = self.get_or_add_constant(b, &base_t, &[])?; let result_type = self.get_or_add( b, SpirvType::Pointer( Box::new(SpirvType::from(base_t)), (*state_space).to_spirv(), ), ); b.variable(result_type, None, (*state_space).to_spirv(), Some(base)) } }) } fn get_or_add_constant_single< T: Copy, CastAsU64: FnOnce(T) -> u64, InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word, >( &mut self, b: &mut dr::Builder, key: ast::ScalarType, init: &[u8], cast: CastAsU64, f: InsertConstant, ) -> spirv::Word { let value = unsafe { *(init.as_ptr() as *const T) }; let value_64 = cast(value); let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64); match self.constants.get(&ht_key) { Some(value) => *value, None => { let spirv_type = self.get_or_add_scalar(b, key); let result = f(b, spirv_type, value); self.constants.insert(ht_key, result); result } } } } pub struct Module { pub spirv: dr::Module, pub kernel_info: HashMap, pub should_link_ptx_impl: Option<&'static [u8]>, pub build_options: CString, } impl Module { pub fn assemble(&self) -> Vec { self.spirv.assemble() } } pub struct KernelInfo { pub arguments_sizes: Vec, pub uses_shared_mem: bool, } pub fn to_spirv_module<'a>(ast: ast::Module<'a>) -> Result { let mut id_defs = GlobalStringIdResolver::new(1); let mut ptx_impl_imports = HashMap::new(); let directives = ast .directives .into_iter() .filter_map(|directive| { translate_directive(&mut id_defs, &mut ptx_impl_imports, directive).transpose() }) .collect::, _>>()?; let must_link_ptx_impl = ptx_impl_imports.len() > 0; let directives = ptx_impl_imports .into_iter() .map(|(_, v)| v) .chain(directives.into_iter()) .collect::>(); let mut builder = dr::Builder::new(); builder.reserve_ids(id_defs.current_id()); let call_map = get_call_map(&directives); let mut directives = convert_dynamic_shared_memory_usage(directives, &mut || builder.id()); normalize_variable_decls(&mut directives); let denorm_information = compute_denorm_information(&directives); // https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_logicallayout_a_logical_layout_of_a_module builder.set_version(1, 3); emit_capabilities(&mut builder); emit_extensions(&mut builder); let opencl_id = emit_opencl_import(&mut builder); emit_memory_model(&mut builder); let mut map = TypeWordMap::new(&mut builder); emit_builtins(&mut builder, &mut map, &id_defs); let mut kernel_info = HashMap::new(); let build_options = emit_denorm_build_string(&call_map, &denorm_information); emit_directives( &mut builder, &mut map, &id_defs, opencl_id, &denorm_information, &call_map, directives, &mut kernel_info, )?; let spirv = builder.module(); Ok(Module { spirv, kernel_info, should_link_ptx_impl: if must_link_ptx_impl { Some(ZLUDA_PTX_IMPL) } else { None }, build_options, }) } // TODO: remove this once we have perf-function support for denorms fn emit_denorm_build_string( call_map: &HashMap<&str, HashSet>, denorm_information: &HashMap>, ) -> CString { let denorm_counts = denorm_information .iter() .map(|(method, meth_denorm)| { let f16_count = meth_denorm .get(&(mem::size_of::() as u8)) .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) .1; let f32_count = meth_denorm .get(&(mem::size_of::() as u8)) .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) .1; (method, (f16_count + f32_count)) }) .collect::>(); let mut flush_over_preserve = 0; for (kernel, children) in call_map { flush_over_preserve += *denorm_counts.get(&MethodName::Kernel(kernel)).unwrap_or(&0); for child_fn in children { flush_over_preserve += *denorm_counts .get(&MethodName::Func(*child_fn)) .unwrap_or(&0); } } if flush_over_preserve > 0 { CString::new("-cl-denorms-are-zero").unwrap() } else { CString::default() } } fn emit_directives<'input>( builder: &mut dr::Builder, map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver<'input>, opencl_id: spirv::Word, denorm_information: &HashMap, HashMap>, call_map: &HashMap<&'input str, HashSet>, directives: Vec, kernel_info: &mut HashMap, ) -> Result<(), TranslateError> { let empty_body = Vec::new(); for d in directives.iter() { match d { Directive::Variable(var) => { emit_variable(builder, map, &var)?; } Directive::Method(f) => { let f_body = match &f.body { Some(f) => f, None => { if f.import_as.is_some() { &empty_body } else { continue; } } }; for var in f.globals.iter() { emit_variable(builder, map, var)?; } emit_function_header( builder, map, &id_defs, &f.globals, &f.spirv_decl, &denorm_information, call_map, &directives, kernel_info, )?; emit_function_body_ops(builder, map, opencl_id, &f_body)?; builder.end_function()?; if let (ast::MethodDecl::Func(_, fn_id, _), Some(name)) = (&f.func_decl, &f.import_as) { builder.decorate( *fn_id, spirv::Decoration::LinkageAttributes, &[ dr::Operand::LiteralString(name.clone()), dr::Operand::LinkageType(spirv::LinkageType::Import), ], ); } } } } Ok(()) } fn get_call_map<'input>( module: &[Directive<'input>], ) -> HashMap<&'input str, HashSet> { let mut directly_called_by = HashMap::new(); for directive in module { match directive { Directive::Method(Function { func_decl, body: Some(statements), .. }) => { let call_key = MethodName::new(&func_decl); if let hash_map::Entry::Vacant(entry) = directly_called_by.entry(call_key) { entry.insert(Vec::new()); } for statement in statements { match statement { Statement::Call(call) => { multi_hash_map_append(&mut directly_called_by, call_key, call.func); } _ => {} } } } _ => {} } } let mut result = HashMap::new(); for (method_key, children) in directly_called_by.iter() { match method_key { MethodName::Kernel(name) => { let mut visited = HashSet::new(); for child in children { add_call_map_single(&directly_called_by, &mut visited, *child); } result.insert(*name, visited); } MethodName::Func(_) => {} } } result } fn add_call_map_single<'input>( directly_called_by: &MultiHashMap, spirv::Word>, visited: &mut HashSet, current: spirv::Word, ) { if !visited.insert(current) { return; } if let Some(children) = directly_called_by.get(&MethodName::Func(current)) { for child in children { add_call_map_single(directly_called_by, visited, *child); } } } type MultiHashMap = HashMap>; fn multi_hash_map_append(m: &mut MultiHashMap, key: K, value: V) { match m.entry(key) { hash_map::Entry::Occupied(mut entry) => { entry.get_mut().push(value); } hash_map::Entry::Vacant(entry) => { entry.insert(vec![value]); } } } // PTX represents dynamically allocated shared local memory as // .extern .shared .align 4 .b8 shared_mem[]; // In SPIRV/OpenCL world this is expressed as an additional argument // This pass looks for all uses of .extern .shared and converts them to // an additional method argument fn convert_dynamic_shared_memory_usage<'input>( module: Vec>, new_id: &mut impl FnMut() -> spirv::Word, ) -> Vec> { let mut extern_shared_decls = HashMap::new(); for dir in module.iter() { match dir { Directive::Variable(var) => { if let ast::VariableType::Shared(ast::VariableGlobalType::Pointer(p_type, _)) = var.v_type { extern_shared_decls.insert(var.name, p_type); } } _ => {} } } if extern_shared_decls.len() == 0 { return module; } let mut methods_using_extern_shared = HashSet::new(); let mut directly_called_by = MultiHashMap::new(); let module = module .into_iter() .map(|directive| match directive { Directive::Method(Function { func_decl, globals, body: Some(statements), import_as, spirv_decl, }) => { let call_key = MethodName::new(&func_decl); let statements = statements .into_iter() .map(|statement| match statement { Statement::Call(call) => { multi_hash_map_append(&mut directly_called_by, call.func, call_key); Statement::Call(call) } statement => statement.map_id(&mut |id, _| { if extern_shared_decls.contains_key(&id) { methods_using_extern_shared.insert(call_key); } id }), }) .collect(); Directive::Method(Function { func_decl, globals, body: Some(statements), import_as, spirv_decl, }) } directive => directive, }) .collect::>(); // If there's a chain `kernel` -> `fn1` -> `fn2`, where only `fn2` uses extern shared, // make sure it gets propagated to `fn1` and `kernel` get_callers_of_extern_shared(&mut methods_using_extern_shared, &directly_called_by); // now visit every method declaration and inject those additional arguments module .into_iter() .map(|directive| match directive { Directive::Method(Function { func_decl, globals, body: Some(statements), import_as, mut spirv_decl, }) => { if !methods_using_extern_shared.contains(&spirv_decl.name) { return Directive::Method(Function { func_decl, globals, body: Some(statements), import_as, spirv_decl, }); } let shared_id_param = new_id(); spirv_decl.input.push({ ast::Variable { align: None, v_type: ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::U8), ast::LdStateSpace::Shared, ), array_init: Vec::new(), name: shared_id_param, } }); spirv_decl.uses_shared_mem = true; let shared_var_id = new_id(); let shared_var = ExpandedStatement::Variable(ast::Variable { align: None, name: shared_var_id, array_init: Vec::new(), v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( ast::SizedScalarType::B8, ast::PointerStateSpace::Shared, )), }); let shared_var_st = ExpandedStatement::StoreVar(StoreVarDetails { arg: ast::Arg2St { src1: shared_var_id, src2: shared_id_param, }, typ: ast::Type::Scalar(ast::ScalarType::B8), member_index: None, }); let mut new_statements = vec![shared_var, shared_var_st]; replace_uses_of_shared_memory( &mut new_statements, new_id, &extern_shared_decls, &mut methods_using_extern_shared, shared_id_param, shared_var_id, statements, ); Directive::Method(Function { func_decl, globals, body: Some(new_statements), import_as, spirv_decl, }) } directive => directive, }) .collect::>() } fn replace_uses_of_shared_memory<'a>( result: &mut Vec, new_id: &mut impl FnMut() -> spirv::Word, extern_shared_decls: &HashMap, methods_using_extern_shared: &mut HashSet>, shared_id_param: spirv::Word, shared_var_id: spirv::Word, statements: Vec, ) { for statement in statements { match statement { Statement::Call(mut call) => { // We can safely skip checking call arguments, // because there's simply no way to pass shared ptr // without converting it to .b64 first if methods_using_extern_shared.contains(&MethodName::Func(call.func)) { call.param_list .push((shared_id_param, ast::FnArgumentType::Shared)); } result.push(Statement::Call(call)) } statement => { let new_statement = statement.map_id(&mut |id, _| { if let Some(typ) = extern_shared_decls.get(&id) { if *typ == ast::SizedScalarType::B8 { return shared_var_id; } let replacement_id = new_id(); result.push(Statement::Conversion(ImplicitConversion { src: shared_var_id, dst: replacement_id, from: ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::B8), ast::LdStateSpace::Shared, ), to: ast::Type::Pointer( ast::PointerType::Scalar((*typ).into()), ast::LdStateSpace::Shared, ), kind: ConversionKind::PtrToPtr { spirv_ptr: true }, src_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default, })); replacement_id } else { id } }); result.push(new_statement); } } } } fn get_callers_of_extern_shared<'a>( methods_using_extern_shared: &mut HashSet>, directly_called_by: &MultiHashMap>, ) { let direct_uses_of_extern_shared = methods_using_extern_shared .iter() .filter_map(|method| { if let MethodName::Func(f_id) = method { Some(*f_id) } else { None } }) .collect::>(); for fn_id in direct_uses_of_extern_shared { get_callers_of_extern_shared_single(methods_using_extern_shared, directly_called_by, fn_id); } } fn get_callers_of_extern_shared_single<'a>( methods_using_extern_shared: &mut HashSet>, directly_called_by: &MultiHashMap>, fn_id: spirv::Word, ) { if let Some(callers) = directly_called_by.get(&fn_id) { for caller in callers { if methods_using_extern_shared.insert(*caller) { if let MethodName::Func(caller_fn) = caller { get_callers_of_extern_shared_single( methods_using_extern_shared, directly_called_by, *caller_fn, ); } } } } } type DenormCountMap = HashMap; fn denorm_count_map_update(map: &mut DenormCountMap, key: T, value: bool) { let num_value = if value { 1 } else { -1 }; denorm_count_map_update_impl(map, key, num_value); } fn denorm_count_map_update_impl( map: &mut DenormCountMap, key: T, num_value: isize, ) { match map.entry(key) { hash_map::Entry::Occupied(mut counter) => { *(counter.get_mut()) += num_value; } hash_map::Entry::Vacant(entry) => { entry.insert(num_value); } } } // HACK ALERT! // This function is a "good enough" heuristic of whetever to mark f16/f32 operations // in the kernel as flushing denorms to zero or preserving them // PTX support per-instruction ftz information. Unfortunately SPIR-V has no // such capability, so instead we guesstimate which use is more common in the kernel // and emit suitable execution mode fn compute_denorm_information<'input>( module: &[Directive<'input>], ) -> HashMap, HashMap> { let mut denorm_methods = HashMap::new(); for directive in module { match directive { Directive::Variable(_) | Directive::Method(Function { body: None, .. }) => {} Directive::Method(Function { func_decl, body: Some(statements), .. }) => { let mut flush_counter = DenormCountMap::new(); let method_key = MethodName::new(func_decl); for statement in statements { match statement { Statement::Instruction(inst) => { if let Some((flush, width)) = inst.flush_to_zero() { denorm_count_map_update(&mut flush_counter, width, flush); } } Statement::LoadVar(..) => {} Statement::StoreVar(..) => {} Statement::Call(_) => {} Statement::Conditional(_) => {} Statement::Conversion(_) => {} Statement::Constant(_) => {} Statement::RetValue(_, _) => {} Statement::Label(_) => {} Statement::Variable(_) => {} Statement::PtrAccess { .. } => {} Statement::RepackVector(_) => {} } } denorm_methods.insert(method_key, flush_counter); } } } denorm_methods .into_iter() .map(|(name, v)| { let width_to_denorm = v .into_iter() .map(|(k, flush_over_preserve)| { let mode = if flush_over_preserve > 0 { spirv::FPDenormMode::FlushToZero } else { spirv::FPDenormMode::Preserve }; (k, (mode, flush_over_preserve)) }) .collect(); (name, width_to_denorm) }) .collect() } #[derive(Hash, PartialEq, Eq, Copy, Clone)] enum MethodName<'input> { Kernel(&'input str), Func(spirv::Word), } impl<'input> MethodName<'input> { fn new(decl: &ast::MethodDecl<'input, spirv::Word>) -> Self { match decl { ast::MethodDecl::Kernel { name, .. } => MethodName::Kernel(name), ast::MethodDecl::Func(_, id, _) => MethodName::Func(*id), } } } fn emit_builtins( builder: &mut dr::Builder, map: &mut TypeWordMap, id_defs: &GlobalStringIdResolver, ) { for (reg, id) in id_defs.special_registers.builtins() { let result_type = map.get_or_add( builder, SpirvType::Pointer( Box::new(SpirvType::from(reg.get_type())), spirv::StorageClass::Input, ), ); builder.variable(result_type, Some(id), spirv::StorageClass::Input, None); builder.decorate( id, spirv::Decoration::BuiltIn, &[dr::Operand::BuiltIn(reg.get_builtin())], ); } } fn emit_function_header<'a>( builder: &mut dr::Builder, map: &mut TypeWordMap, defined_globals: &GlobalStringIdResolver<'a>, synthetic_globals: &[ast::Variable], func_decl: &SpirvMethodDecl<'a>, _denorm_information: &HashMap, HashMap>, call_map: &HashMap<&'a str, HashSet>, direcitves: &[Directive], kernel_info: &mut HashMap, ) -> Result<(), TranslateError> { if let MethodName::Kernel(name) = func_decl.name { let input_args = if !func_decl.uses_shared_mem { func_decl.input.as_slice() } else { &func_decl.input[0..func_decl.input.len() - 1] }; let args_lens = input_args .iter() .map(|param| param.v_type.size_of()) .collect(); kernel_info.insert( name.to_string(), KernelInfo { arguments_sizes: args_lens, uses_shared_mem: func_decl.uses_shared_mem, }, ); } let (ret_type, func_type) = get_function_type(builder, map, &func_decl.input, &func_decl.output); let fn_id = match func_decl.name { MethodName::Kernel(name) => { let fn_id = defined_globals.get_id(name)?; let mut global_variables = defined_globals .variables_type_check .iter() .filter_map(|(k, t)| t.as_ref().map(|_| *k)) .collect::>(); let mut interface = defined_globals.special_registers.interface(); for ast::Variable { name, .. } in synthetic_globals { interface.push(*name); } let empty_hash_set = HashSet::new(); let child_fns = call_map.get(name).unwrap_or(&empty_hash_set); for directive in direcitves { match directive { Directive::Method(Function { func_decl: ast::MethodDecl::Func(_, name, _), globals, .. }) => { if child_fns.contains(name) { for var in globals { interface.push(var.name); } } } _ => {} } } global_variables.append(&mut interface); builder.entry_point(spirv::ExecutionModel::Kernel, fn_id, name, global_variables); fn_id } MethodName::Func(name) => name, }; builder.begin_function( ret_type, Some(fn_id), spirv::FunctionControl::NONE, func_type, )?; // TODO: re-enable when Intel float control extension works /* if let Some(denorm_modes) = denorm_information.get(&func_decl.name) { for (size_of, denorm_mode) in denorm_modes { builder.decorate( fn_id, spirv::Decoration::FunctionDenormModeINTEL, [ dr::Operand::LiteralInt32((*size_of as u32) * 8), dr::Operand::FPDenormMode(*denorm_mode), ], ) } } */ for input in &func_decl.input { let result_type = map.get_or_add(builder, SpirvType::from(input.v_type.clone())); let inst = dr::Instruction::new( spirv::Op::FunctionParameter, Some(result_type), Some(input.name), Vec::new(), ); builder.function.as_mut().unwrap().parameters.push(inst); } Ok(()) } fn emit_capabilities(builder: &mut dr::Builder) { builder.capability(spirv::Capability::GenericPointer); builder.capability(spirv::Capability::Linkage); builder.capability(spirv::Capability::Addresses); builder.capability(spirv::Capability::Kernel); builder.capability(spirv::Capability::Int8); builder.capability(spirv::Capability::Int16); builder.capability(spirv::Capability::Int64); builder.capability(spirv::Capability::Float16); builder.capability(spirv::Capability::Float64); // TODO: re-enable when Intel float control extension works //builder.capability(spirv::Capability::FunctionFloatControlINTEL); } // http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html fn emit_extensions(_builder: &mut dr::Builder) { // TODO: re-enable when Intel float control extension works //builder.extension("SPV_INTEL_float_controls2"); } fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word { builder.ext_inst_import("OpenCL.std") } fn emit_memory_model(builder: &mut dr::Builder) { builder.memory_model( spirv::AddressingModel::Physical64, spirv::MemoryModel::OpenCL, ); } fn translate_directive<'input>( id_defs: &mut GlobalStringIdResolver<'input>, ptx_impl_imports: &mut HashMap>, d: ast::Directive<'input, ast::ParsedArgParams<'input>>, ) -> Result>, TranslateError> { Ok(match d { ast::Directive::Variable(v) => Some(Directive::Variable(translate_variable(id_defs, v)?)), ast::Directive::Method(f) => { translate_function(id_defs, ptx_impl_imports, f)?.map(Directive::Method) } }) } fn translate_variable<'a>( id_defs: &mut GlobalStringIdResolver<'a>, var: ast::Variable, ) -> Result, TranslateError> { let (space, var_type) = var.v_type.to_type(); let mut is_variable = false; let var_type = match space { ast::StateSpace::Reg => { is_variable = true; var_type } ast::StateSpace::Const => var_type.param_pointer_to(ast::LdStateSpace::Const)?, ast::StateSpace::Global => var_type.param_pointer_to(ast::LdStateSpace::Global)?, ast::StateSpace::Local => var_type.param_pointer_to(ast::LdStateSpace::Local)?, ast::StateSpace::Shared => { // If it's a pointer it will be translated to a method parameter later if let ast::Type::Pointer(..) = var_type { is_variable = true; var_type } else { var_type.param_pointer_to(ast::LdStateSpace::Shared)? } } ast::StateSpace::Param => var_type.param_pointer_to(ast::LdStateSpace::Param)?, }; Ok(ast::Variable { align: var.align, v_type: var.v_type, name: id_defs.get_or_add_def_typed(var.name, var_type, is_variable), array_init: var.array_init, }) } fn translate_function<'a>( id_defs: &mut GlobalStringIdResolver<'a>, ptx_impl_imports: &mut HashMap>, f: ast::ParsedFunction<'a>, ) -> Result>, TranslateError> { let import_as = match &f.func_directive { ast::MethodDecl::Func(_, "__assertfail", _) => { Some("__zluda_ptx_impl____assertfail".to_owned()) } _ => None, }; let (str_resolver, fn_resolver, fn_decl) = id_defs.start_fn(&f.func_directive)?; let mut func = to_ssa(ptx_impl_imports, str_resolver, fn_resolver, fn_decl, f.body)?; func.import_as = import_as; if func.import_as.is_some() { ptx_impl_imports.insert( func.import_as.as_ref().unwrap().clone(), Directive::Method(func), ); Ok(None) } else { Ok(Some(func)) } } fn expand_kernel_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, args: impl Iterator>, ) -> Result>, TranslateError> { args.map(|a| { Ok(ast::KernelArgument { name: fn_resolver.add_def( a.name, Some(ast::Type::from(a.v_type.clone()).param_pointer_to(ast::LdStateSpace::Param)?), false, ), v_type: a.v_type.clone(), align: a.align, array_init: Vec::new(), }) }) .collect::>() } fn expand_fn_params<'a, 'b>( fn_resolver: &mut FnStringIdResolver<'a, 'b>, args: impl Iterator>, ) -> Result>, TranslateError> { args.map(|a| { let is_variable = match a.v_type { ast::FnArgumentType::Reg(_) => true, _ => false, }; let var_type = a.v_type.to_func_type(); Ok(ast::FnArgument { name: fn_resolver.add_def(a.name, Some(var_type), is_variable), v_type: a.v_type.clone(), align: a.align, array_init: Vec::new(), }) }) .collect() } fn to_ssa<'input, 'b>( ptx_impl_imports: &mut HashMap, mut id_defs: FnStringIdResolver<'input, 'b>, fn_defs: GlobalFnDeclResolver<'input, 'b>, f_args: ast::MethodDecl<'input, spirv::Word>, f_body: Option>>>, ) -> Result, TranslateError> { let mut spirv_decl = SpirvMethodDecl::new(&f_args); let f_body = match f_body { Some(vec) => vec, None => { return Ok(Function { func_decl: f_args, body: None, globals: Vec::new(), import_as: None, spirv_decl, }) } }; let normalized_ids = normalize_identifiers(&mut id_defs, &fn_defs, f_body)?; let mut numeric_id_defs = id_defs.finish(); let unadorned_statements = normalize_predicates(normalized_ids, &mut numeric_id_defs)?; let typed_statements = convert_to_typed_statements(unadorned_statements, &fn_defs, &mut numeric_id_defs)?; let typed_statements = convert_to_stateful_memory_access(&mut spirv_decl, typed_statements, &mut numeric_id_defs)?; let ssa_statements = insert_mem_ssa_statements( typed_statements, &mut numeric_id_defs, &f_args, &mut spirv_decl, )?; let ssa_statements = fix_builtins(ssa_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.finish(); let expanded_statements = expand_arguments(ssa_statements, &mut numeric_id_defs)?; let expanded_statements = insert_implicit_conversions(expanded_statements, &mut numeric_id_defs)?; let mut numeric_id_defs = numeric_id_defs.unmut(); let labeled_statements = normalize_labels(expanded_statements, &mut numeric_id_defs); let (f_body, globals) = extract_globals(labeled_statements, ptx_impl_imports, &mut numeric_id_defs); Ok(Function { func_decl: f_args, globals: globals, body: Some(f_body), import_as: None, spirv_decl, }) } fn fix_builtins( typed_statements: Vec, numeric_id_defs: &mut NumericIdResolver, ) -> Result, TranslateError> { let mut result = Vec::with_capacity(typed_statements.len()); for s in typed_statements { match s { Statement::LoadVar( mut details @ LoadVarDetails { member_index: Some((_, Some(_))), .. }, ) => { let index = details.member_index.unwrap().0; if index == 3 { result.push(Statement::Constant(ConstantDefinition { dst: details.arg.dst, typ: ast::ScalarType::U32, value: ast::ImmediateValue::U64(0), })); } else { let sreg_and_type = match numeric_id_defs.special_registers.get(details.arg.src) { Some(reg) => get_sreg_id_scalar_type(numeric_id_defs, reg), None => None, }; let (sreg_src, scalar_typ, vector_width) = match sreg_and_type { Some(sreg_and_type) => sreg_and_type, None => { result.push(Statement::LoadVar(details)); continue; } }; let temp_id = numeric_id_defs.new_non_variable(Some(details.typ.clone())); let real_dst = details.arg.dst; details.arg.dst = temp_id; result.push(Statement::LoadVar(LoadVarDetails { arg: Arg2 { src: sreg_src, dst: temp_id, }, typ: ast::Type::Scalar(scalar_typ), member_index: Some((index, Some(vector_width))), })); result.push(Statement::Conversion(ImplicitConversion { src: temp_id, dst: real_dst, from: ast::Type::Scalar(scalar_typ), to: ast::Type::Scalar(ast::ScalarType::U32), kind: ConversionKind::Default, src_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default, })); } } s => result.push(s), } } Ok(result) } fn get_sreg_id_scalar_type( numeric_id_defs: &mut NumericIdResolver, sreg: PtxSpecialRegister, ) -> Option<(spirv::Word, ast::ScalarType, u8)> { match sreg.normalized_sreg_and_type() { Some((normalized_sreg, typ, vec_width)) => Some(( numeric_id_defs .special_registers .get_or_add(numeric_id_defs.current_id, normalized_sreg), typ, vec_width, )), None => None, } } fn extract_globals<'input, 'b>( sorted_statements: Vec, ptx_impl_imports: &mut HashMap, id_def: &mut NumericIdResolver, ) -> ( Vec, Vec>, ) { let mut local = Vec::with_capacity(sorted_statements.len()); let mut global = Vec::new(); for statement in sorted_statements { match statement { Statement::Variable( var @ ast::Variable { v_type: ast::VariableType::Shared(_), .. }, ) | Statement::Variable( var @ ast::Variable { v_type: ast::VariableType::Global(_), .. }, ) => global.push(var), Statement::Instruction(ast::Instruction::Bfe { typ, arg }) => { local.push(to_ptx_impl_bfe_call(id_def, ptx_impl_imports, typ, arg)); } Statement::Instruction(ast::Instruction::Atom( d @ ast::AtomDetails { inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Inc, .. }, .. }, a, )) => { local.push(to_ptx_impl_atomic_call( id_def, ptx_impl_imports, d, a, "inc", )); } Statement::Instruction(ast::Instruction::Atom( d @ ast::AtomDetails { inner: ast::AtomInnerDetails::Unsigned { op: ast::AtomUIntOp::Dec, .. }, .. }, a, )) => { local.push(to_ptx_impl_atomic_call( id_def, ptx_impl_imports, d, a, "dec", )); } s => local.push(s), } } (local, global) } fn normalize_variable_decls(directives: &mut Vec) { for directive in directives { match directive { Directive::Method(Function { body: Some(func), .. }) => { func[1..].sort_by_key(|s| match s { Statement::Variable(_) => 0, _ => 1, }); } _ => (), } } } fn convert_to_typed_statements( func: Vec, fn_defs: &GlobalFnDeclResolver, id_defs: &mut NumericIdResolver, ) -> Result, TranslateError> { let mut result = Vec::::with_capacity(func.len()); for s in func { match s { Statement::Instruction(inst) => match inst { ast::Instruction::Call(call) => { // TODO: error out if lengths don't match let fn_def = fn_defs.get_fn_decl(call.func)?; let out_args = to_resolved_fn_args(call.ret_params, &*fn_def.ret_vals); let in_args = to_resolved_fn_args(call.param_list, &*fn_def.params); let (out_params, out_non_params): (Vec<_>, Vec<_>) = out_args .into_iter() .partition(|(_, arg_type)| arg_type.is_param()); let normalized_input_args = out_params .into_iter() .map(|(id, typ)| (ast::Operand::Reg(id), typ)) .chain(in_args.into_iter()) .collect(); let resolved_call = ResolvedCall { uniform: call.uniform, ret_params: out_non_params, func: call.func, param_list: normalized_input_args, }; let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let reresolved_call = resolved_call.visit(&mut visitor)?; visitor.func.push(reresolved_call); visitor.func.extend(visitor.post_stmts); } ast::Instruction::Mov(mut d, ast::Arg2Mov { dst, src }) => { if let Some(src_id) = src.underlying() { let (typ, _) = id_defs.get_typed(*src_id)?; let take_address = match typ { ast::Type::Scalar(_) => false, ast::Type::Vector(_, _) => false, ast::Type::Array(_, _) => true, ast::Type::Pointer(_, _) => true, }; d.src_is_address = take_address; } let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let instruction = Statement::Instruction( ast::Instruction::Mov(d, ast::Arg2Mov { dst, src }).map(&mut visitor)?, ); visitor.func.push(instruction); visitor.func.extend(visitor.post_stmts); } inst => { let mut visitor = VectorRepackVisitor::new(&mut result, id_defs); let instruction = Statement::Instruction(inst.map(&mut visitor)?); visitor.func.push(instruction); visitor.func.extend(visitor.post_stmts); } }, Statement::Label(i) => result.push(Statement::Label(i)), Statement::Variable(v) => result.push(Statement::Variable(v)), Statement::Conditional(c) => result.push(Statement::Conditional(c)), _ => return Err(error_unreachable()), } } Ok(result) } struct VectorRepackVisitor<'a, 'b> { func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>, post_stmts: Option, } impl<'a, 'b> VectorRepackVisitor<'a, 'b> { fn new(func: &'b mut Vec, id_def: &'b mut NumericIdResolver<'a>) -> Self { VectorRepackVisitor { func, id_def, post_stmts: None, } } fn convert_vector( &mut self, is_dst: bool, vector_sema: ArgumentSemantics, typ: &ast::Type, idx: Vec, ) -> Result { // mov.u32 foobar, {a,b}; let scalar_t = match typ { ast::Type::Vector(scalar_t, _) => *scalar_t, _ => return Err(TranslateError::MismatchedType), }; let temp_vec = self.id_def.new_non_variable(Some(typ.clone())); let statement = Statement::RepackVector(RepackVectorDetails { is_extract: is_dst, typ: scalar_t, packed: temp_vec, unpacked: idx, vector_sema, }); if is_dst { self.post_stmts = Some(statement); } else { self.func.push(statement); } Ok(temp_vec) } } impl<'a, 'b> ArgumentMapVisitor for VectorRepackVisitor<'a, 'b> { fn id( &mut self, desc: ArgumentDescriptor, _: Option<&ast::Type>, ) -> Result { Ok(desc.op) } fn operand( &mut self, desc: ArgumentDescriptor>, typ: &ast::Type, ) -> Result { Ok(match desc.op { ast::Operand::Reg(reg) => TypedOperand::Reg(reg), ast::Operand::RegOffset(reg, offset) => TypedOperand::RegOffset(reg, offset), ast::Operand::Imm(x) => TypedOperand::Imm(x), ast::Operand::VecMember(vec, idx) => TypedOperand::VecMember(vec, idx), ast::Operand::VecPack(vec) => { TypedOperand::Reg(self.convert_vector(desc.is_dst, desc.sema, typ, vec)?) } }) } } //TODO: share common code between this and to_ptx_impl_bfe_call fn to_ptx_impl_atomic_call( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, details: ast::AtomDetails, arg: ast::Arg3, op: &'static str, ) -> ExpandedStatement { let semantics = ptx_semantics_name(details.semantics); let scope = ptx_scope_name(details.scope); let space = ptx_space_name(details.space); let fn_name = format!( "__zluda_ptx_impl__atom_{}_{}_{}_{}", semantics, scope, space, op ); // TODO: extract to a function let ptr_space = match details.space { ast::AtomSpace::Generic => ast::PointerStateSpace::Generic, ast::AtomSpace::Global => ast::PointerStateSpace::Global, ast::AtomSpace::Shared => ast::PointerStateSpace::Shared, }; let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { let fn_id = id_defs.new_non_variable(None); let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( ast::ScalarType::U32, )), name: id_defs.new_non_variable(None), array_init: Vec::new(), }], fn_id, vec![ ast::FnArgument { align: None, v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Pointer( ast::SizedScalarType::U32, ptr_space, )), name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( ast::ScalarType::U32, )), name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ], ); let spirv_decl = SpirvMethodDecl::new(&func_decl); let func = Function { func_decl, globals: Vec::new(), body: None, import_as: Some(entry.key().clone()), spirv_decl, }; entry.insert(Directive::Method(func)); fn_id } hash_map::Entry::Occupied(entry) => match entry.get() { Directive::Method(Function { func_decl: ast::MethodDecl::Func(_, name, _), .. }) => *name, _ => unreachable!(), }, }; Statement::Call(ResolvedCall { uniform: false, func: fn_id, ret_params: vec![( arg.dst, ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), )], param_list: vec![ ( arg.src1, ast::FnArgumentType::Reg(ast::VariableRegType::Pointer( ast::SizedScalarType::U32, ptr_space, )), ), ( arg.src2, ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), ), ], }) } fn to_ptx_impl_bfe_call( id_defs: &mut NumericIdResolver, ptx_impl_imports: &mut HashMap, typ: ast::IntType, arg: ast::Arg4, ) -> ExpandedStatement { let prefix = "__zluda_ptx_impl__"; let suffix = match typ { ast::IntType::U32 => "bfe_u32", ast::IntType::U64 => "bfe_u64", ast::IntType::S32 => "bfe_s32", ast::IntType::S64 => "bfe_s64", _ => unreachable!(), }; let fn_name = format!("{}{}", prefix, suffix); let fn_id = match ptx_impl_imports.entry(fn_name) { hash_map::Entry::Vacant(entry) => { let fn_id = id_defs.new_non_variable(None); let func_decl = ast::MethodDecl::Func::( vec![ast::FnArgument { align: None, v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), name: id_defs.new_non_variable(None), array_init: Vec::new(), }], fn_id, vec![ ast::FnArgument { align: None, v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( ast::ScalarType::U32, )), name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ast::FnArgument { align: None, v_type: ast::FnArgumentType::Reg(ast::VariableRegType::Scalar( ast::ScalarType::U32, )), name: id_defs.new_non_variable(None), array_init: Vec::new(), }, ], ); let spirv_decl = SpirvMethodDecl::new(&func_decl); let func = Function { func_decl, globals: Vec::new(), body: None, import_as: Some(entry.key().clone()), spirv_decl, }; entry.insert(Directive::Method(func)); fn_id } hash_map::Entry::Occupied(entry) => match entry.get() { Directive::Method(Function { func_decl: ast::MethodDecl::Func(_, name, _), .. }) => *name, _ => unreachable!(), }, }; Statement::Call(ResolvedCall { uniform: false, func: fn_id, ret_params: vec![( arg.dst, ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), )], param_list: vec![ ( arg.src1, ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(typ.into())), ), ( arg.src2, ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), ), ( arg.src3, ast::FnArgumentType::Reg(ast::VariableRegType::Scalar(ast::ScalarType::U32)), ), ], }) } fn to_resolved_fn_args( params: Vec, params_decl: &[ast::FnArgumentType], ) -> Vec<(T, ast::FnArgumentType)> { params .into_iter() .zip(params_decl.iter()) .map(|(id, typ)| (id, typ.clone())) .collect::>() } fn normalize_labels( func: Vec, id_def: &mut NumericIdResolver, ) -> Vec { let mut labels_in_use = HashSet::new(); for s in func.iter() { match s { Statement::Instruction(i) => { if let Some(target) = i.jump_target() { labels_in_use.insert(target); } } Statement::Conditional(cond) => { labels_in_use.insert(cond.if_true); labels_in_use.insert(cond.if_false); } Statement::Call(..) | Statement::Variable(..) | Statement::LoadVar(..) | Statement::StoreVar(..) | Statement::RetValue(..) | Statement::Conversion(..) | Statement::Constant(..) | Statement::Label(..) | Statement::PtrAccess { .. } | Statement::RepackVector(..) => {} } } iter::once(Statement::Label(id_def.new_non_variable(None))) .chain(func.into_iter().filter(|s| match s { Statement::Label(i) => labels_in_use.contains(i), _ => true, })) .collect::>() } fn normalize_predicates( func: Vec, id_def: &mut NumericIdResolver, ) -> Result, TranslateError> { let mut result = Vec::with_capacity(func.len()); for s in func { match s { Statement::Label(id) => result.push(Statement::Label(id)), Statement::Instruction((pred, inst)) => { if let Some(pred) = pred { let if_true = id_def.new_non_variable(None); let if_false = id_def.new_non_variable(None); let folded_bra = match &inst { ast::Instruction::Bra(_, arg) => Some(arg.src), _ => None, }; let mut branch = BrachCondition { predicate: pred.label, if_true: folded_bra.unwrap_or(if_true), if_false, }; if pred.not { std::mem::swap(&mut branch.if_true, &mut branch.if_false); } result.push(Statement::Conditional(branch)); if folded_bra.is_none() { result.push(Statement::Label(if_true)); result.push(Statement::Instruction(inst)); } result.push(Statement::Label(if_false)); } else { result.push(Statement::Instruction(inst)); } } Statement::Variable(var) => result.push(Statement::Variable(var)), // Blocks are flattened when resolving ids _ => return Err(error_unreachable()), } } Ok(result) } fn insert_mem_ssa_statements<'a, 'b>( func: Vec, id_def: &mut NumericIdResolver, ast_fn_decl: &'a ast::MethodDecl<'b, spirv::Word>, fn_decl: &mut SpirvMethodDecl, ) -> Result, TranslateError> { let is_func = match ast_fn_decl { ast::MethodDecl::Func(..) => true, ast::MethodDecl::Kernel { .. } => false, }; let mut result = Vec::with_capacity(func.len()); for arg in fn_decl.output.iter() { match type_to_variable_type(&arg.v_type, is_func)? { Some(var_type) => { result.push(Statement::Variable(ast::Variable { align: arg.align, v_type: var_type, name: arg.name, array_init: arg.array_init.clone(), })); } None => return Err(error_unreachable()), } } for spirv_arg in fn_decl.input.iter_mut() { match type_to_variable_type(&spirv_arg.v_type, is_func)? { Some(var_type) => { let typ = spirv_arg.v_type.clone(); let new_id = id_def.new_non_variable(Some(typ.clone())); result.push(Statement::Variable(ast::Variable { align: spirv_arg.align, v_type: var_type, name: spirv_arg.name, array_init: spirv_arg.array_init.clone(), })); result.push(Statement::StoreVar(StoreVarDetails { arg: ast::Arg2St { src1: spirv_arg.name, src2: new_id, }, typ, member_index: None, })); spirv_arg.name = new_id; } None => {} } } for s in func { match s { Statement::Call(call) => { insert_mem_ssa_statement_default(id_def, &mut result, call.cast())? } Statement::Instruction(inst) => match inst { ast::Instruction::Ret(d) => { // TODO: handle multiple output args if let &[out_param] = &fn_decl.output.as_slice() { let (typ, _) = id_def.get_typed(out_param.name)?; let new_id = id_def.new_non_variable(Some(typ.clone())); result.push(Statement::LoadVar(LoadVarDetails { arg: ast::Arg2 { dst: new_id, src: out_param.name, }, typ: typ.clone(), member_index: None, })); result.push(Statement::RetValue(d, new_id)); } else { result.push(Statement::Instruction(ast::Instruction::Ret(d))) } } inst => insert_mem_ssa_statement_default(id_def, &mut result, inst)?, }, Statement::Conditional(mut bra) => { let generated_id = id_def.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::Pred))); result.push(Statement::LoadVar(LoadVarDetails { arg: Arg2 { dst: generated_id, src: bra.predicate, }, typ: ast::Type::Scalar(ast::ScalarType::Pred), member_index: None, })); bra.predicate = generated_id; result.push(Statement::Conditional(bra)); } Statement::Conversion(conv) => { insert_mem_ssa_statement_default(id_def, &mut result, conv)? } Statement::PtrAccess(ptr_access) => { insert_mem_ssa_statement_default(id_def, &mut result, ptr_access)? } Statement::RepackVector(repack) => { insert_mem_ssa_statement_default(id_def, &mut result, repack)? } s @ Statement::Variable(_) | s @ Statement::Label(_) => result.push(s), _ => return Err(error_unreachable()), } } Ok(result) } fn type_to_variable_type( t: &ast::Type, is_func: bool, ) -> Result, TranslateError> { Ok(match t { ast::Type::Scalar(typ) => Some(ast::VariableType::Reg(ast::VariableRegType::Scalar(*typ))), ast::Type::Vector(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Vector( (*typ) .try_into() .map_err(|_| TranslateError::MismatchedType)?, *len, ))), ast::Type::Array(typ, len) => Some(ast::VariableType::Reg(ast::VariableRegType::Array( (*typ) .try_into() .map_err(|_| TranslateError::MismatchedType)?, len.clone(), ))), ast::Type::Pointer(ast::PointerType::Scalar(scalar_type), space) => { if is_func { return Ok(None); } Some(ast::VariableType::Reg(ast::VariableRegType::Pointer( scalar_type .clone() .try_into() .map_err(|_| error_unreachable())?, (*space).try_into().map_err(|_| error_unreachable())?, ))) } ast::Type::Pointer(_, ast::LdStateSpace::Shared) => None, _ => return Err(error_unreachable()), }) } trait Visitable: Sized { fn visit( self, visitor: &mut impl ArgumentMapVisitor, ) -> Result, To>, TranslateError>; } struct VisitArgumentDescriptor< 'a, Ctor: FnOnce(spirv::Word) -> Statement, U>, U: ArgParamsEx, > { desc: ArgumentDescriptor, typ: &'a ast::Type, stmt_ctor: Ctor, } impl< 'a, Ctor: FnOnce(spirv::Word) -> Statement, U>, T: ArgParamsEx, U: ArgParamsEx, > Visitable for VisitArgumentDescriptor<'a, Ctor, U> { fn visit( self, visitor: &mut impl ArgumentMapVisitor, ) -> Result, U>, TranslateError> { Ok((self.stmt_ctor)(visitor.id(self.desc, Some(self.typ))?)) } } struct InsertMemSSAVisitor<'a, 'input> { id_def: &'a mut NumericIdResolver<'input>, func: &'a mut Vec, post_statements: Vec, } impl<'a, 'input> InsertMemSSAVisitor<'a, 'input> { fn symbol( &mut self, desc: ArgumentDescriptor<(spirv::Word, Option)>, expected_type: Option<&ast::Type>, ) -> Result { let symbol = desc.op.0; if expected_type.is_none() { return Ok(symbol); }; let (mut var_type, is_variable) = self.id_def.get_typed(symbol)?; if !is_variable { return Ok(symbol); }; let member_index = match desc.op.1 { Some(idx) => { let vector_width = match var_type { ast::Type::Vector(scalar_t, width) => { var_type = ast::Type::Scalar(scalar_t); width } _ => return Err(TranslateError::MismatchedType), }; Some(( idx, if self.id_def.special_registers.get(symbol).is_some() { Some(vector_width) } else { None }, )) } None => None, }; let generated_id = self.id_def.new_non_variable(Some(var_type.clone())); if !desc.is_dst { self.func.push(Statement::LoadVar(LoadVarDetails { arg: Arg2 { dst: generated_id, src: symbol, }, typ: var_type, member_index, })); } else { self.post_statements .push(Statement::StoreVar(StoreVarDetails { arg: Arg2St { src1: symbol, src2: generated_id, }, typ: var_type, member_index: member_index.map(|(idx, _)| idx), })); } Ok(generated_id) } } impl<'a, 'input> ArgumentMapVisitor for InsertMemSSAVisitor<'a, 'input> { fn id( &mut self, desc: ArgumentDescriptor, typ: Option<&ast::Type>, ) -> Result { self.symbol(desc.new_op((desc.op, None)), typ) } fn operand( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result { Ok(match desc.op { TypedOperand::Reg(reg) => { TypedOperand::Reg(self.symbol(desc.new_op((reg, None)), Some(typ))?) } TypedOperand::RegOffset(reg, offset) => { TypedOperand::RegOffset(self.symbol(desc.new_op((reg, None)), Some(typ))?, offset) } op @ TypedOperand::Imm(..) => op, TypedOperand::VecMember(symbol, index) => { TypedOperand::Reg(self.symbol(desc.new_op((symbol, Some(index))), Some(typ))?) } }) } } fn insert_mem_ssa_statement_default<'a, 'input, S: Visitable>( id_def: &'a mut NumericIdResolver<'input>, func: &'a mut Vec, stmt: S, ) -> Result<(), TranslateError> { let mut visitor = InsertMemSSAVisitor { id_def, func, post_statements: Vec::new(), }; let new_stmt = stmt.visit(&mut visitor)?; visitor.func.push(new_stmt); visitor.func.extend(visitor.post_statements); Ok(()) } fn expand_arguments<'a, 'b>( func: Vec, id_def: &'b mut MutableNumericIdResolver<'a>, ) -> Result, TranslateError> { let mut result = Vec::with_capacity(func.len()); for s in func { match s { Statement::Call(call) => { let mut visitor = FlattenArguments::new(&mut result, id_def); let (new_call, post_stmts) = (call.map(&mut visitor)?, visitor.post_stmts); result.push(Statement::Call(new_call)); result.extend(post_stmts); } Statement::Instruction(inst) => { let mut visitor = FlattenArguments::new(&mut result, id_def); let (new_inst, post_stmts) = (inst.map(&mut visitor)?, visitor.post_stmts); result.push(Statement::Instruction(new_inst)); result.extend(post_stmts); } Statement::Variable(ast::Variable { align, v_type, name, array_init, }) => result.push(Statement::Variable(ast::Variable { align, v_type, name, array_init, })), Statement::PtrAccess(ptr_access) => { let mut visitor = FlattenArguments::new(&mut result, id_def); let (new_inst, post_stmts) = (ptr_access.map(&mut visitor)?, visitor.post_stmts); result.push(Statement::PtrAccess(new_inst)); result.extend(post_stmts); } Statement::RepackVector(repack) => { let mut visitor = FlattenArguments::new(&mut result, id_def); let (new_inst, post_stmts) = (repack.map(&mut visitor)?, visitor.post_stmts); result.push(Statement::RepackVector(new_inst)); result.extend(post_stmts); } Statement::Label(id) => result.push(Statement::Label(id)), Statement::Conditional(bra) => result.push(Statement::Conditional(bra)), Statement::LoadVar(details) => result.push(Statement::LoadVar(details)), Statement::StoreVar(details) => result.push(Statement::StoreVar(details)), Statement::RetValue(d, id) => result.push(Statement::RetValue(d, id)), Statement::Conversion(conv) => result.push(Statement::Conversion(conv)), Statement::Constant(_) => return Err(error_unreachable()), } } Ok(result) } struct FlattenArguments<'a, 'b> { func: &'b mut Vec, id_def: &'b mut MutableNumericIdResolver<'a>, post_stmts: Vec, } impl<'a, 'b> FlattenArguments<'a, 'b> { fn new( func: &'b mut Vec, id_def: &'b mut MutableNumericIdResolver<'a>, ) -> Self { FlattenArguments { func, id_def, post_stmts: Vec::new(), } } fn reg( &mut self, desc: ArgumentDescriptor, _: Option<&ast::Type>, ) -> Result { Ok(desc.op) } fn reg_offset( &mut self, desc: ArgumentDescriptor<(spirv::Word, i32)>, typ: &ast::Type, ) -> Result { let (reg, offset) = desc.op; let add_type; match typ { ast::Type::Pointer(underlying_type, state_space) => { let reg_typ = self.id_def.get_typed(reg)?; if let ast::Type::Pointer(_, _) = reg_typ { let id_constant_stmt = self.id_def.new_non_variable(typ.clone()); self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, typ: ast::ScalarType::S64, value: ast::ImmediateValue::S64(offset as i64), })); let dst = self.id_def.new_non_variable(typ.clone()); self.func.push(Statement::PtrAccess(PtrAccess { underlying_type: underlying_type.clone(), state_space: *state_space, dst, ptr_src: reg, offset_src: id_constant_stmt, })); return Ok(dst); } else { add_type = self.id_def.get_typed(reg)?; } } _ => { add_type = typ.clone(); } }; let (width, kind) = match add_type { ast::Type::Scalar(scalar_t) => { let kind = match scalar_t.kind() { kind @ ScalarKind::Bit | kind @ ScalarKind::Unsigned | kind @ ScalarKind::Signed => kind, ScalarKind::Float => return Err(TranslateError::MismatchedType), ScalarKind::Float2 => return Err(TranslateError::MismatchedType), ScalarKind::Pred => return Err(TranslateError::MismatchedType), }; (scalar_t.size_of(), kind) } _ => return Err(TranslateError::MismatchedType), }; let arith_detail = if kind == ScalarKind::Signed { ast::ArithDetails::Signed(ast::ArithSInt { typ: ast::SIntType::from_size(width), saturate: false, }) } else { ast::ArithDetails::Unsigned(ast::UIntType::from_size(width)) }; let id_constant_stmt = self.id_def.new_non_variable(add_type.clone()); let result_id = self.id_def.new_non_variable(add_type); // TODO: check for edge cases around min value/max value/wrapping if offset < 0 && kind != ScalarKind::Signed { self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, typ: ast::ScalarType::from_parts(width, kind), value: ast::ImmediateValue::U64(-(offset as i64) as u64), })); self.func.push(Statement::Instruction( ast::Instruction::::Sub( arith_detail, ast::Arg3 { dst: result_id, src1: reg, src2: id_constant_stmt, }, ), )); } else { self.func.push(Statement::Constant(ConstantDefinition { dst: id_constant_stmt, typ: ast::ScalarType::from_parts(width, kind), value: ast::ImmediateValue::S64(offset as i64), })); self.func.push(Statement::Instruction( ast::Instruction::::Add( arith_detail, ast::Arg3 { dst: result_id, src1: reg, src2: id_constant_stmt, }, ), )); } Ok(result_id) } fn immediate( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result { let scalar_t = if let ast::Type::Scalar(scalar) = typ { *scalar } else { todo!() }; let id = self.id_def.new_non_variable(ast::Type::Scalar(scalar_t)); self.func.push(Statement::Constant(ConstantDefinition { dst: id, typ: scalar_t, value: desc.op, })); Ok(id) } } impl<'a, 'b> ArgumentMapVisitor for FlattenArguments<'a, 'b> { fn id( &mut self, desc: ArgumentDescriptor, t: Option<&ast::Type>, ) -> Result { self.reg(desc, t) } fn operand( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result { match desc.op { TypedOperand::Reg(r) => self.reg(desc.new_op(r), Some(typ)), TypedOperand::Imm(x) => self.immediate(desc.new_op(x), typ), TypedOperand::RegOffset(reg, offset) => { self.reg_offset(desc.new_op((reg, offset)), typ) } TypedOperand::VecMember(..) => Err(error_unreachable()), } } } /* There are several kinds of implicit conversions in PTX: * auto-bitcast: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#type-information-for-instructions-and-operands * special ld/st/cvt conversion rules: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size - ld.param: not documented, but for instruction `ld.param. x, [y]`, semantics are to first zext/chop/bitcast `y` as needed and then do documented special ld/st/cvt conversion rules for destination operands - st.param [x] y (used as function return arguments) same rule as above applies - generic/global ld: for instruction `ld x, [y]`, y must be of type b64/u64/s64, which is bitcast to a pointer, dereferenced and then documented special ld/st/cvt conversion rules are applied to dst - generic/global st: for instruction `st [x], y`, x must be of type b64/u64/s64, which is bitcast to a pointer */ fn insert_implicit_conversions( func: Vec, id_def: &mut MutableNumericIdResolver, ) -> Result, TranslateError> { let mut result = Vec::with_capacity(func.len()); for s in func.into_iter() { match s { Statement::Call(call) => insert_implicit_conversions_impl( &mut result, id_def, call, should_bitcast_wrapper, None, )?, Statement::Instruction(inst) => { let mut default_conversion_fn = should_bitcast_wrapper as for<'a> fn(&'a ast::Type, &'a ast::Type, _) -> _; let mut state_space = None; if let ast::Instruction::Ld(d, _) = &inst { state_space = Some(d.state_space); } if let ast::Instruction::St(d, _) = &inst { state_space = Some(d.state_space.to_ld_ss()); } if let ast::Instruction::Atom(d, _) = &inst { state_space = Some(d.space.to_ld_ss()); } if let ast::Instruction::AtomCas(d, _) = &inst { state_space = Some(d.space.to_ld_ss()); } if let ast::Instruction::Mov(..) = &inst { default_conversion_fn = should_bitcast_packed; } insert_implicit_conversions_impl( &mut result, id_def, inst, default_conversion_fn, state_space, )?; } Statement::PtrAccess(PtrAccess { underlying_type, state_space, dst, ptr_src, offset_src: constant_src, }) => { let visit_desc = VisitArgumentDescriptor { desc: ArgumentDescriptor { op: ptr_src, is_dst: false, sema: ArgumentSemantics::PhysicalPointer, }, typ: &ast::Type::Pointer(underlying_type.clone(), state_space), stmt_ctor: |new_ptr_src| { Statement::PtrAccess(PtrAccess { underlying_type, state_space, dst, ptr_src: new_ptr_src, offset_src: constant_src, }) }, }; insert_implicit_conversions_impl( &mut result, id_def, visit_desc, bitcast_physical_pointer, Some(state_space), )?; } Statement::RepackVector(repack) => insert_implicit_conversions_impl( &mut result, id_def, repack, should_bitcast_wrapper, None, )?, s @ Statement::Conditional(_) | s @ Statement::Conversion(_) | s @ Statement::Label(_) | s @ Statement::Constant(_) | s @ Statement::Variable(_) | s @ Statement::LoadVar(..) | s @ Statement::StoreVar(..) | s @ Statement::RetValue(_, _) => result.push(s), } } Ok(result) } fn insert_implicit_conversions_impl( func: &mut Vec, id_def: &mut MutableNumericIdResolver, stmt: impl Visitable, default_conversion_fn: for<'a> fn( &'a ast::Type, &'a ast::Type, Option, ) -> Result, TranslateError>, state_space: Option, ) -> Result<(), TranslateError> { let mut post_conv = Vec::new(); let statement = stmt.visit( &mut |desc: ArgumentDescriptor, typ: Option<&ast::Type>| { let instr_type = match typ { None => return Ok(desc.op), Some(t) => t, }; let operand_type = id_def.get_typed(desc.op)?; let mut conversion_fn = default_conversion_fn; match desc.sema { ArgumentSemantics::Default => {} ArgumentSemantics::DefaultRelaxed => { if desc.is_dst { conversion_fn = should_convert_relaxed_dst_wrapper; } else { conversion_fn = should_convert_relaxed_src_wrapper; } } ArgumentSemantics::PhysicalPointer => { conversion_fn = bitcast_physical_pointer; } ArgumentSemantics::RegisterPointer => { conversion_fn = bitcast_register_pointer; } ArgumentSemantics::Address => { conversion_fn = force_bitcast_ptr_to_bit; } }; match conversion_fn(&operand_type, instr_type, state_space)? { Some(conv_kind) => { let conv_output = if desc.is_dst { &mut post_conv } else { &mut *func }; let mut from = instr_type.clone(); let mut to = operand_type; let mut src = id_def.new_non_variable(instr_type.clone()); let mut dst = desc.op; let result = Ok(src); if !desc.is_dst { mem::swap(&mut src, &mut dst); mem::swap(&mut from, &mut to); } conv_output.push(Statement::Conversion(ImplicitConversion { src, dst, from, to, kind: conv_kind, src_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default, })); result } None => Ok(desc.op), } }, )?; func.push(statement); func.append(&mut post_conv); Ok(()) } fn get_function_type( builder: &mut dr::Builder, map: &mut TypeWordMap, spirv_input: &[ast::Variable], spirv_output: &[ast::Variable], ) -> (spirv::Word, spirv::Word) { map.get_or_add_fn( builder, spirv_input .iter() .map(|var| SpirvType::from(var.v_type.clone())), spirv_output .iter() .map(|var| SpirvType::from(var.v_type.clone())), ) } fn emit_function_body_ops( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, func: &[ExpandedStatement], ) -> Result<(), TranslateError> { for s in func { match s { Statement::Label(id) => { if builder.block.is_some() { builder.branch(*id)?; } builder.begin_block(Some(*id))?; } _ => { if builder.block.is_none() && builder.function.is_some() { builder.begin_block(None)?; } } } match s { Statement::Label(_) => (), Statement::Call(call) => { let (result_type, result_id) = match &*call.ret_params { [(id, typ)] => ( map.get_or_add(builder, SpirvType::from(typ.to_func_type())), Some(*id), ), [] => (map.void(), None), _ => todo!(), }; let arg_list = call .param_list .iter() .map(|(id, _)| *id) .collect::>(); builder.function_call(result_type, result_id, call.func, arg_list)?; } Statement::Variable(var) => { emit_variable(builder, map, var)?; } Statement::Constant(cnst) => { let typ_id = map.get_or_add_scalar(builder, cnst.typ); match (cnst.typ, cnst.value) { (ast::ScalarType::B8, ast::ImmediateValue::U64(value)) | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => { builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32); } (ast::ScalarType::B16, ast::ImmediateValue::U64(value)) | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => { builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32); } (ast::ScalarType::B32, ast::ImmediateValue::U64(value)) | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => { builder.constant_u32(typ_id, Some(cnst.dst), value as u32); } (ast::ScalarType::B64, ast::ImmediateValue::U64(value)) | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => { builder.constant_u64(typ_id, Some(cnst.dst), value); } (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => { builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32); } (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => { builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32); } (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => { builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32); } (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => { builder.constant_u64(typ_id, Some(cnst.dst), value as i64 as u64); } (ast::ScalarType::B8, ast::ImmediateValue::S64(value)) | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => { builder.constant_u32(typ_id, Some(cnst.dst), value as u8 as u32); } (ast::ScalarType::B16, ast::ImmediateValue::S64(value)) | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => { builder.constant_u32(typ_id, Some(cnst.dst), value as u16 as u32); } (ast::ScalarType::B32, ast::ImmediateValue::S64(value)) | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => { builder.constant_u32(typ_id, Some(cnst.dst), value as u32); } (ast::ScalarType::B64, ast::ImmediateValue::S64(value)) | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => { builder.constant_u64(typ_id, Some(cnst.dst), value as u64); } (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => { builder.constant_u32(typ_id, Some(cnst.dst), value as i8 as u32); } (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => { builder.constant_u32(typ_id, Some(cnst.dst), value as i16 as u32); } (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => { builder.constant_u32(typ_id, Some(cnst.dst), value as i32 as u32); } (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => { builder.constant_u64(typ_id, Some(cnst.dst), value as u64); } (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => { builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f32(value).to_f32()); } (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => { builder.constant_f32(typ_id, Some(cnst.dst), value); } (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => { builder.constant_f64(typ_id, Some(cnst.dst), value as f64); } (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => { builder.constant_f32(typ_id, Some(cnst.dst), f16::from_f64(value).to_f32()); } (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => { builder.constant_f32(typ_id, Some(cnst.dst), value as f32); } (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => { builder.constant_f64(typ_id, Some(cnst.dst), value); } (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => { let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred); if value == 0 { builder.constant_false(bool_type, Some(cnst.dst)); } else { builder.constant_true(bool_type, Some(cnst.dst)); } } (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => { let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred); if value == 0 { builder.constant_false(bool_type, Some(cnst.dst)); } else { builder.constant_true(bool_type, Some(cnst.dst)); } } _ => return Err(TranslateError::MismatchedType), } } Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, Statement::Conditional(bra) => { builder.branch_conditional(bra.predicate, bra.if_true, bra.if_false, [])?; } Statement::Instruction(inst) => match inst { ast::Instruction::Abs(d, arg) => emit_abs(builder, map, opencl, d, arg)?, ast::Instruction::Call(_) => unreachable!(), // SPIR-V does not support marking jumps as guaranteed-converged ast::Instruction::Bra(_, arg) => { builder.branch(arg.src)?; } ast::Instruction::Ld(data, arg) => { if data.qualifier != ast::LdStQualifier::Weak { todo!() } let result_type = map.get_or_add(builder, SpirvType::from(ast::Type::from(data.typ.clone()))); builder.load( result_type, Some(arg.dst), arg.src, Some(spirv::MemoryAccess::ALIGNED), [dr::Operand::LiteralInt32( ast::Type::from(data.typ.clone()).size_of() as u32, )], )?; } ast::Instruction::St(data, arg) => { if data.qualifier != ast::LdStQualifier::Weak { todo!() } builder.store( arg.src1, arg.src2, Some(spirv::MemoryAccess::ALIGNED), [dr::Operand::LiteralInt32( ast::Type::from(data.typ.clone()).size_of() as u32, )], )?; } // SPIR-V does not support ret as guaranteed-converged ast::Instruction::Ret(_) => builder.ret()?, ast::Instruction::Mov(d, arg) => { let result_type = map.get_or_add(builder, SpirvType::from(ast::Type::from(d.typ.clone()))); builder.copy_object(result_type, Some(arg.dst), arg.src)?; } ast::Instruction::Mul(mul, arg) => match mul { ast::MulDetails::Signed(ref ctr) => { emit_mul_sint(builder, map, opencl, ctr, arg)? } ast::MulDetails::Unsigned(ref ctr) => { emit_mul_uint(builder, map, opencl, ctr, arg)? } ast::MulDetails::Float(ref ctr) => emit_mul_float(builder, map, ctr, arg)?, }, ast::Instruction::Add(add, arg) => match add { ast::ArithDetails::Signed(ref desc) => { emit_add_int(builder, map, desc.typ.into(), desc.saturate, arg)? } ast::ArithDetails::Unsigned(ref desc) => { emit_add_int(builder, map, (*desc).into(), false, arg)? } ast::ArithDetails::Float(desc) => emit_add_float(builder, map, desc, arg)?, }, ast::Instruction::Setp(setp, arg) => { if arg.dst2.is_some() { todo!() } emit_setp(builder, map, setp, arg)?; } ast::Instruction::Not(t, a) => { let result_type = map.get_or_add(builder, SpirvType::from(t.to_type())); let result_id = Some(a.dst); let operand = a.src; match t { ast::BooleanType::Pred => { // HACK ALERT // Temporary workaround until IGC gets its shit together // Currently IGC carries two copies of SPIRV-LLVM translator // a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/. // Obviously, old and buggy one is used for compiling L0 SPIRV // https://github.com/intel/intel-graphics-compiler/issues/148 let type_pred = map.get_or_add_scalar(builder, ast::ScalarType::Pred); let const_true = builder.constant_true(type_pred, None); let const_false = builder.constant_false(type_pred, None); builder.select(result_type, result_id, operand, const_false, const_true) } _ => builder.not(result_type, result_id, operand), }?; } ast::Instruction::Shl(t, a) => { let full_type = t.to_type(); let size_of = full_type.size_of(); let result_type = map.get_or_add(builder, SpirvType::from(full_type)); let offset_src = insert_shift_hack(builder, map, a.src2, size_of)?; builder.shift_left_logical(result_type, Some(a.dst), a.src1, offset_src)?; } ast::Instruction::Shr(t, a) => { let full_type = ast::ScalarType::from(*t); let size_of = full_type.size_of(); let result_type = map.get_or_add_scalar(builder, full_type); let offset_src = insert_shift_hack(builder, map, a.src2, size_of as usize)?; if t.signed() { builder.shift_right_arithmetic( result_type, Some(a.dst), a.src1, offset_src, )?; } else { builder.shift_right_logical( result_type, Some(a.dst), a.src1, offset_src, )?; } } ast::Instruction::Cvt(dets, arg) => { emit_cvt(builder, map, opencl, dets, arg)?; } ast::Instruction::Cvta(_, arg) => { // This would be only meaningful if const/slm/global pointers // had a different format than generic pointers, but they don't pretty much by ptx definition // Honestly, I have no idea why this instruction exists and is emitted by the compiler let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64); builder.copy_object(result_type, Some(arg.dst), arg.src)?; } ast::Instruction::SetpBool(_, _) => todo!(), ast::Instruction::Mad(mad, arg) => match mad { ast::MulDetails::Signed(ref desc) => { emit_mad_sint(builder, map, opencl, desc, arg)? } ast::MulDetails::Unsigned(ref desc) => { emit_mad_uint(builder, map, opencl, desc, arg)? } ast::MulDetails::Float(desc) => { emit_mad_float(builder, map, opencl, desc, arg)? } }, ast::Instruction::Or(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); if *t == ast::BooleanType::Pred { builder.logical_or(result_type, Some(a.dst), a.src1, a.src2)?; } else { builder.bitwise_or(result_type, Some(a.dst), a.src1, a.src2)?; } } ast::Instruction::Sub(d, arg) => match d { ast::ArithDetails::Signed(desc) => { emit_sub_int(builder, map, desc.typ.into(), desc.saturate, arg)?; } ast::ArithDetails::Unsigned(desc) => { emit_sub_int(builder, map, (*desc).into(), false, arg)?; } ast::ArithDetails::Float(desc) => { emit_sub_float(builder, map, desc, arg)?; } }, ast::Instruction::Min(d, a) => { emit_min(builder, map, opencl, d, a)?; } ast::Instruction::Max(d, a) => { emit_max(builder, map, opencl, d, a)?; } ast::Instruction::Rcp(d, a) => { emit_rcp(builder, map, d, a)?; } ast::Instruction::And(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); if *t == ast::BooleanType::Pred { builder.logical_and(result_type, Some(a.dst), a.src1, a.src2)?; } else { builder.bitwise_and(result_type, Some(a.dst), a.src1, a.src2)?; } } ast::Instruction::Selp(t, a) => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::from(*t)); builder.select(result_type, Some(a.dst), a.src3, a.src1, a.src2)?; } // TODO: implement named barriers ast::Instruction::Bar(d, _) => { let workgroup_scope = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(spirv::Scope::Workgroup as u32), )?; let barrier_semantics = match d { ast::BarDetails::SyncAligned => map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr( spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY | spirv::MemorySemantics::WORKGROUP_MEMORY | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, ), )?, }; builder.control_barrier(workgroup_scope, workgroup_scope, barrier_semantics)?; } ast::Instruction::Atom(details, arg) => { emit_atom(builder, map, details, arg)?; } ast::Instruction::AtomCas(details, arg) => { let result_type = map.get_or_add_scalar(builder, details.typ.into()); let memory_const = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(details.scope.to_spirv() as u32), )?; let semantics_const = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(details.semantics.to_spirv().bits()), )?; builder.atomic_compare_exchange( result_type, Some(arg.dst), arg.src1, memory_const, semantics_const, semantics_const, arg.src3, arg.src2, )?; } ast::Instruction::Div(details, arg) => match details { ast::DivDetails::Unsigned(t) => { let result_type = map.get_or_add_scalar(builder, (*t).into()); builder.u_div(result_type, Some(arg.dst), arg.src1, arg.src2)?; } ast::DivDetails::Signed(t) => { let result_type = map.get_or_add_scalar(builder, (*t).into()); builder.s_div(result_type, Some(arg.dst), arg.src1, arg.src2)?; } ast::DivDetails::Float(t) => { let result_type = map.get_or_add_scalar(builder, t.typ.into()); builder.f_div(result_type, Some(arg.dst), arg.src1, arg.src2)?; emit_float_div_decoration(builder, arg.dst, t.kind); } }, ast::Instruction::Sqrt(details, a) => { emit_sqrt(builder, map, opencl, details, a)?; } ast::Instruction::Rsqrt(details, a) => { let result_type = map.get_or_add_scalar(builder, details.typ.into()); builder.ext_inst( result_type, Some(a.dst), opencl, spirv::CLOp::native_rsqrt as spirv::Word, &[a.src], )?; } ast::Instruction::Neg(details, arg) => { let result_type = map.get_or_add_scalar(builder, details.typ); let negate_func = if details.typ.kind() == ScalarKind::Float { dr::Builder::f_negate } else { dr::Builder::s_negate }; negate_func(builder, result_type, Some(arg.dst), arg.src)?; } ast::Instruction::Sin { arg, .. } => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); builder.ext_inst( result_type, Some(arg.dst), opencl, spirv::CLOp::sin as u32, [arg.src], )?; } ast::Instruction::Cos { arg, .. } => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); builder.ext_inst( result_type, Some(arg.dst), opencl, spirv::CLOp::cos as u32, [arg.src], )?; } ast::Instruction::Lg2 { arg, .. } => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); builder.ext_inst( result_type, Some(arg.dst), opencl, spirv::CLOp::log2 as u32, [arg.src], )?; } ast::Instruction::Ex2 { arg, .. } => { let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); builder.ext_inst( result_type, Some(arg.dst), opencl, spirv::CLOp::exp2 as u32, [arg.src], )?; } ast::Instruction::Clz { typ, arg } => { let result_type = map.get_or_add_scalar(builder, (*typ).into()); builder.ext_inst( result_type, Some(arg.dst), opencl, spirv::CLOp::clz as u32, [arg.src], )?; } ast::Instruction::Brev { typ, arg } => { let result_type = map.get_or_add_scalar(builder, (*typ).into()); builder.bit_reverse(result_type, Some(arg.dst), arg.src)?; } ast::Instruction::Popc { typ, arg } => { let result_type = map.get_or_add_scalar(builder, (*typ).into()); builder.bit_count(result_type, Some(arg.dst), arg.src)?; } ast::Instruction::Xor { typ, arg } => { let builder_fn = match typ { ast::BooleanType::Pred => emit_logical_xor_spirv, _ => dr::Builder::bitwise_xor, }; let result_type = map.get_or_add_scalar(builder, (*typ).into()); builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; } ast::Instruction::Bfe { typ, arg } => { let builder_fn = if typ.is_signed() { dr::Builder::bit_field_s_extract } else { dr::Builder::bit_field_u_extract }; let result_type = map.get_or_add_scalar(builder, (*typ).into()); builder_fn( builder, result_type, Some(arg.dst), arg.src1, arg.src2, arg.src3, )?; } ast::Instruction::Rem { typ, arg } => { let builder_fn = if typ.is_signed() { dr::Builder::s_mod } else { dr::Builder::u_mod }; let result_type = map.get_or_add_scalar(builder, (*typ).into()); builder_fn(builder, result_type, Some(arg.dst), arg.src1, arg.src2)?; } }, Statement::LoadVar(details) => { emit_load_var(builder, map, details)?; } Statement::StoreVar(details) => { let dst_ptr = match details.member_index { Some(index) => { let result_ptr_type = map.get_or_add( builder, SpirvType::new_pointer( details.typ.clone(), spirv::StorageClass::Function, ), ); let index_spirv = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(index as u32), )?; builder.in_bounds_access_chain( result_ptr_type, None, details.arg.src1, &[index_spirv], )? } None => details.arg.src1, }; builder.store(dst_ptr, details.arg.src2, None, [])?; } Statement::RetValue(_, id) => { builder.ret_value(*id)?; } Statement::PtrAccess(PtrAccess { underlying_type, state_space, dst, ptr_src, offset_src, }) => { let u8_pointer = map.get_or_add( builder, SpirvType::from(ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::U8), *state_space, )), ); let result_type = map.get_or_add( builder, SpirvType::from(ast::Type::Pointer(underlying_type.clone(), *state_space)), ); let ptr_src_u8 = builder.bitcast(u8_pointer, None, *ptr_src)?; let temp = builder.in_bounds_ptr_access_chain( u8_pointer, None, ptr_src_u8, *offset_src, &[], )?; builder.bitcast(result_type, Some(*dst), temp)?; } Statement::RepackVector(repack) => { if repack.is_extract { let scalar_type = map.get_or_add_scalar(builder, repack.typ); for (index, dst_id) in repack.unpacked.iter().enumerate() { builder.composite_extract( scalar_type, Some(*dst_id), repack.packed, &[index as u32], )?; } } else { let vector_type = map.get_or_add( builder, SpirvType::Vector( SpirvScalarKey::from(repack.typ), repack.unpacked.len() as u8, ), ); let mut temp_vec = builder.undef(vector_type, None); for (index, src_id) in repack.unpacked.iter().enumerate() { temp_vec = builder.composite_insert( vector_type, None, *src_id, temp_vec, &[index as u32], )?; } builder.copy_object(vector_type, Some(repack.packed), temp_vec)?; } } } } Ok(()) } // HACK ALERT // For some reason IGC fails linking if the value and shift size are of different type fn insert_shift_hack( builder: &mut dr::Builder, map: &mut TypeWordMap, offset_var: spirv::Word, size_of: usize, ) -> Result { let result_type = match size_of { 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), 4 => return Ok(offset_var), _ => return Err(error_unreachable()), }; Ok(builder.u_convert(result_type, None, offset_var)?) } // TODO: check what kind of assembly do we emit fn emit_logical_xor_spirv( builder: &mut dr::Builder, result_type: spirv::Word, result_id: Option, op1: spirv::Word, op2: spirv::Word, ) -> Result { let temp_or = builder.logical_or(result_type, None, op1, op2)?; let temp_and = builder.logical_and(result_type, None, op1, op2)?; let temp_neg = builder.logical_not(result_type, None, temp_and)?; builder.logical_and(result_type, result_id, temp_or, temp_neg) } fn emit_sqrt( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, details: &ast::SqrtDetails, a: &ast::Arg2, ) -> Result<(), TranslateError> { let result_type = map.get_or_add_scalar(builder, details.typ.into()); let (ocl_op, rounding) = match details.kind { ast::SqrtKind::Approx => (spirv::CLOp::native_sqrt, None), ast::SqrtKind::Rounding(rnd) => (spirv::CLOp::sqrt, Some(rnd)), }; builder.ext_inst( result_type, Some(a.dst), opencl, ocl_op as spirv::Word, &[a.src], )?; emit_rounding_decoration(builder, a.dst, rounding); Ok(()) } fn emit_float_div_decoration(builder: &mut dr::Builder, dst: spirv::Word, kind: ast::DivFloatKind) { match kind { ast::DivFloatKind::Approx => { builder.decorate( dst, spirv::Decoration::FPFastMathMode, &[dr::Operand::FPFastMathMode( spirv::FPFastMathMode::ALLOW_RECIP, )], ); } ast::DivFloatKind::Rounding(rnd) => { emit_rounding_decoration(builder, dst, Some(rnd)); } ast::DivFloatKind::Full => {} } } fn emit_atom( builder: &mut dr::Builder, map: &mut TypeWordMap, details: &ast::AtomDetails, arg: &ast::Arg3, ) -> Result<(), TranslateError> { let (spirv_op, typ) = match details.inner { ast::AtomInnerDetails::Bit { op, typ } => { let spirv_op = match op { ast::AtomBitOp::And => dr::Builder::atomic_and, ast::AtomBitOp::Or => dr::Builder::atomic_or, ast::AtomBitOp::Xor => dr::Builder::atomic_xor, ast::AtomBitOp::Exchange => dr::Builder::atomic_exchange, }; (spirv_op, ast::ScalarType::from(typ)) } ast::AtomInnerDetails::Unsigned { op, typ } => { let spirv_op = match op { ast::AtomUIntOp::Add => dr::Builder::atomic_i_add, ast::AtomUIntOp::Inc | ast::AtomUIntOp::Dec => { return Err(error_unreachable()); } ast::AtomUIntOp::Min => dr::Builder::atomic_u_min, ast::AtomUIntOp::Max => dr::Builder::atomic_u_max, }; (spirv_op, typ.into()) } ast::AtomInnerDetails::Signed { op, typ } => { let spirv_op = match op { ast::AtomSIntOp::Add => dr::Builder::atomic_i_add, ast::AtomSIntOp::Min => dr::Builder::atomic_s_min, ast::AtomSIntOp::Max => dr::Builder::atomic_s_max, }; (spirv_op, typ.into()) } // TODO: Hardware is capable of this, implement it through builtin ast::AtomInnerDetails::Float { .. } => todo!(), }; let result_type = map.get_or_add_scalar(builder, typ); let memory_const = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(details.scope.to_spirv() as u32), )?; let semantics_const = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(details.semantics.to_spirv().bits()), )?; spirv_op( builder, result_type, Some(arg.dst), arg.src1, memory_const, semantics_const, arg.src2, )?; Ok(()) } #[derive(Clone)] struct PtxImplImport { out_arg: ast::Type, fn_id: u32, in_args: Vec, } fn ptx_semantics_name(sema: ast::AtomSemantics) -> &'static str { match sema { ast::AtomSemantics::Relaxed => "relaxed", ast::AtomSemantics::Acquire => "acquire", ast::AtomSemantics::Release => "release", ast::AtomSemantics::AcquireRelease => "acq_rel", } } fn ptx_scope_name(scope: ast::MemScope) -> &'static str { match scope { ast::MemScope::Cta => "cta", ast::MemScope::Gpu => "gpu", ast::MemScope::Sys => "sys", } } fn ptx_space_name(space: ast::AtomSpace) -> &'static str { match space { ast::AtomSpace::Generic => "generic", ast::AtomSpace::Global => "global", ast::AtomSpace::Shared => "shared", } } fn emit_mul_float( builder: &mut dr::Builder, map: &mut TypeWordMap, ctr: &ast::ArithFloat, arg: &ast::Arg3, ) -> Result<(), dr::Error> { if ctr.saturate { todo!() } let result_type = map.get_or_add_scalar(builder, ctr.typ.into()); builder.f_mul(result_type, Some(arg.dst), arg.src1, arg.src2)?; emit_rounding_decoration(builder, arg.dst, ctr.rounding); Ok(()) } fn emit_rcp( builder: &mut dr::Builder, map: &mut TypeWordMap, desc: &ast::RcpDetails, a: &ast::Arg2, ) -> Result<(), TranslateError> { let (instr_type, constant) = if desc.is_f64 { (ast::ScalarType::F64, vec_repr(1.0f64)) } else { (ast::ScalarType::F32, vec_repr(1.0f32)) }; let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?; let result_type = map.get_or_add_scalar(builder, instr_type); builder.f_div(result_type, Some(a.dst), one, a.src)?; emit_rounding_decoration(builder, a.dst, desc.rounding); builder.decorate( a.dst, spirv::Decoration::FPFastMathMode, &[dr::Operand::FPFastMathMode( spirv::FPFastMathMode::ALLOW_RECIP, )], ); Ok(()) } fn vec_repr(t: T) -> Vec { let mut result = vec![0; mem::size_of::()]; unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) }; result } fn emit_variable( builder: &mut dr::Builder, map: &mut TypeWordMap, var: &ast::Variable, ) -> Result<(), TranslateError> { let (must_init, st_class) = match var.v_type { ast::VariableType::Reg(_) | ast::VariableType::Param(_) | ast::VariableType::Local(_) => { (false, spirv::StorageClass::Function) } ast::VariableType::Global(_) => (true, spirv::StorageClass::CrossWorkgroup), ast::VariableType::Shared(_) => (false, spirv::StorageClass::Workgroup), }; let initalizer = if var.array_init.len() > 0 { Some(map.get_or_add_constant( builder, &ast::Type::from(var.v_type.clone()), &*var.array_init, )?) } else if must_init { let type_id = map.get_or_add( builder, SpirvType::from(ast::Type::from(var.v_type.clone())), ); Some(builder.constant_null(type_id, None)) } else { None }; let ptr_type_id = map.get_or_add( builder, SpirvType::new_pointer(ast::Type::from(var.v_type.clone()), st_class), ); builder.variable(ptr_type_id, Some(var.name), st_class, initalizer); if let Some(align) = var.align { builder.decorate( var.name, spirv::Decoration::Alignment, &[dr::Operand::LiteralInt32(align)], ); } Ok(()) } fn emit_mad_uint( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::MulUInt, arg: &ast::Arg4, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); match desc.control { ast::MulIntControl::Low => { let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?; builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?; } ast::MulIntControl::High => { builder.ext_inst( inst_type, Some(arg.dst), opencl, spirv::CLOp::u_mad_hi as spirv::Word, [arg.src1, arg.src2, arg.src3], )?; } ast::MulIntControl::Wide => todo!(), }; Ok(()) } fn emit_mad_sint( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::MulSInt, arg: &ast::Arg4, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); match desc.control { ast::MulIntControl::Low => { let mul_result = builder.i_mul(inst_type, None, arg.src1, arg.src2)?; builder.i_add(inst_type, Some(arg.dst), arg.src3, mul_result)?; } ast::MulIntControl::High => { builder.ext_inst( inst_type, Some(arg.dst), opencl, spirv::CLOp::s_mad_hi as spirv::Word, [arg.src1, arg.src2, arg.src3], )?; } ast::MulIntControl::Wide => todo!(), }; Ok(()) } fn emit_mad_float( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::ArithFloat, arg: &ast::Arg4, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); builder.ext_inst( inst_type, Some(arg.dst), opencl, spirv::CLOp::mad as spirv::Word, [arg.src1, arg.src2, arg.src3], )?; Ok(()) } fn emit_add_float( builder: &mut dr::Builder, map: &mut TypeWordMap, desc: &ast::ArithFloat, arg: &ast::Arg3, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); builder.f_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; emit_rounding_decoration(builder, arg.dst, desc.rounding); Ok(()) } fn emit_sub_float( builder: &mut dr::Builder, map: &mut TypeWordMap, desc: &ast::ArithFloat, arg: &ast::Arg3, ) -> Result<(), dr::Error> { let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); builder.f_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?; emit_rounding_decoration(builder, arg.dst, desc.rounding); Ok(()) } fn emit_min( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::MinMaxDetails, arg: &ast::Arg3, ) -> Result<(), dr::Error> { let cl_op = match desc { ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min, ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, }; let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type())); builder.ext_inst( inst_type, Some(arg.dst), opencl, cl_op as spirv::Word, [arg.src1, arg.src2], )?; Ok(()) } fn emit_max( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::MinMaxDetails, arg: &ast::Arg3, ) -> Result<(), dr::Error> { let cl_op = match desc { ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max, ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, }; let inst_type = map.get_or_add(builder, SpirvType::from(desc.get_type())); builder.ext_inst( inst_type, Some(arg.dst), opencl, cl_op as spirv::Word, [arg.src1, arg.src2], )?; Ok(()) } fn emit_cvt( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, dets: &ast::CvtDetails, arg: &ast::Arg2, ) -> Result<(), TranslateError> { match dets { ast::CvtDetails::FloatFromFloat(desc) => { if desc.saturate { todo!() } let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); if desc.dst == desc.src { match desc.rounding { Some(ast::RoundingMode::NearestEven) => { builder.ext_inst( result_type, Some(arg.dst), opencl, spirv::CLOp::rint as u32, [arg.src], )?; } Some(ast::RoundingMode::Zero) => { builder.ext_inst( result_type, Some(arg.dst), opencl, spirv::CLOp::trunc as u32, [arg.src], )?; } Some(ast::RoundingMode::NegativeInf) => { builder.ext_inst( result_type, Some(arg.dst), opencl, spirv::CLOp::floor as u32, [arg.src], )?; } Some(ast::RoundingMode::PositiveInf) => { builder.ext_inst( result_type, Some(arg.dst), opencl, spirv::CLOp::ceil as u32, [arg.src], )?; } None => { builder.copy_object(result_type, Some(arg.dst), arg.src)?; } } } else { builder.f_convert(result_type, Some(arg.dst), arg.src)?; emit_rounding_decoration(builder, arg.dst, desc.rounding); } } ast::CvtDetails::FloatFromInt(desc) => { if desc.saturate { todo!() } let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); if desc.src.is_signed() { builder.convert_s_to_f(result_type, Some(arg.dst), arg.src)?; } else { builder.convert_u_to_f(result_type, Some(arg.dst), arg.src)?; } emit_rounding_decoration(builder, arg.dst, desc.rounding); } ast::CvtDetails::IntFromFloat(desc) => { let dest_t: ast::ScalarType = desc.dst.into(); let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); if desc.dst.is_signed() { builder.convert_f_to_s(result_type, Some(arg.dst), arg.src)?; } else { builder.convert_f_to_u(result_type, Some(arg.dst), arg.src)?; } emit_rounding_decoration(builder, arg.dst, desc.rounding); emit_saturating_decoration(builder, arg.dst, desc.saturate); } ast::CvtDetails::IntFromInt(desc) => { let dest_t: ast::ScalarType = desc.dst.into(); let src_t: ast::ScalarType = desc.src.into(); // first do shortening/widening let src = if desc.dst.width() != desc.src.width() { let new_dst = if dest_t.kind() == src_t.kind() { arg.dst } else { builder.id() }; let cv = ImplicitConversion { src: arg.src, dst: new_dst, from: ast::Type::Scalar(src_t), to: ast::Type::Scalar(ast::ScalarType::from_parts( dest_t.size_of(), src_t.kind(), )), kind: ConversionKind::Default, src_sema: ArgumentSemantics::Default, dst_sema: ArgumentSemantics::Default, }; emit_implicit_conversion(builder, map, &cv)?; new_dst } else { arg.src }; if dest_t.kind() == src_t.kind() { return Ok(()); } // now do actual conversion let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); if desc.saturate { if desc.dst.is_signed() { builder.sat_convert_u_to_s(result_type, Some(arg.dst), src)?; } else { builder.sat_convert_s_to_u(result_type, Some(arg.dst), src)?; } } else { builder.bitcast(result_type, Some(arg.dst), src)?; } } } Ok(()) } fn emit_saturating_decoration(builder: &mut dr::Builder, dst: u32, saturate: bool) { if saturate { builder.decorate(dst, spirv::Decoration::SaturatedConversion, []); } } fn emit_rounding_decoration( builder: &mut dr::Builder, dst: spirv::Word, rounding: Option, ) { if let Some(rounding) = rounding { builder.decorate( dst, spirv::Decoration::FPRoundingMode, [rounding.to_spirv()], ); } } impl ast::RoundingMode { fn to_spirv(self) -> rspirv::dr::Operand { let mode = match self { ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE, ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ, ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP, ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN, }; rspirv::dr::Operand::FPRoundingMode(mode) } } fn emit_setp( builder: &mut dr::Builder, map: &mut TypeWordMap, setp: &ast::SetpData, arg: &ast::Arg4Setp, ) -> Result<(), dr::Error> { let result_type = map.get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred)); let result_id = Some(arg.dst1); let operand_1 = arg.src1; let operand_2 = arg.src2; match (setp.cmp_op, setp.typ.kind()) { (ast::SetpCompareOp::Eq, ScalarKind::Signed) | (ast::SetpCompareOp::Eq, ScalarKind::Unsigned) | (ast::SetpCompareOp::Eq, ScalarKind::Bit) => { builder.i_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::Eq, ScalarKind::Float) => { builder.f_ord_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::NotEq, ScalarKind::Signed) | (ast::SetpCompareOp::NotEq, ScalarKind::Unsigned) | (ast::SetpCompareOp::NotEq, ScalarKind::Bit) => { builder.i_not_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::NotEq, ScalarKind::Float) => { builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::Less, ScalarKind::Unsigned) | (ast::SetpCompareOp::Less, ScalarKind::Bit) => { builder.u_less_than(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::Less, ScalarKind::Signed) => { builder.s_less_than(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::Less, ScalarKind::Float) => { builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::LessOrEq, ScalarKind::Unsigned) | (ast::SetpCompareOp::LessOrEq, ScalarKind::Bit) => { builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::LessOrEq, ScalarKind::Signed) => { builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::LessOrEq, ScalarKind::Float) => { builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::Greater, ScalarKind::Unsigned) | (ast::SetpCompareOp::Greater, ScalarKind::Bit) => { builder.u_greater_than(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::Greater, ScalarKind::Signed) => { builder.s_greater_than(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::Greater, ScalarKind::Float) => { builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Unsigned) | (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Bit) => { builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Signed) => { builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::GreaterOrEq, ScalarKind::Float) => { builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::NanEq, _) => { builder.f_unord_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::NanNotEq, _) => { builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::NanLess, _) => { builder.f_unord_less_than(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::NanLessOrEq, _) => { builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::NanGreater, _) => { builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2) } (ast::SetpCompareOp::NanGreaterOrEq, _) => { builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2) } _ => todo!(), }?; Ok(()) } fn emit_mul_sint( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::MulSInt, arg: &ast::Arg3, ) -> Result<(), dr::Error> { let instruction_type = ast::ScalarType::from(desc.typ); let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); match desc.control { ast::MulIntControl::Low => { builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?; } ast::MulIntControl::High => { builder.ext_inst( inst_type, Some(arg.dst), opencl, spirv::CLOp::s_mul_hi as spirv::Word, [arg.src1, arg.src2], )?; } ast::MulIntControl::Wide => { let mul_ext_type = SpirvType::Struct(vec![ SpirvScalarKey::from(instruction_type), SpirvScalarKey::from(instruction_type), ]); let mul_ext_type_id = map.get_or_add(builder, mul_ext_type); let mul = builder.s_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?; let instr_width = instruction_type.size_of(); let instr_kind = instruction_type.kind(); let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); let dst_type_id = map.get_or_add_scalar(builder, dst_type); struct2_bitcast_to_wide( builder, map, SpirvScalarKey::from(instruction_type), inst_type, arg.dst, dst_type_id, mul, )?; } } Ok(()) } fn emit_mul_uint( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, desc: &ast::MulUInt, arg: &ast::Arg3, ) -> Result<(), dr::Error> { let instruction_type = ast::ScalarType::from(desc.typ); let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.typ))); match desc.control { ast::MulIntControl::Low => { builder.i_mul(inst_type, Some(arg.dst), arg.src1, arg.src2)?; } ast::MulIntControl::High => { builder.ext_inst( inst_type, Some(arg.dst), opencl, spirv::CLOp::u_mul_hi as spirv::Word, [arg.src1, arg.src2], )?; } ast::MulIntControl::Wide => { let mul_ext_type = SpirvType::Struct(vec![ SpirvScalarKey::from(instruction_type), SpirvScalarKey::from(instruction_type), ]); let mul_ext_type_id = map.get_or_add(builder, mul_ext_type); let mul = builder.u_mul_extended(mul_ext_type_id, None, arg.src1, arg.src2)?; let instr_width = instruction_type.size_of(); let instr_kind = instruction_type.kind(); let dst_type = ast::ScalarType::from_parts(instr_width * 2, instr_kind); let dst_type_id = map.get_or_add_scalar(builder, dst_type); struct2_bitcast_to_wide( builder, map, SpirvScalarKey::from(instruction_type), inst_type, arg.dst, dst_type_id, mul, )?; } } Ok(()) } // Surprisingly, structs can't be bitcast, so we route everything through a vector fn struct2_bitcast_to_wide( builder: &mut dr::Builder, map: &mut TypeWordMap, base_type_key: SpirvScalarKey, instruction_type: spirv::Word, dst: spirv::Word, dst_type_id: spirv::Word, src: spirv::Word, ) -> Result<(), dr::Error> { let low_bits = builder.composite_extract(instruction_type, None, src, [0])?; let high_bits = builder.composite_extract(instruction_type, None, src, [1])?; let vector_type = map.get_or_add(builder, SpirvType::Vector(base_type_key, 2)); let vector = builder.composite_construct(vector_type, None, [low_bits, high_bits])?; builder.bitcast(dst_type_id, Some(dst), vector)?; Ok(()) } fn emit_abs( builder: &mut dr::Builder, map: &mut TypeWordMap, opencl: spirv::Word, d: &ast::AbsDetails, arg: &ast::Arg2, ) -> Result<(), dr::Error> { let scalar_t = ast::ScalarType::from(d.typ); let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); let cl_abs = if scalar_t.kind() == ScalarKind::Signed { spirv::CLOp::s_abs } else { spirv::CLOp::fabs }; builder.ext_inst( result_type, Some(arg.dst), opencl, cl_abs as spirv::Word, [arg.src], )?; Ok(()) } fn emit_add_int( builder: &mut dr::Builder, map: &mut TypeWordMap, typ: ast::ScalarType, saturate: bool, arg: &ast::Arg3, ) -> Result<(), dr::Error> { if saturate { todo!() } let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))); builder.i_add(inst_type, Some(arg.dst), arg.src1, arg.src2)?; Ok(()) } fn emit_sub_int( builder: &mut dr::Builder, map: &mut TypeWordMap, typ: ast::ScalarType, saturate: bool, arg: &ast::Arg3, ) -> Result<(), dr::Error> { if saturate { todo!() } let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))); builder.i_sub(inst_type, Some(arg.dst), arg.src1, arg.src2)?; Ok(()) } fn emit_implicit_conversion( builder: &mut dr::Builder, map: &mut TypeWordMap, cv: &ImplicitConversion, ) -> Result<(), TranslateError> { let from_parts = cv.from.to_parts(); let to_parts = cv.to.to_parts(); match (from_parts.kind, to_parts.kind, cv.kind) { (_, _, ConversionKind::PtrToBit(typ)) => { let dst_type = map.get_or_add_scalar(builder, typ.into()); builder.convert_ptr_to_u(dst_type, Some(cv.dst), cv.src)?; } (_, _, ConversionKind::BitToPtr(_)) => { let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); builder.convert_u_to_ptr(dst_type, Some(cv.dst), cv.src)?; } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::Default) => { if from_parts.width == to_parts.width { let dst_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); if from_parts.scalar_kind != ScalarKind::Float && to_parts.scalar_kind != ScalarKind::Float { // It is noop, but another instruction expects result of this conversion builder.copy_object(dst_type, Some(cv.dst), cv.src)?; } else { builder.bitcast(dst_type, Some(cv.dst), cv.src)?; } } else { // This block is safe because it's illegal to implictly convert between floating point instructions let same_width_bit_type = map.get_or_add( builder, SpirvType::from(ast::Type::from_parts(TypeParts { scalar_kind: ScalarKind::Bit, ..from_parts })), ); let same_width_bit_value = builder.bitcast(same_width_bit_type, None, cv.src)?; let wide_bit_type = ast::Type::from_parts(TypeParts { scalar_kind: ScalarKind::Bit, ..to_parts }); let wide_bit_type_spirv = map.get_or_add(builder, SpirvType::from(wide_bit_type.clone())); if to_parts.scalar_kind == ScalarKind::Unsigned || to_parts.scalar_kind == ScalarKind::Bit { builder.u_convert(wide_bit_type_spirv, Some(cv.dst), same_width_bit_value)?; } else { let conversion_fn = if from_parts.scalar_kind == ScalarKind::Signed && to_parts.scalar_kind == ScalarKind::Signed { dr::Builder::s_convert } else { dr::Builder::u_convert }; let wide_bit_value = conversion_fn(builder, wide_bit_type_spirv, None, same_width_bit_value)?; emit_implicit_conversion( builder, map, &ImplicitConversion { src: wide_bit_value, dst: cv.dst, from: wide_bit_type, to: cv.to.clone(), kind: ConversionKind::Default, src_sema: cv.src_sema, dst_sema: cv.dst_sema, }, )?; } } } (TypeKind::Scalar, TypeKind::Scalar, ConversionKind::SignExtend) => todo!(), (TypeKind::Vector, TypeKind::Scalar, ConversionKind::Default) | (TypeKind::Scalar, TypeKind::Array, ConversionKind::Default) | (TypeKind::Array, TypeKind::Scalar, ConversionKind::Default) => { let into_type = map.get_or_add(builder, SpirvType::from(cv.to.clone())); builder.bitcast(into_type, Some(cv.dst), cv.src)?; } (_, _, ConversionKind::PtrToPtr { spirv_ptr }) => { let result_type = if spirv_ptr { map.get_or_add( builder, SpirvType::Pointer( Box::new(SpirvType::from(cv.to.clone())), spirv::StorageClass::Function, ), ) } else { map.get_or_add(builder, SpirvType::from(cv.to.clone())) }; builder.bitcast(result_type, Some(cv.dst), cv.src)?; } _ => unreachable!(), } Ok(()) } fn emit_load_var( builder: &mut dr::Builder, map: &mut TypeWordMap, details: &LoadVarDetails, ) -> Result<(), TranslateError> { let result_type = map.get_or_add(builder, SpirvType::from(details.typ.clone())); match details.member_index { Some((index, Some(width))) => { let vector_type = match details.typ { ast::Type::Scalar(scalar_t) => ast::Type::Vector(scalar_t, width), _ => return Err(TranslateError::MismatchedType), }; let vector_type_spirv = map.get_or_add(builder, SpirvType::from(vector_type)); let vector_temp = builder.load(vector_type_spirv, None, details.arg.src, None, [])?; builder.composite_extract( result_type, Some(details.arg.dst), vector_temp, &[index as u32], )?; } Some((index, None)) => { let result_ptr_type = map.get_or_add( builder, SpirvType::new_pointer(details.typ.clone(), spirv::StorageClass::Function), ); let index_spirv = map.get_or_add_constant( builder, &ast::Type::Scalar(ast::ScalarType::U32), &vec_repr(index as u32), )?; let src = builder.in_bounds_access_chain( result_ptr_type, None, details.arg.src, &[index_spirv], )?; builder.load(result_type, Some(details.arg.dst), src, None, [])?; } None => { builder.load( result_type, Some(details.arg.dst), details.arg.src, None, [], )?; } }; Ok(()) } fn normalize_identifiers<'a, 'b>( id_defs: &mut FnStringIdResolver<'a, 'b>, fn_defs: &GlobalFnDeclResolver<'a, 'b>, func: Vec>>, ) -> Result, TranslateError> { for s in func.iter() { match s { ast::Statement::Label(id) => { id_defs.add_def(*id, None, false); } _ => (), } } let mut result = Vec::new(); for s in func { expand_map_variables(id_defs, fn_defs, &mut result, s)?; } Ok(result) } fn expand_map_variables<'a, 'b>( id_defs: &mut FnStringIdResolver<'a, 'b>, fn_defs: &GlobalFnDeclResolver<'a, 'b>, result: &mut Vec, s: ast::Statement>, ) -> Result<(), TranslateError> { match s { ast::Statement::Block(block) => { id_defs.start_block(); for s in block { expand_map_variables(id_defs, fn_defs, result, s)?; } id_defs.end_block(); } ast::Statement::Label(name) => result.push(Statement::Label(id_defs.get_id(name)?)), ast::Statement::Instruction(p, i) => result.push(Statement::Instruction(( p.map(|p| p.map_variable(&mut |id| id_defs.get_id(id))) .transpose()?, i.map_variable(&mut |id| id_defs.get_id(id))?, ))), ast::Statement::Variable(var) => { let mut var_type = ast::Type::from(var.var.v_type.clone()); let mut is_variable = false; var_type = match var.var.v_type { ast::VariableType::Reg(_) => { is_variable = true; var_type } ast::VariableType::Shared(_) => { // If it's a pointer it will be translated to a method parameter later if let ast::Type::Pointer(..) = var_type { is_variable = true; var_type } else { var_type.param_pointer_to(ast::LdStateSpace::Shared)? } } ast::VariableType::Global(_) => { var_type.param_pointer_to(ast::LdStateSpace::Global)? } ast::VariableType::Param(_) => { var_type.param_pointer_to(ast::LdStateSpace::Param)? } ast::VariableType::Local(_) => { var_type.param_pointer_to(ast::LdStateSpace::Local)? } }; match var.count { Some(count) => { for new_id in id_defs.add_defs(var.var.name, count, var_type, is_variable) { result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), name: new_id, array_init: var.var.array_init.clone(), })) } } None => { let new_id = id_defs.add_def(var.var.name, Some(var_type), is_variable); result.push(Statement::Variable(ast::Variable { align: var.var.align, v_type: var.var.v_type.clone(), name: new_id, array_init: var.var.array_init, })); } } } }; Ok(()) } // TODO: detect more patterns (mov, call via reg, call via param) // TODO: don't convert to ptr if the register is not ultimately used for ld/st // TODO: once insert_mem_ssa_statements is moved to later, move this pass after // argument expansion // TODO: propagate through calls? fn convert_to_stateful_memory_access<'a>( func_args: &mut SpirvMethodDecl, func_body: Vec, id_defs: &mut NumericIdResolver<'a>, ) -> Result, TranslateError> { let func_args_64bit = func_args .input .iter() .filter_map(|arg| match arg.v_type { ast::Type::Scalar(ast::ScalarType::U64) | ast::Type::Scalar(ast::ScalarType::B64) | ast::Type::Scalar(ast::ScalarType::S64) => Some(arg.name), _ => None, }) .collect::>(); let mut stateful_markers = Vec::new(); let mut stateful_init_reg = MultiHashMap::new(); for statement in func_body.iter() { match statement { Statement::Instruction(ast::Instruction::Cvta( ast::CvtaDetails { to: ast::CvtaStateSpace::Global, size: ast::CvtaSize::U64, from: ast::CvtaStateSpace::Generic, }, arg, )) => { if let (TypedOperand::Reg(dst), Some(src)) = (arg.dst, arg.src.upcast().underlying()) { if is_64_bit_integer(id_defs, *src) && is_64_bit_integer(id_defs, dst) { stateful_markers.push((dst, *src)); } } } Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { state_space: ast::LdStateSpace::Param, typ: ast::LdStType::Scalar(ast::LdStScalarType::U64), .. }, arg, )) | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { state_space: ast::LdStateSpace::Param, typ: ast::LdStType::Scalar(ast::LdStScalarType::S64), .. }, arg, )) | Statement::Instruction(ast::Instruction::Ld( ast::LdDetails { state_space: ast::LdStateSpace::Param, typ: ast::LdStType::Scalar(ast::LdStScalarType::B64), .. }, arg, )) => { if let (TypedOperand::Reg(dst), Some(src)) = (&arg.dst, arg.src.upcast().underlying()) { if func_args_64bit.contains(src) { multi_hash_map_append(&mut stateful_init_reg, *dst, *src); } } } _ => {} } } let mut func_args_ptr = HashSet::new(); let mut regs_ptr_current = HashSet::new(); for (dst, src) in stateful_markers { if let Some(func_args) = stateful_init_reg.get(&src) { for a in func_args { func_args_ptr.insert(*a); regs_ptr_current.insert(src); regs_ptr_current.insert(dst); } } } // BTreeSet here to have a stable order of iteration, // unfortunately our tests rely on it let mut regs_ptr_seen = BTreeSet::new(); while regs_ptr_current.len() > 0 { let mut regs_ptr_new = HashSet::new(); for statement in func_body.iter() { match statement { Statement::Instruction(ast::Instruction::Add( ast::ArithDetails::Unsigned(ast::UIntType::U64), arg, )) | Statement::Instruction(ast::Instruction::Add( ast::ArithDetails::Signed(ast::ArithSInt { typ: ast::SIntType::S64, saturate: false, }), arg, )) | Statement::Instruction(ast::Instruction::Sub( ast::ArithDetails::Unsigned(ast::UIntType::U64), arg, )) | Statement::Instruction(ast::Instruction::Sub( ast::ArithDetails::Signed(ast::ArithSInt { typ: ast::SIntType::S64, saturate: false, }), arg, )) => { if let (TypedOperand::Reg(dst), Some(src1)) = (arg.dst, arg.src1.upcast().underlying()) { if regs_ptr_current.contains(src1) && !regs_ptr_seen.contains(src1) { regs_ptr_new.insert(dst); } } else if let (TypedOperand::Reg(dst), Some(src2)) = (arg.dst, arg.src2.upcast().underlying()) { if regs_ptr_current.contains(src2) && !regs_ptr_seen.contains(src2) { regs_ptr_new.insert(dst); } } } _ => {} } } for id in regs_ptr_current { regs_ptr_seen.insert(id); } regs_ptr_current = regs_ptr_new; } drop(regs_ptr_current); let mut remapped_ids = HashMap::new(); let mut result = Vec::with_capacity(regs_ptr_seen.len() + func_body.len()); for reg in regs_ptr_seen { let new_id = id_defs.new_variable(ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::U8), ast::LdStateSpace::Global, )); result.push(Statement::Variable(ast::Variable { align: None, name: new_id, array_init: Vec::new(), v_type: ast::VariableType::Reg(ast::VariableRegType::Pointer( ast::SizedScalarType::U8, ast::PointerStateSpace::Global, )), })); remapped_ids.insert(reg, new_id); } for statement in func_body { match statement { l @ Statement::Label(_) => result.push(l), c @ Statement::Conditional(_) => result.push(c), Statement::Variable(var) => { if !remapped_ids.contains_key(&var.name) { result.push(Statement::Variable(var)); } } Statement::Instruction(ast::Instruction::Add( ast::ArithDetails::Unsigned(ast::UIntType::U64), arg, )) | Statement::Instruction(ast::Instruction::Add( ast::ArithDetails::Signed(ast::ArithSInt { typ: ast::SIntType::S64, saturate: false, }), arg, )) if is_add_ptr_direct(&remapped_ids, &arg) => { let (ptr, offset) = match arg.src1.upcast().underlying() { Some(src1) if remapped_ids.contains_key(src1) => { (remapped_ids.get(src1).unwrap(), arg.src2) } Some(src2) if remapped_ids.contains_key(src2) => { (remapped_ids.get(src2).unwrap(), arg.src1) } _ => return Err(error_unreachable()), }; let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), state_space: ast::LdStateSpace::Global, dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, offset_src: offset, })) } Statement::Instruction(ast::Instruction::Sub( ast::ArithDetails::Unsigned(ast::UIntType::U64), arg, )) | Statement::Instruction(ast::Instruction::Sub( ast::ArithDetails::Signed(ast::ArithSInt { typ: ast::SIntType::S64, saturate: false, }), arg, )) if is_add_ptr_direct(&remapped_ids, &arg) => { let (ptr, offset) = match arg.src1.upcast().underlying() { Some(src1) if remapped_ids.contains_key(src1) => { (remapped_ids.get(src1).unwrap(), arg.src2) } Some(src2) if remapped_ids.contains_key(src2) => { (remapped_ids.get(src2).unwrap(), arg.src1) } _ => return Err(error_unreachable()), }; let offset_neg = id_defs.new_non_variable(Some(ast::Type::Scalar(ast::ScalarType::S64))); result.push(Statement::Instruction(ast::Instruction::Neg( ast::NegDetails { typ: ast::ScalarType::S64, flush_to_zero: None, }, ast::Arg2 { src: offset, dst: TypedOperand::Reg(offset_neg), }, ))); let dst = arg.dst.upcast().unwrap_reg()?; result.push(Statement::PtrAccess(PtrAccess { underlying_type: ast::PointerType::Scalar(ast::ScalarType::U8), state_space: ast::LdStateSpace::Global, dst: *remapped_ids.get(&dst).unwrap(), ptr_src: *ptr, offset_src: TypedOperand::Reg(offset_neg), })) } Statement::Instruction(inst) => { let mut post_statements = Vec::new(); let new_statement = inst.visit( &mut |arg_desc: ArgumentDescriptor, expected_type: Option<&ast::Type>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, &func_args_ptr, &mut result, &mut post_statements, arg_desc, expected_type, ) }, )?; result.push(new_statement); result.extend(post_statements); } Statement::Call(call) => { let mut post_statements = Vec::new(); let new_statement = call.visit( &mut |arg_desc: ArgumentDescriptor, expected_type: Option<&ast::Type>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, &func_args_ptr, &mut result, &mut post_statements, arg_desc, expected_type, ) }, )?; result.push(new_statement); result.extend(post_statements); } Statement::RepackVector(pack) => { let mut post_statements = Vec::new(); let new_statement = pack.visit( &mut |arg_desc: ArgumentDescriptor, expected_type: Option<&ast::Type>| { convert_to_stateful_memory_access_postprocess( id_defs, &remapped_ids, &func_args_ptr, &mut result, &mut post_statements, arg_desc, expected_type, ) }, )?; result.push(new_statement); result.extend(post_statements); } _ => return Err(error_unreachable()), } } for arg in func_args.input.iter_mut() { if func_args_ptr.contains(&arg.name) { arg.v_type = ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::U8), ast::LdStateSpace::Global, ); } } Ok(result) } fn convert_to_stateful_memory_access_postprocess( id_defs: &mut NumericIdResolver, remapped_ids: &HashMap, func_args_ptr: &HashSet, result: &mut Vec, post_statements: &mut Vec, arg_desc: ArgumentDescriptor, expected_type: Option<&ast::Type>, ) -> Result { Ok(match remapped_ids.get(&arg_desc.op) { Some(new_id) => { // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id), _ => id_defs.get_typed(arg_desc.op)?.0, }; let old_type_clone = old_type.clone(); let converting_id = id_defs.new_non_variable(Some(old_type_clone)); if arg_desc.is_dst { post_statements.push(Statement::Conversion(ImplicitConversion { src: converting_id, dst: *new_id, from: old_type, to: ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::U8), ast::LdStateSpace::Global, ), kind: ConversionKind::BitToPtr(ast::LdStateSpace::Global), src_sema: ArgumentSemantics::Default, dst_sema: arg_desc.sema, })); converting_id } else { result.push(Statement::Conversion(ImplicitConversion { src: *new_id, dst: converting_id, from: ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::U8), ast::LdStateSpace::Global, ), to: old_type, kind: ConversionKind::PtrToBit(ast::UIntType::U64), src_sema: arg_desc.sema, dst_sema: ArgumentSemantics::Default, })); converting_id } } None => match func_args_ptr.get(&arg_desc.op) { Some(new_id) => { if arg_desc.is_dst { return Err(error_unreachable()); } // We skip conversion here to trigger PtrAcces in a later pass let old_type = match expected_type { Some(ast::Type::Pointer(_, ast::LdStateSpace::Global)) => return Ok(*new_id), _ => id_defs.get_typed(arg_desc.op)?.0, }; let old_type_clone = old_type.clone(); let converting_id = id_defs.new_non_variable(Some(old_type)); result.push(Statement::Conversion(ImplicitConversion { src: *new_id, dst: converting_id, from: ast::Type::Pointer( ast::PointerType::Pointer(ast::ScalarType::U8, ast::LdStateSpace::Global), ast::LdStateSpace::Param, ), to: old_type_clone, kind: ConversionKind::PtrToPtr { spirv_ptr: false }, src_sema: arg_desc.sema, dst_sema: ArgumentSemantics::Default, })); converting_id } None => arg_desc.op, }, }) } fn is_add_ptr_direct(remapped_ids: &HashMap, arg: &ast::Arg3) -> bool { match arg.dst { TypedOperand::Imm(..) | TypedOperand::RegOffset(..) | TypedOperand::VecMember(..) => { return false } TypedOperand::Reg(dst) => { if !remapped_ids.contains_key(&dst) { return false; } match arg.src1.upcast().underlying() { Some(src1) if remapped_ids.contains_key(src1) => true, Some(src2) if remapped_ids.contains_key(src2) => true, _ => false, } } } } fn is_64_bit_integer(id_defs: &NumericIdResolver, id: spirv::Word) -> bool { match id_defs.get_typed(id) { Ok((ast::Type::Scalar(ast::ScalarType::U64), _)) | Ok((ast::Type::Scalar(ast::ScalarType::S64), _)) | Ok((ast::Type::Scalar(ast::ScalarType::B64), _)) => true, _ => false, } } #[derive(Ord, PartialOrd, Eq, PartialEq, Hash, Copy, Clone)] enum PtxSpecialRegister { Tid, Tid64, Ntid, Ntid64, Ctaid, Ctaid64, Nctaid, Nctaid64, } impl PtxSpecialRegister { fn try_parse(s: &str) -> Option { match s { "%tid" => Some(Self::Tid), "%ntid" => Some(Self::Ntid), "%ctaid" => Some(Self::Ctaid), "%nctaid" => Some(Self::Nctaid), _ => None, } } fn get_type(self) -> ast::Type { match self { PtxSpecialRegister::Tid => ast::Type::Vector(ast::ScalarType::U32, 4), PtxSpecialRegister::Tid64 => ast::Type::Vector(ast::ScalarType::U64, 3), PtxSpecialRegister::Ntid => ast::Type::Vector(ast::ScalarType::U32, 4), PtxSpecialRegister::Ntid64 => ast::Type::Vector(ast::ScalarType::U64, 3), PtxSpecialRegister::Ctaid => ast::Type::Vector(ast::ScalarType::U32, 4), PtxSpecialRegister::Ctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3), PtxSpecialRegister::Nctaid => ast::Type::Vector(ast::ScalarType::U32, 4), PtxSpecialRegister::Nctaid64 => ast::Type::Vector(ast::ScalarType::U64, 3), } } fn get_builtin(self) -> spirv::BuiltIn { match self { PtxSpecialRegister::Tid | PtxSpecialRegister::Tid64 => { spirv::BuiltIn::LocalInvocationId } PtxSpecialRegister::Ntid | PtxSpecialRegister::Ntid64 => spirv::BuiltIn::WorkgroupSize, PtxSpecialRegister::Ctaid | PtxSpecialRegister::Ctaid64 => spirv::BuiltIn::WorkgroupId, PtxSpecialRegister::Nctaid | PtxSpecialRegister::Nctaid64 => { spirv::BuiltIn::NumWorkgroups } } } fn normalized_sreg_and_type(self) -> Option<(PtxSpecialRegister, ast::ScalarType, u8)> { match self { PtxSpecialRegister::Tid => Some((PtxSpecialRegister::Tid64, ast::ScalarType::U64, 3)), PtxSpecialRegister::Ntid => Some((PtxSpecialRegister::Ntid64, ast::ScalarType::U64, 3)), PtxSpecialRegister::Ctaid => { Some((PtxSpecialRegister::Ctaid64, ast::ScalarType::U64, 3)) } PtxSpecialRegister::Nctaid => { Some((PtxSpecialRegister::Nctaid64, ast::ScalarType::U64, 3)) } PtxSpecialRegister::Tid64 | PtxSpecialRegister::Ntid64 | PtxSpecialRegister::Ctaid64 | PtxSpecialRegister::Nctaid64 => None, } } } struct SpecialRegistersMap { reg_to_id: HashMap, id_to_reg: HashMap, } impl SpecialRegistersMap { fn new() -> Self { SpecialRegistersMap { reg_to_id: HashMap::new(), id_to_reg: HashMap::new(), } } fn builtins<'a>(&'a self) -> impl Iterator + 'a { self.reg_to_id.iter().filter_map(|(sreg, id)| { if sreg.normalized_sreg_and_type().is_none() { Some((*sreg, *id)) } else { None } }) } fn interface(&self) -> Vec { self.reg_to_id .iter() .filter_map(|(sreg, id)| { if sreg.normalized_sreg_and_type().is_none() { Some(*id) } else { None } }) .collect::>() } fn get(&self, id: spirv::Word) -> Option { self.id_to_reg.get(&id).copied() } fn get_or_add(&mut self, current_id: &mut spirv::Word, reg: PtxSpecialRegister) -> spirv::Word { match self.reg_to_id.entry(reg) { hash_map::Entry::Occupied(e) => *e.get(), hash_map::Entry::Vacant(e) => { let numeric_id = *current_id; *current_id += 1; e.insert(numeric_id); self.id_to_reg.insert(numeric_id, reg); numeric_id } } } } struct GlobalStringIdResolver<'input> { current_id: spirv::Word, variables: HashMap, spirv::Word>, variables_type_check: HashMap>, special_registers: SpecialRegistersMap, fns: HashMap, } pub struct FnDecl { ret_vals: Vec, params: Vec, } impl<'a> GlobalStringIdResolver<'a> { fn new(start_id: spirv::Word) -> Self { Self { current_id: start_id, variables: HashMap::new(), variables_type_check: HashMap::new(), special_registers: SpecialRegistersMap::new(), fns: HashMap::new(), } } fn get_or_add_def(&mut self, id: &'a str) -> spirv::Word { self.get_or_add_impl(id, None) } fn get_or_add_def_typed( &mut self, id: &'a str, typ: ast::Type, is_variable: bool, ) -> spirv::Word { self.get_or_add_impl(id, Some((typ, is_variable))) } fn get_or_add_impl(&mut self, id: &'a str, typ: Option<(ast::Type, bool)>) -> spirv::Word { let id = match self.variables.entry(Cow::Borrowed(id)) { hash_map::Entry::Occupied(e) => *(e.get()), hash_map::Entry::Vacant(e) => { let numeric_id = self.current_id; e.insert(numeric_id); self.current_id += 1; numeric_id } }; self.variables_type_check.insert(id, typ); id } fn get_id(&self, id: &str) -> Result { self.variables .get(id) .copied() .ok_or(TranslateError::UnknownSymbol) } fn current_id(&self) -> spirv::Word { self.current_id } fn start_fn<'b>( &'b mut self, header: &'b ast::MethodDecl<'a, &'a str>, ) -> Result< ( FnStringIdResolver<'a, 'b>, GlobalFnDeclResolver<'a, 'b>, ast::MethodDecl<'a, spirv::Word>, ), TranslateError, > { // In case a function decl was inserted earlier we want to use its id let name_id = self.get_or_add_def(header.name()); let mut fn_resolver = FnStringIdResolver { current_id: &mut self.current_id, global_variables: &self.variables, global_type_check: &self.variables_type_check, special_registers: &mut self.special_registers, variables: vec![HashMap::new(); 1], type_check: HashMap::new(), }; let new_fn_decl = match header { ast::MethodDecl::Kernel { name, in_args } => ast::MethodDecl::Kernel { name, in_args: expand_kernel_params(&mut fn_resolver, in_args.iter())?, }, ast::MethodDecl::Func(ret_params, _, params) => { let ret_params_ids = expand_fn_params(&mut fn_resolver, ret_params.iter())?; let params_ids = expand_fn_params(&mut fn_resolver, params.iter())?; self.fns.insert( name_id, FnDecl { ret_vals: ret_params_ids.iter().map(|p| p.v_type.clone()).collect(), params: params_ids.iter().map(|p| p.v_type.clone()).collect(), }, ); ast::MethodDecl::Func(ret_params_ids, name_id, params_ids) } }; Ok(( fn_resolver, GlobalFnDeclResolver { variables: &self.variables, fns: &self.fns, }, new_fn_decl, )) } } pub struct GlobalFnDeclResolver<'input, 'a> { variables: &'a HashMap, spirv::Word>, fns: &'a HashMap, } impl<'input, 'a> GlobalFnDeclResolver<'input, 'a> { fn get_fn_decl(&self, id: spirv::Word) -> Result<&FnDecl, TranslateError> { self.fns.get(&id).ok_or(TranslateError::UnknownSymbol) } fn get_fn_decl_str(&self, id: &str) -> Result<&'a FnDecl, TranslateError> { match self.variables.get(id).map(|var_id| self.fns.get(var_id)) { Some(Some(fn_d)) => Ok(fn_d), _ => Err(TranslateError::UnknownSymbol), } } } struct FnStringIdResolver<'input, 'b> { current_id: &'b mut spirv::Word, global_variables: &'b HashMap, spirv::Word>, global_type_check: &'b HashMap>, special_registers: &'b mut SpecialRegistersMap, variables: Vec, spirv::Word>>, type_check: HashMap>, } impl<'a, 'b> FnStringIdResolver<'a, 'b> { fn finish(self) -> NumericIdResolver<'b> { NumericIdResolver { current_id: self.current_id, global_type_check: self.global_type_check, type_check: self.type_check, special_registers: self.special_registers, } } fn start_block(&mut self) { self.variables.push(HashMap::new()) } fn end_block(&mut self) { self.variables.pop(); } fn get_id(&mut self, id: &str) -> Result { for scope in self.variables.iter().rev() { match scope.get(id) { Some(id) => return Ok(*id), None => continue, } } match self.global_variables.get(id) { Some(id) => Ok(*id), None => { let sreg = PtxSpecialRegister::try_parse(id).ok_or(TranslateError::UnknownSymbol)?; Ok(self.special_registers.get_or_add(self.current_id, sreg)) } } } fn add_def(&mut self, id: &'a str, typ: Option, is_variable: bool) -> spirv::Word { let numeric_id = *self.current_id; self.variables .last_mut() .unwrap() .insert(Cow::Borrowed(id), numeric_id); self.type_check .insert(numeric_id, typ.map(|t| (t, is_variable))); *self.current_id += 1; numeric_id } #[must_use] fn add_defs( &mut self, base_id: &'a str, count: u32, typ: ast::Type, is_variable: bool, ) -> impl Iterator { let numeric_id = *self.current_id; for i in 0..count { self.variables .last_mut() .unwrap() .insert(Cow::Owned(format!("{}{}", base_id, i)), numeric_id + i); self.type_check .insert(numeric_id + i, Some((typ.clone(), is_variable))); } *self.current_id += count; (0..count).into_iter().map(move |i| i + numeric_id) } } struct NumericIdResolver<'b> { current_id: &'b mut spirv::Word, global_type_check: &'b HashMap>, type_check: HashMap>, special_registers: &'b mut SpecialRegistersMap, } impl<'b> NumericIdResolver<'b> { fn finish(self) -> MutableNumericIdResolver<'b> { MutableNumericIdResolver { base: self } } fn get_typed(&self, id: spirv::Word) -> Result<(ast::Type, bool), TranslateError> { match self.type_check.get(&id) { Some(Some(x)) => Ok(x.clone()), Some(None) => Err(TranslateError::UntypedSymbol), None => match self.special_registers.get(id) { Some(x) => Ok((x.get_type(), true)), None => match self.global_type_check.get(&id) { Some(Some(result)) => Ok(result.clone()), Some(None) | None => Err(TranslateError::UntypedSymbol), }, }, } } // This is for identifiers which will be emitted later as OpVariable // They are candidates for insertion of LoadVar/StoreVar fn new_variable(&mut self, typ: ast::Type) -> spirv::Word { let new_id = *self.current_id; self.type_check.insert(new_id, Some((typ, true))); *self.current_id += 1; new_id } fn new_non_variable(&mut self, typ: Option) -> spirv::Word { let new_id = *self.current_id; self.type_check.insert(new_id, typ.map(|t| (t, false))); *self.current_id += 1; new_id } } struct MutableNumericIdResolver<'b> { base: NumericIdResolver<'b>, } impl<'b> MutableNumericIdResolver<'b> { fn unmut(self) -> NumericIdResolver<'b> { self.base } fn get_typed(&self, id: spirv::Word) -> Result { self.base.get_typed(id).map(|(t, _)| t) } fn new_non_variable(&mut self, typ: ast::Type) -> spirv::Word { self.base.new_non_variable(Some(typ)) } } enum Statement { Label(u32), Variable(ast::Variable), Instruction(I), // SPIR-V compatible replacement for PTX predicates Conditional(BrachCondition), Call(ResolvedCall

), LoadVar(LoadVarDetails), StoreVar(StoreVarDetails), Conversion(ImplicitConversion), Constant(ConstantDefinition), RetValue(ast::RetData, spirv::Word), PtrAccess(PtrAccess

), RepackVector(RepackVectorDetails), } impl ExpandedStatement { fn map_id(self, f: &mut impl FnMut(spirv::Word, bool) -> spirv::Word) -> ExpandedStatement { match self { Statement::Label(id) => Statement::Label(f(id, false)), Statement::Variable(mut var) => { var.name = f(var.name, true); Statement::Variable(var) } Statement::Instruction(inst) => inst .visit(&mut |arg: ArgumentDescriptor<_>, _: Option<&ast::Type>| { Ok(f(arg.op, arg.is_dst)) }) .unwrap(), Statement::LoadVar(mut details) => { details.arg.dst = f(details.arg.dst, true); details.arg.src = f(details.arg.src, false); Statement::LoadVar(details) } Statement::StoreVar(mut details) => { details.arg.src1 = f(details.arg.src1, false); details.arg.src2 = f(details.arg.src2, false); Statement::StoreVar(details) } Statement::Call(mut call) => { for (id, typ) in call.ret_params.iter_mut() { let is_dst = match typ { ast::FnArgumentType::Reg(_) => true, ast::FnArgumentType::Param(_) => false, ast::FnArgumentType::Shared => false, }; *id = f(*id, is_dst); } call.func = f(call.func, false); for (id, _) in call.param_list.iter_mut() { *id = f(*id, false); } Statement::Call(call) } Statement::Conditional(mut conditional) => { conditional.predicate = f(conditional.predicate, false); conditional.if_true = f(conditional.if_true, false); conditional.if_false = f(conditional.if_false, false); Statement::Conditional(conditional) } Statement::Conversion(mut conv) => { conv.dst = f(conv.dst, true); conv.src = f(conv.src, false); Statement::Conversion(conv) } Statement::Constant(mut constant) => { constant.dst = f(constant.dst, true); Statement::Constant(constant) } Statement::RetValue(data, id) => { let id = f(id, false); Statement::RetValue(data, id) } Statement::PtrAccess(PtrAccess { underlying_type, state_space, dst, ptr_src, offset_src: constant_src, }) => { let dst = f(dst, true); let ptr_src = f(ptr_src, false); let constant_src = f(constant_src, false); Statement::PtrAccess(PtrAccess { underlying_type, state_space, dst, ptr_src, offset_src: constant_src, }) } Statement::RepackVector(repack) => { let packed = f(repack.packed, !repack.is_extract); let unpacked = repack .unpacked .iter() .map(|id| f(*id, repack.is_extract)) .collect(); Statement::RepackVector(RepackVectorDetails { packed, unpacked, ..repack }) } } } } struct LoadVarDetails { arg: ast::Arg2, typ: ast::Type, // (index, vector_width) // HACK ALERT // For some reason IGC explodes when you try to load from builtin vectors // using OpInBoundsAccessChain, the one true way to do it is to // OpLoad+OpCompositeExtract member_index: Option<(u8, Option)>, } struct StoreVarDetails { arg: ast::Arg2St, typ: ast::Type, member_index: Option, } struct RepackVectorDetails { is_extract: bool, typ: ast::ScalarType, packed: spirv::Word, unpacked: Vec, vector_sema: ArgumentSemantics, } impl RepackVectorDetails { fn map< From: ArgParamsEx, To: ArgParamsEx, V: ArgumentMapVisitor, >( self, visitor: &mut V, ) -> Result { let scalar = visitor.id( ArgumentDescriptor { op: self.packed, is_dst: !self.is_extract, sema: ArgumentSemantics::Default, }, Some(&ast::Type::Vector(self.typ, self.unpacked.len() as u8)), )?; let scalar_type = self.typ; let is_extract = self.is_extract; let vector_sema = self.vector_sema; let vector = self .unpacked .into_iter() .map(|id| { visitor.id( ArgumentDescriptor { op: id, is_dst: is_extract, sema: vector_sema, }, Some(&ast::Type::Scalar(scalar_type)), ) }) .collect::>()?; Ok(RepackVectorDetails { is_extract, typ: self.typ, packed: scalar, unpacked: vector, vector_sema, }) } } impl, U: ArgParamsEx> Visitable for RepackVectorDetails { fn visit( self, visitor: &mut impl ArgumentMapVisitor, ) -> Result, U>, TranslateError> { Ok(Statement::RepackVector(self.map::<_, _, _>(visitor)?)) } } struct ResolvedCall { pub uniform: bool, pub ret_params: Vec<(P::Id, ast::FnArgumentType)>, pub func: P::Id, pub param_list: Vec<(P::Operand, ast::FnArgumentType)>, } impl ResolvedCall { fn cast>(self) -> ResolvedCall { ResolvedCall { uniform: self.uniform, ret_params: self.ret_params, func: self.func, param_list: self.param_list, } } } impl> ResolvedCall { fn map, V: ArgumentMapVisitor>( self, visitor: &mut V, ) -> Result, TranslateError> { let ret_params = self .ret_params .into_iter() .map::, _>(|(id, typ)| { let new_id = visitor.id( ArgumentDescriptor { op: id, is_dst: !typ.is_param(), sema: typ.semantics(), }, Some(&typ.to_func_type()), )?; Ok((new_id, typ)) }) .collect::, _>>()?; let func = visitor.id( ArgumentDescriptor { op: self.func, is_dst: false, sema: ArgumentSemantics::Default, }, None, )?; let param_list = self .param_list .into_iter() .map::, _>(|(id, typ)| { let new_id = visitor.operand( ArgumentDescriptor { op: id, is_dst: false, sema: typ.semantics(), }, &typ.to_func_type(), )?; Ok((new_id, typ)) }) .collect::, _>>()?; Ok(ResolvedCall { uniform: self.uniform, ret_params, func, param_list, }) } } impl, U: ArgParamsEx> Visitable for ResolvedCall { fn visit( self, visitor: &mut impl ArgumentMapVisitor, ) -> Result, U>, TranslateError> { Ok(Statement::Call(self.map(visitor)?)) } } impl> PtrAccess

{ fn map, V: ArgumentMapVisitor>( self, visitor: &mut V, ) -> Result, TranslateError> { let sema = match self.state_space { ast::LdStateSpace::Const | ast::LdStateSpace::Global | ast::LdStateSpace::Shared | ast::LdStateSpace::Generic => ArgumentSemantics::PhysicalPointer, ast::LdStateSpace::Local | ast::LdStateSpace::Param => { ArgumentSemantics::RegisterPointer } }; let ptr_type = ast::Type::Pointer(self.underlying_type.clone(), self.state_space); let new_dst = visitor.id( ArgumentDescriptor { op: self.dst, is_dst: true, sema, }, Some(&ptr_type), )?; let new_ptr_src = visitor.id( ArgumentDescriptor { op: self.ptr_src, is_dst: false, sema, }, Some(&ptr_type), )?; let new_constant_src = visitor.operand( ArgumentDescriptor { op: self.offset_src, is_dst: false, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::S64), )?; Ok(PtrAccess { underlying_type: self.underlying_type, state_space: self.state_space, dst: new_dst, ptr_src: new_ptr_src, offset_src: new_constant_src, }) } } impl, U: ArgParamsEx> Visitable for PtrAccess { fn visit( self, visitor: &mut impl ArgumentMapVisitor, ) -> Result, U>, TranslateError> { Ok(Statement::PtrAccess(self.map(visitor)?)) } } pub trait ArgParamsEx: ast::ArgParams + Sized { fn get_fn_decl<'x, 'b>( id: &Self::Id, decl: &'b GlobalFnDeclResolver<'x, 'b>, ) -> Result<&'b FnDecl, TranslateError>; } impl<'input> ArgParamsEx for ast::ParsedArgParams<'input> { fn get_fn_decl<'x, 'b>( id: &Self::Id, decl: &'b GlobalFnDeclResolver<'x, 'b>, ) -> Result<&'b FnDecl, TranslateError> { decl.get_fn_decl_str(id) } } enum NormalizedArgParams {} impl ast::ArgParams for NormalizedArgParams { type Id = spirv::Word; type Operand = ast::Operand; } impl ArgParamsEx for NormalizedArgParams { fn get_fn_decl<'a, 'b>( id: &Self::Id, decl: &'b GlobalFnDeclResolver<'a, 'b>, ) -> Result<&'b FnDecl, TranslateError> { decl.get_fn_decl(*id) } } type NormalizedStatement = Statement< ( Option>, ast::Instruction, ), NormalizedArgParams, >; type UnconditionalStatement = Statement, NormalizedArgParams>; enum TypedArgParams {} impl ast::ArgParams for TypedArgParams { type Id = spirv::Word; type Operand = TypedOperand; } impl ArgParamsEx for TypedArgParams { fn get_fn_decl<'a, 'b>( id: &Self::Id, decl: &'b GlobalFnDeclResolver<'a, 'b>, ) -> Result<&'b FnDecl, TranslateError> { decl.get_fn_decl(*id) } } #[derive(Copy, Clone)] enum TypedOperand { Reg(spirv::Word), RegOffset(spirv::Word, i32), Imm(ast::ImmediateValue), VecMember(spirv::Word, u8), } impl TypedOperand { fn upcast(self) -> ast::Operand { match self { TypedOperand::Reg(reg) => ast::Operand::Reg(reg), TypedOperand::RegOffset(reg, idx) => ast::Operand::RegOffset(reg, idx), TypedOperand::Imm(x) => ast::Operand::Imm(x), TypedOperand::VecMember(vec, idx) => ast::Operand::VecMember(vec, idx), } } } type TypedStatement = Statement, TypedArgParams>; enum ExpandedArgParams {} type ExpandedStatement = Statement, ExpandedArgParams>; impl ast::ArgParams for ExpandedArgParams { type Id = spirv::Word; type Operand = spirv::Word; } impl ArgParamsEx for ExpandedArgParams { fn get_fn_decl<'a, 'b>( id: &Self::Id, decl: &'b GlobalFnDeclResolver<'a, 'b>, ) -> Result<&'b FnDecl, TranslateError> { decl.get_fn_decl(*id) } } enum Directive<'input> { Variable(ast::Variable), Method(Function<'input>), } struct Function<'input> { pub func_decl: ast::MethodDecl<'input, spirv::Word>, pub spirv_decl: SpirvMethodDecl<'input>, pub globals: Vec>, pub body: Option>, import_as: Option, } pub trait ArgumentMapVisitor { fn id( &mut self, desc: ArgumentDescriptor, typ: Option<&ast::Type>, ) -> Result; fn operand( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result; } impl ArgumentMapVisitor for T where T: FnMut( ArgumentDescriptor, Option<&ast::Type>, ) -> Result, { fn id( &mut self, desc: ArgumentDescriptor, t: Option<&ast::Type>, ) -> Result { self(desc, t) } fn operand( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result { self(desc, Some(typ)) } } impl<'a, T> ArgumentMapVisitor, NormalizedArgParams> for T where T: FnMut(&str) -> Result, { fn id( &mut self, desc: ArgumentDescriptor<&str>, _: Option<&ast::Type>, ) -> Result { self(desc.op) } fn operand( &mut self, desc: ArgumentDescriptor>, typ: &ast::Type, ) -> Result, TranslateError> { Ok(match desc.op { ast::Operand::Reg(id) => ast::Operand::Reg(self(id)?), ast::Operand::RegOffset(id, imm) => ast::Operand::RegOffset(self(id)?, imm), ast::Operand::Imm(imm) => ast::Operand::Imm(imm), ast::Operand::VecMember(id, member) => ast::Operand::VecMember(self(id)?, member), ast::Operand::VecPack(ref ids) => ast::Operand::VecPack( ids.into_iter() .map(|id| self.id(desc.new_op(id), Some(typ))) .collect::, _>>()?, ), }) } } pub struct ArgumentDescriptor { op: Op, is_dst: bool, sema: ArgumentSemantics, } pub struct PtrAccess { underlying_type: ast::PointerType, state_space: ast::LdStateSpace, dst: spirv::Word, ptr_src: spirv::Word, offset_src: P::Operand, } #[derive(Copy, Clone, PartialEq, Eq, Debug)] pub enum ArgumentSemantics { // normal register access Default, // normal register access with relaxed conversion rules (ld/st) DefaultRelaxed, // st/ld global PhysicalPointer, // st/ld .param, .local RegisterPointer, // mov of .local/.global variables Address, } impl ArgumentDescriptor { fn new_op(&self, u: U) -> ArgumentDescriptor { ArgumentDescriptor { op: u, is_dst: self.is_dst, sema: self.sema, } } } impl ast::Instruction { fn map>( self, visitor: &mut V, ) -> Result, TranslateError> { Ok(match self { ast::Instruction::Abs(d, arg) => { ast::Instruction::Abs(d, arg.map(visitor, &ast::Type::Scalar(d.typ))?) } // Call instruction is converted to a call statement early on ast::Instruction::Call(_) => return Err(error_unreachable()), ast::Instruction::Ld(d, a) => { let new_args = a.map(visitor, &d)?; ast::Instruction::Ld(d, new_args) } ast::Instruction::Mov(d, a) => { let mapped = a.map(visitor, &d)?; ast::Instruction::Mov(d, mapped) } ast::Instruction::Mul(d, a) => { let inst_type = d.get_type(); let is_wide = d.is_wide(); ast::Instruction::Mul(d, a.map_non_shift(visitor, &inst_type, is_wide)?) } ast::Instruction::Add(d, a) => { let inst_type = d.get_type(); ast::Instruction::Add(d, a.map_non_shift(visitor, &inst_type, false)?) } ast::Instruction::Setp(d, a) => { let inst_type = d.typ; ast::Instruction::Setp(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) } ast::Instruction::SetpBool(d, a) => { let inst_type = d.typ; ast::Instruction::SetpBool(d, a.map(visitor, &ast::Type::Scalar(inst_type))?) } ast::Instruction::Not(t, a) => ast::Instruction::Not(t, a.map(visitor, &t.to_type())?), ast::Instruction::Cvt(d, a) => { let (dst_t, src_t) = match &d { ast::CvtDetails::FloatFromFloat(desc) => ( ast::Type::Scalar(desc.dst.into()), ast::Type::Scalar(desc.src.into()), ), ast::CvtDetails::FloatFromInt(desc) => ( ast::Type::Scalar(desc.dst.into()), ast::Type::Scalar(desc.src.into()), ), ast::CvtDetails::IntFromFloat(desc) => ( ast::Type::Scalar(desc.dst.into()), ast::Type::Scalar(desc.src.into()), ), ast::CvtDetails::IntFromInt(desc) => ( ast::Type::Scalar(desc.dst.into()), ast::Type::Scalar(desc.src.into()), ), }; ast::Instruction::Cvt(d, a.map_different_types(visitor, &dst_t, &src_t)?) } ast::Instruction::Shl(t, a) => { ast::Instruction::Shl(t, a.map_shift(visitor, &t.to_type())?) } ast::Instruction::Shr(t, a) => { ast::Instruction::Shr(t, a.map_shift(visitor, &ast::Type::Scalar(t.into()))?) } ast::Instruction::St(d, a) => { let new_args = a.map(visitor, &d)?; ast::Instruction::St(d, new_args) } ast::Instruction::Bra(d, a) => ast::Instruction::Bra(d, a.map(visitor, None)?), ast::Instruction::Ret(d) => ast::Instruction::Ret(d), ast::Instruction::Cvta(d, a) => { let inst_type = ast::Type::Scalar(ast::ScalarType::B64); ast::Instruction::Cvta(d, a.map(visitor, &inst_type)?) } ast::Instruction::Mad(d, a) => { let inst_type = d.get_type(); let is_wide = d.is_wide(); ast::Instruction::Mad(d, a.map(visitor, &inst_type, is_wide)?) } ast::Instruction::Or(t, a) => ast::Instruction::Or( t, a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?, ), ast::Instruction::Sub(d, a) => { let typ = d.get_type(); ast::Instruction::Sub(d, a.map_non_shift(visitor, &typ, false)?) } ast::Instruction::Min(d, a) => { let typ = d.get_type(); ast::Instruction::Min(d, a.map_non_shift(visitor, &typ, false)?) } ast::Instruction::Max(d, a) => { let typ = d.get_type(); ast::Instruction::Max(d, a.map_non_shift(visitor, &typ, false)?) } ast::Instruction::Rcp(d, a) => { let typ = ast::Type::Scalar(if d.is_f64 { ast::ScalarType::F64 } else { ast::ScalarType::F32 }); ast::Instruction::Rcp(d, a.map(visitor, &typ)?) } ast::Instruction::And(t, a) => ast::Instruction::And( t, a.map_non_shift(visitor, &ast::Type::Scalar(t.into()), false)?, ), ast::Instruction::Selp(t, a) => ast::Instruction::Selp(t, a.map_selp(visitor, t)?), ast::Instruction::Bar(d, a) => ast::Instruction::Bar(d, a.map(visitor)?), ast::Instruction::Atom(d, a) => { ast::Instruction::Atom(d, a.map_atom(visitor, d.inner.get_type(), d.space)?) } ast::Instruction::AtomCas(d, a) => { ast::Instruction::AtomCas(d, a.map_atom(visitor, d.typ, d.space)?) } ast::Instruction::Div(d, a) => { ast::Instruction::Div(d, a.map_non_shift(visitor, &d.get_type(), false)?) } ast::Instruction::Sqrt(d, a) => { ast::Instruction::Sqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?) } ast::Instruction::Rsqrt(d, a) => { ast::Instruction::Rsqrt(d, a.map(visitor, &ast::Type::Scalar(d.typ.into()))?) } ast::Instruction::Neg(d, a) => { ast::Instruction::Neg(d, a.map(visitor, &ast::Type::Scalar(d.typ))?) } ast::Instruction::Sin { flush_to_zero, arg } => { let typ = ast::Type::Scalar(ast::ScalarType::F32); ast::Instruction::Sin { flush_to_zero, arg: arg.map(visitor, &typ)?, } } ast::Instruction::Cos { flush_to_zero, arg } => { let typ = ast::Type::Scalar(ast::ScalarType::F32); ast::Instruction::Cos { flush_to_zero, arg: arg.map(visitor, &typ)?, } } ast::Instruction::Lg2 { flush_to_zero, arg } => { let typ = ast::Type::Scalar(ast::ScalarType::F32); ast::Instruction::Lg2 { flush_to_zero, arg: arg.map(visitor, &typ)?, } } ast::Instruction::Ex2 { flush_to_zero, arg } => { let typ = ast::Type::Scalar(ast::ScalarType::F32); ast::Instruction::Ex2 { flush_to_zero, arg: arg.map(visitor, &typ)?, } } ast::Instruction::Clz { typ, arg } => { let dst_type = ast::Type::Scalar(ast::ScalarType::B32); let src_type = ast::Type::Scalar(typ.into()); ast::Instruction::Clz { typ, arg: arg.map_different_types(visitor, &dst_type, &src_type)?, } } ast::Instruction::Brev { typ, arg } => { let full_type = ast::Type::Scalar(typ.into()); ast::Instruction::Brev { typ, arg: arg.map(visitor, &full_type)?, } } ast::Instruction::Popc { typ, arg } => { let dst_type = ast::Type::Scalar(ast::ScalarType::B32); let src_type = ast::Type::Scalar(typ.into()); ast::Instruction::Popc { typ, arg: arg.map_different_types(visitor, &dst_type, &src_type)?, } } ast::Instruction::Xor { typ, arg } => { let full_type = ast::Type::Scalar(typ.into()); ast::Instruction::Xor { typ, arg: arg.map_non_shift(visitor, &full_type, false)?, } } ast::Instruction::Bfe { typ, arg } => { let full_type = ast::Type::Scalar(typ.into()); ast::Instruction::Bfe { typ, arg: arg.map_bfe(visitor, &full_type)?, } } ast::Instruction::Rem { typ, arg } => { let full_type = ast::Type::Scalar(typ.into()); ast::Instruction::Rem { typ, arg: arg.map_non_shift(visitor, &full_type, false)?, } } }) } } impl Visitable for ast::Instruction { fn visit( self, visitor: &mut impl ArgumentMapVisitor, ) -> Result, U>, TranslateError> { Ok(Statement::Instruction(self.map(visitor)?)) } } impl ImplicitConversion { fn map< T: ArgParamsEx, U: ArgParamsEx, V: ArgumentMapVisitor, >( self, visitor: &mut V, ) -> Result, U>, TranslateError> { let new_dst = visitor.id( ArgumentDescriptor { op: self.dst, is_dst: true, sema: self.dst_sema, }, Some(&self.to), )?; let new_src = visitor.id( ArgumentDescriptor { op: self.src, is_dst: false, sema: self.src_sema, }, Some(&self.from), )?; Ok(Statement::Conversion({ ImplicitConversion { src: new_src, dst: new_dst, ..self } })) } } impl, To: ArgParamsEx> Visitable for ImplicitConversion { fn visit( self, visitor: &mut impl ArgumentMapVisitor, ) -> Result, To>, TranslateError> { Ok(self.map(visitor)?) } } impl ArgumentMapVisitor for T where T: FnMut( ArgumentDescriptor, Option<&ast::Type>, ) -> Result, { fn id( &mut self, desc: ArgumentDescriptor, t: Option<&ast::Type>, ) -> Result { self(desc, t) } fn operand( &mut self, desc: ArgumentDescriptor, typ: &ast::Type, ) -> Result { Ok(match desc.op { TypedOperand::Reg(id) => TypedOperand::Reg(self(desc.new_op(id), Some(typ))?), TypedOperand::Imm(imm) => TypedOperand::Imm(imm), TypedOperand::RegOffset(id, imm) => { TypedOperand::RegOffset(self(desc.new_op(id), Some(typ))?, imm) } TypedOperand::VecMember(reg, index) => { let scalar_type = match typ { ast::Type::Scalar(scalar_t) => *scalar_t, _ => return Err(error_unreachable()), }; let vec_type = ast::Type::Vector(scalar_type, index + 1); TypedOperand::VecMember(self(desc.new_op(reg), Some(&vec_type))?, index) } }) } } impl ast::Type { fn widen(self) -> Result { match self { ast::Type::Scalar(scalar) => { let kind = scalar.kind(); let width = scalar.size_of(); if (kind != ScalarKind::Signed && kind != ScalarKind::Unsigned && kind != ScalarKind::Bit) || (width == 8) { return Err(TranslateError::MismatchedType); } Ok(ast::Type::Scalar(ast::ScalarType::from_parts( width * 2, kind, ))) } _ => Err(error_unreachable()), } } fn to_parts(&self) -> TypeParts { match self { ast::Type::Scalar(scalar) => TypeParts { kind: TypeKind::Scalar, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), state_space: ast::LdStateSpace::Global, }, ast::Type::Vector(scalar, components) => TypeParts { kind: TypeKind::Vector, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*components as u32], state_space: ast::LdStateSpace::Global, }, ast::Type::Array(scalar, components) => TypeParts { kind: TypeKind::Array, scalar_kind: scalar.kind(), width: scalar.size_of(), components: components.clone(), state_space: ast::LdStateSpace::Global, }, ast::Type::Pointer(ast::PointerType::Scalar(scalar), state_space) => TypeParts { kind: TypeKind::PointerScalar, scalar_kind: scalar.kind(), width: scalar.size_of(), components: Vec::new(), state_space: *state_space, }, ast::Type::Pointer(ast::PointerType::Vector(scalar, len), state_space) => TypeParts { kind: TypeKind::PointerVector, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*len as u32], state_space: *state_space, }, ast::Type::Pointer(ast::PointerType::Array(scalar, components), state_space) => { TypeParts { kind: TypeKind::PointerArray, scalar_kind: scalar.kind(), width: scalar.size_of(), components: components.clone(), state_space: *state_space, } } ast::Type::Pointer(ast::PointerType::Pointer(scalar, inner_space), state_space) => { TypeParts { kind: TypeKind::PointerPointer, scalar_kind: scalar.kind(), width: scalar.size_of(), components: vec![*inner_space as u32], state_space: *state_space, } } } } fn from_parts(t: TypeParts) -> Self { match t.kind { TypeKind::Scalar => { ast::Type::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)) } TypeKind::Vector => ast::Type::Vector( ast::ScalarType::from_parts(t.width, t.scalar_kind), t.components[0] as u8, ), TypeKind::Array => ast::Type::Array( ast::ScalarType::from_parts(t.width, t.scalar_kind), t.components, ), TypeKind::PointerScalar => ast::Type::Pointer( ast::PointerType::Scalar(ast::ScalarType::from_parts(t.width, t.scalar_kind)), t.state_space, ), TypeKind::PointerVector => ast::Type::Pointer( ast::PointerType::Vector( ast::ScalarType::from_parts(t.width, t.scalar_kind), t.components[0] as u8, ), t.state_space, ), TypeKind::PointerArray => ast::Type::Pointer( ast::PointerType::Array( ast::ScalarType::from_parts(t.width, t.scalar_kind), t.components, ), t.state_space, ), TypeKind::PointerPointer => ast::Type::Pointer( ast::PointerType::Pointer( ast::ScalarType::from_parts(t.width, t.scalar_kind), unsafe { mem::transmute::<_, ast::LdStateSpace>(t.components[0] as u8) }, ), t.state_space, ), } } pub fn size_of(&self) -> usize { match self { ast::Type::Scalar(typ) => typ.size_of() as usize, ast::Type::Vector(typ, len) => (typ.size_of() as usize) * (*len as usize), ast::Type::Array(typ, len) => len .iter() .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), ast::Type::Pointer(_, _) => mem::size_of::(), } } } #[derive(Eq, PartialEq, Clone)] struct TypeParts { kind: TypeKind, scalar_kind: ScalarKind, width: u8, components: Vec, state_space: ast::LdStateSpace, } #[derive(Eq, PartialEq, Copy, Clone)] enum TypeKind { Scalar, Vector, Array, PointerScalar, PointerVector, PointerArray, PointerPointer, } impl ast::Instruction { fn jump_target(&self) -> Option { match self { ast::Instruction::Bra(_, a) => Some(a.src), _ => None, } } // .wide instructions don't support ftz, so it's enough to just look at the // type declared by the instruction fn flush_to_zero(&self) -> Option<(bool, u8)> { match self { ast::Instruction::Ld(_, _) => None, ast::Instruction::St(_, _) => None, ast::Instruction::Mov(_, _) => None, ast::Instruction::Not(_, _) => None, ast::Instruction::Bra(_, _) => None, ast::Instruction::Shl(_, _) => None, ast::Instruction::Shr(_, _) => None, ast::Instruction::Ret(_) => None, ast::Instruction::Call(_) => None, ast::Instruction::Or(_, _) => None, ast::Instruction::And(_, _) => None, ast::Instruction::Cvta(_, _) => None, ast::Instruction::Selp(_, _) => None, ast::Instruction::Bar(_, _) => None, ast::Instruction::Atom(_, _) => None, ast::Instruction::AtomCas(_, _) => None, ast::Instruction::Sub(ast::ArithDetails::Signed(_), _) => None, ast::Instruction::Sub(ast::ArithDetails::Unsigned(_), _) => None, ast::Instruction::Add(ast::ArithDetails::Signed(_), _) => None, ast::Instruction::Add(ast::ArithDetails::Unsigned(_), _) => None, ast::Instruction::Mul(ast::MulDetails::Unsigned(_), _) => None, ast::Instruction::Mul(ast::MulDetails::Signed(_), _) => None, ast::Instruction::Mad(ast::MulDetails::Unsigned(_), _) => None, ast::Instruction::Mad(ast::MulDetails::Signed(_), _) => None, ast::Instruction::Min(ast::MinMaxDetails::Signed(_), _) => None, ast::Instruction::Min(ast::MinMaxDetails::Unsigned(_), _) => None, ast::Instruction::Max(ast::MinMaxDetails::Signed(_), _) => None, ast::Instruction::Max(ast::MinMaxDetails::Unsigned(_), _) => None, ast::Instruction::Cvt(ast::CvtDetails::IntFromInt(_), _) => None, ast::Instruction::Cvt(ast::CvtDetails::FloatFromInt(_), _) => None, ast::Instruction::Div(ast::DivDetails::Unsigned(_), _) => None, ast::Instruction::Div(ast::DivDetails::Signed(_), _) => None, ast::Instruction::Clz { .. } => None, ast::Instruction::Brev { .. } => None, ast::Instruction::Popc { .. } => None, ast::Instruction::Xor { .. } => None, ast::Instruction::Bfe { .. } => None, ast::Instruction::Rem { .. } => None, ast::Instruction::Sub(ast::ArithDetails::Float(float_control), _) | ast::Instruction::Add(ast::ArithDetails::Float(float_control), _) | ast::Instruction::Mul(ast::MulDetails::Float(float_control), _) | ast::Instruction::Mad(ast::MulDetails::Float(float_control), _) => float_control .flush_to_zero .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())), ast::Instruction::Setp(details, _) => details .flush_to_zero .map(|ftz| (ftz, details.typ.size_of())), ast::Instruction::SetpBool(details, _) => details .flush_to_zero .map(|ftz| (ftz, details.typ.size_of())), ast::Instruction::Abs(details, _) => details .flush_to_zero .map(|ftz| (ftz, details.typ.size_of())), ast::Instruction::Min(ast::MinMaxDetails::Float(float_control), _) | ast::Instruction::Max(ast::MinMaxDetails::Float(float_control), _) => float_control .flush_to_zero .map(|ftz| (ftz, ast::ScalarType::from(float_control.typ).size_of())), ast::Instruction::Rcp(details, _) => details .flush_to_zero .map(|ftz| (ftz, if details.is_f64 { 8 } else { 4 })), // Modifier .ftz can only be specified when either .dtype or .atype // is .f32 and applies only to single precision (.f32) inputs and results. ast::Instruction::Cvt( ast::CvtDetails::FloatFromFloat(ast::CvtDesc { flush_to_zero, .. }), _, ) | ast::Instruction::Cvt( ast::CvtDetails::IntFromFloat(ast::CvtDesc { flush_to_zero, .. }), _, ) => flush_to_zero.map(|ftz| (ftz, 4)), ast::Instruction::Div(ast::DivDetails::Float(details), _) => details .flush_to_zero .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())), ast::Instruction::Sqrt(details, _) => details .flush_to_zero .map(|ftz| (ftz, ast::ScalarType::from(details.typ).size_of())), ast::Instruction::Rsqrt(details, _) => Some(( details.flush_to_zero, ast::ScalarType::from(details.typ).size_of(), )), ast::Instruction::Neg(details, _) => details .flush_to_zero .map(|ftz| (ftz, details.typ.size_of())), ast::Instruction::Sin { flush_to_zero, .. } | ast::Instruction::Cos { flush_to_zero, .. } | ast::Instruction::Lg2 { flush_to_zero, .. } | ast::Instruction::Ex2 { flush_to_zero, .. } => { Some((*flush_to_zero, mem::size_of::() as u8)) } } } } type Arg2 = ast::Arg2; type Arg2St = ast::Arg2St; struct ConstantDefinition { pub dst: spirv::Word, pub typ: ast::ScalarType, pub value: ast::ImmediateValue, } struct BrachCondition { predicate: spirv::Word, if_true: spirv::Word, if_false: spirv::Word, } #[derive(Clone)] struct ImplicitConversion { src: spirv::Word, dst: spirv::Word, from: ast::Type, to: ast::Type, kind: ConversionKind, src_sema: ArgumentSemantics, dst_sema: ArgumentSemantics, } #[derive(PartialEq, Copy, Clone)] enum ConversionKind { Default, // zero-extend/chop/bitcast depending on types SignExtend, BitToPtr(ast::LdStateSpace), PtrToBit(ast::UIntType), PtrToPtr { spirv_ptr: bool }, } impl ast::PredAt { fn map_variable Result>( self, f: &mut F, ) -> Result, TranslateError> { let new_label = f(self.label)?; Ok(ast::PredAt { not: self.not, label: new_label, }) } } impl<'a> ast::Instruction> { fn map_variable Result>( self, f: &mut F, ) -> Result, TranslateError> { match self { ast::Instruction::Call(call) => { let call_inst = ast::CallInst { uniform: call.uniform, ret_params: call .ret_params .into_iter() .map(|p| f(p)) .collect::>()?, func: f(call.func)?, param_list: call .param_list .into_iter() .map(|p| p.map_variable(f)) .collect::>()?, }; Ok(ast::Instruction::Call(call_inst)) } i => i.map(f), } } } impl ast::Arg1 { fn map>( self, visitor: &mut V, t: Option<&ast::Type>, ) -> Result, TranslateError> { let new_src = visitor.id( ArgumentDescriptor { op: self.src, is_dst: false, sema: ArgumentSemantics::Default, }, t, )?; Ok(ast::Arg1 { src: new_src }) } } impl ast::Arg1Bar { fn map>( self, visitor: &mut V, ) -> Result, TranslateError> { let new_src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::U32), )?; Ok(ast::Arg1Bar { src: new_src }) } } impl ast::Arg2 { fn map>( self, visitor: &mut V, t: &ast::Type, ) -> Result, TranslateError> { let new_dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, t, )?; let new_src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, sema: ArgumentSemantics::Default, }, t, )?; Ok(ast::Arg2 { dst: new_dst, src: new_src, }) } fn map_different_types>( self, visitor: &mut V, dst_t: &ast::Type, src_t: &ast::Type, ) -> Result, TranslateError> { let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, dst_t, )?; let src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, sema: ArgumentSemantics::Default, }, src_t, )?; Ok(ast::Arg2 { dst, src }) } } impl ast::Arg2Ld { fn map>( self, visitor: &mut V, details: &ast::LdDetails, ) -> Result, TranslateError> { let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::DefaultRelaxed, }, &ast::Type::from(details.typ.clone()), )?; let is_logical_ptr = details.state_space == ast::LdStateSpace::Param || details.state_space == ast::LdStateSpace::Local; let src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, sema: if is_logical_ptr { ArgumentSemantics::RegisterPointer } else { ArgumentSemantics::PhysicalPointer }, }, &ast::Type::Pointer( ast::PointerType::from(details.typ.clone()), details.state_space, ), )?; Ok(ast::Arg2Ld { dst, src }) } } impl ast::Arg2St { fn map>( self, visitor: &mut V, details: &ast::StData, ) -> Result, TranslateError> { let is_logical_ptr = details.state_space == ast::StStateSpace::Param || details.state_space == ast::StStateSpace::Local; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, sema: if is_logical_ptr { ArgumentSemantics::RegisterPointer } else { ArgumentSemantics::PhysicalPointer }, }, &ast::Type::Pointer( ast::PointerType::from(details.typ.clone()), details.state_space.to_ld_ss(), ), )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, sema: ArgumentSemantics::DefaultRelaxed, }, &details.typ.clone().into(), )?; Ok(ast::Arg2St { src1, src2 }) } } impl ast::Arg2Mov { fn map>( self, visitor: &mut V, details: &ast::MovDetails, ) -> Result, TranslateError> { let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, &details.typ.clone().into(), )?; let src = visitor.operand( ArgumentDescriptor { op: self.src, is_dst: false, sema: if details.src_is_address { ArgumentSemantics::Address } else { ArgumentSemantics::Default }, }, &details.typ.clone().into(), )?; Ok(ast::Arg2Mov { dst, src }) } } impl ast::Arg3 { fn map_non_shift>( self, visitor: &mut V, typ: &ast::Type, is_wide: bool, ) -> Result, TranslateError> { let wide_type = if is_wide { Some(typ.clone().widen()?) } else { None }; let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, wide_type.as_ref().unwrap_or(typ), )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, sema: ArgumentSemantics::Default, }, typ, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, sema: ArgumentSemantics::Default, }, typ, )?; Ok(ast::Arg3 { dst, src1, src2 }) } fn map_shift>( self, visitor: &mut V, t: &ast::Type, ) -> Result, TranslateError> { let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, t, )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, sema: ArgumentSemantics::Default, }, t, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::U32), )?; Ok(ast::Arg3 { dst, src1, src2 }) } fn map_atom>( self, visitor: &mut V, t: ast::ScalarType, state_space: ast::AtomSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, sema: ArgumentSemantics::PhysicalPointer, }, &ast::Type::Pointer( ast::PointerType::Scalar(scalar_type), state_space.to_ld_ss(), ), )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), )?; Ok(ast::Arg3 { dst, src1, src2 }) } } impl ast::Arg4 { fn map>( self, visitor: &mut V, t: &ast::Type, is_wide: bool, ) -> Result, TranslateError> { let wide_type = if is_wide { Some(t.clone().widen()?) } else { None }; let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, wide_type.as_ref().unwrap_or(t), )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, sema: ArgumentSemantics::Default, }, t, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, sema: ArgumentSemantics::Default, }, t, )?; let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, sema: ArgumentSemantics::Default, }, t, )?; Ok(ast::Arg4 { dst, src1, src2, src3, }) } fn map_selp>( self, visitor: &mut V, t: ast::SelpType, ) -> Result, TranslateError> { let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(t.into()), )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(t.into()), )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(t.into()), )?; let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::Pred), )?; Ok(ast::Arg4 { dst, src1, src2, src3, }) } fn map_atom>( self, visitor: &mut V, t: ast::BitType, state_space: ast::AtomSpace, ) -> Result, TranslateError> { let scalar_type = ast::ScalarType::from(t); let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, sema: ArgumentSemantics::PhysicalPointer, }, &ast::Type::Pointer( ast::PointerType::Scalar(scalar_type), state_space.to_ld_ss(), ), )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), )?; let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(scalar_type), )?; Ok(ast::Arg4 { dst, src1, src2, src3, }) } fn map_bfe>( self, visitor: &mut V, typ: &ast::Type, ) -> Result, TranslateError> { let dst = visitor.operand( ArgumentDescriptor { op: self.dst, is_dst: true, sema: ArgumentSemantics::Default, }, typ, )?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, sema: ArgumentSemantics::Default, }, typ, )?; let u32_type = ast::Type::Scalar(ast::ScalarType::U32); let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, sema: ArgumentSemantics::Default, }, &u32_type, )?; let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, sema: ArgumentSemantics::Default, }, &u32_type, )?; Ok(ast::Arg4 { dst, src1, src2, src3, }) } } impl ast::Arg4Setp { fn map>( self, visitor: &mut V, t: &ast::Type, ) -> Result, TranslateError> { let dst1 = visitor.id( ArgumentDescriptor { op: self.dst1, is_dst: true, sema: ArgumentSemantics::Default, }, Some(&ast::Type::Scalar(ast::ScalarType::Pred)), )?; let dst2 = self .dst2 .map(|dst2| { visitor.id( ArgumentDescriptor { op: dst2, is_dst: true, sema: ArgumentSemantics::Default, }, Some(&ast::Type::Scalar(ast::ScalarType::Pred)), ) }) .transpose()?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, sema: ArgumentSemantics::Default, }, t, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, sema: ArgumentSemantics::Default, }, t, )?; Ok(ast::Arg4Setp { dst1, dst2, src1, src2, }) } } impl ast::Arg5Setp { fn map>( self, visitor: &mut V, t: &ast::Type, ) -> Result, TranslateError> { let dst1 = visitor.id( ArgumentDescriptor { op: self.dst1, is_dst: true, sema: ArgumentSemantics::Default, }, Some(&ast::Type::Scalar(ast::ScalarType::Pred)), )?; let dst2 = self .dst2 .map(|dst2| { visitor.id( ArgumentDescriptor { op: dst2, is_dst: true, sema: ArgumentSemantics::Default, }, Some(&ast::Type::Scalar(ast::ScalarType::Pred)), ) }) .transpose()?; let src1 = visitor.operand( ArgumentDescriptor { op: self.src1, is_dst: false, sema: ArgumentSemantics::Default, }, t, )?; let src2 = visitor.operand( ArgumentDescriptor { op: self.src2, is_dst: false, sema: ArgumentSemantics::Default, }, t, )?; let src3 = visitor.operand( ArgumentDescriptor { op: self.src3, is_dst: false, sema: ArgumentSemantics::Default, }, &ast::Type::Scalar(ast::ScalarType::Pred), )?; Ok(ast::Arg5Setp { dst1, dst2, src1, src2, src3, }) } } impl ast::Operand { fn map_variable Result>( self, f: &mut F, ) -> Result, TranslateError> { Ok(match self { ast::Operand::Reg(reg) => ast::Operand::Reg(f(reg)?), ast::Operand::RegOffset(reg, offset) => ast::Operand::RegOffset(f(reg)?, offset), ast::Operand::Imm(x) => ast::Operand::Imm(x), ast::Operand::VecMember(reg, idx) => ast::Operand::VecMember(f(reg)?, idx), ast::Operand::VecPack(vec) => { ast::Operand::VecPack(vec.into_iter().map(f).collect::>()?) } }) } } impl ast::Operand { fn unwrap_reg(&self) -> Result { match self { ast::Operand::Reg(reg) => Ok(*reg), _ => Err(error_unreachable()), } } } impl ast::StStateSpace { fn to_ld_ss(self) -> ast::LdStateSpace { match self { ast::StStateSpace::Generic => ast::LdStateSpace::Generic, ast::StStateSpace::Global => ast::LdStateSpace::Global, ast::StStateSpace::Local => ast::LdStateSpace::Local, ast::StStateSpace::Param => ast::LdStateSpace::Param, ast::StStateSpace::Shared => ast::LdStateSpace::Shared, } } } #[derive(Clone, Copy, PartialEq, Eq)] enum ScalarKind { Bit, Unsigned, Signed, Float, Float2, Pred, } impl ast::ScalarType { fn kind(self) -> ScalarKind { match self { ast::ScalarType::U8 => ScalarKind::Unsigned, ast::ScalarType::U16 => ScalarKind::Unsigned, ast::ScalarType::U32 => ScalarKind::Unsigned, ast::ScalarType::U64 => ScalarKind::Unsigned, ast::ScalarType::S8 => ScalarKind::Signed, ast::ScalarType::S16 => ScalarKind::Signed, ast::ScalarType::S32 => ScalarKind::Signed, ast::ScalarType::S64 => ScalarKind::Signed, ast::ScalarType::B8 => ScalarKind::Bit, ast::ScalarType::B16 => ScalarKind::Bit, ast::ScalarType::B32 => ScalarKind::Bit, ast::ScalarType::B64 => ScalarKind::Bit, ast::ScalarType::F16 => ScalarKind::Float, ast::ScalarType::F32 => ScalarKind::Float, ast::ScalarType::F64 => ScalarKind::Float, ast::ScalarType::F16x2 => ScalarKind::Float2, ast::ScalarType::Pred => ScalarKind::Pred, } } fn from_parts(width: u8, kind: ScalarKind) -> Self { match kind { ScalarKind::Float => match width { 2 => ast::ScalarType::F16, 4 => ast::ScalarType::F32, 8 => ast::ScalarType::F64, _ => unreachable!(), }, ScalarKind::Bit => match width { 1 => ast::ScalarType::B8, 2 => ast::ScalarType::B16, 4 => ast::ScalarType::B32, 8 => ast::ScalarType::B64, _ => unreachable!(), }, ScalarKind::Signed => match width { 1 => ast::ScalarType::S8, 2 => ast::ScalarType::S16, 4 => ast::ScalarType::S32, 8 => ast::ScalarType::S64, _ => unreachable!(), }, ScalarKind::Unsigned => match width { 1 => ast::ScalarType::U8, 2 => ast::ScalarType::U16, 4 => ast::ScalarType::U32, 8 => ast::ScalarType::U64, _ => unreachable!(), }, ScalarKind::Float2 => match width { 4 => ast::ScalarType::F16x2, _ => unreachable!(), }, ScalarKind::Pred => ast::ScalarType::Pred, } } } impl ast::BooleanType { fn to_type(self) -> ast::Type { match self { ast::BooleanType::Pred => ast::Type::Scalar(ast::ScalarType::Pred), ast::BooleanType::B16 => ast::Type::Scalar(ast::ScalarType::B16), ast::BooleanType::B32 => ast::Type::Scalar(ast::ScalarType::B32), ast::BooleanType::B64 => ast::Type::Scalar(ast::ScalarType::B64), } } } impl ast::ShlType { fn to_type(self) -> ast::Type { match self { ast::ShlType::B16 => ast::Type::Scalar(ast::ScalarType::B16), ast::ShlType::B32 => ast::Type::Scalar(ast::ScalarType::B32), ast::ShlType::B64 => ast::Type::Scalar(ast::ScalarType::B64), } } } impl ast::ShrType { fn signed(&self) -> bool { match self { ast::ShrType::S16 | ast::ShrType::S32 | ast::ShrType::S64 => true, _ => false, } } } impl ast::ArithDetails { fn get_type(&self) -> ast::Type { ast::Type::Scalar(match self { ast::ArithDetails::Unsigned(t) => (*t).into(), ast::ArithDetails::Signed(d) => d.typ.into(), ast::ArithDetails::Float(d) => d.typ.into(), }) } } impl ast::MulDetails { fn get_type(&self) -> ast::Type { ast::Type::Scalar(match self { ast::MulDetails::Unsigned(d) => d.typ.into(), ast::MulDetails::Signed(d) => d.typ.into(), ast::MulDetails::Float(d) => d.typ.into(), }) } } impl ast::MinMaxDetails { fn get_type(&self) -> ast::Type { ast::Type::Scalar(match self { ast::MinMaxDetails::Signed(t) => (*t).into(), ast::MinMaxDetails::Unsigned(t) => (*t).into(), ast::MinMaxDetails::Float(d) => d.typ.into(), }) } } impl ast::DivDetails { fn get_type(&self) -> ast::Type { ast::Type::Scalar(match self { ast::DivDetails::Unsigned(t) => (*t).into(), ast::DivDetails::Signed(t) => (*t).into(), ast::DivDetails::Float(d) => d.typ.into(), }) } } impl ast::AtomInnerDetails { fn get_type(&self) -> ast::ScalarType { match self { ast::AtomInnerDetails::Bit { typ, .. } => (*typ).into(), ast::AtomInnerDetails::Unsigned { typ, .. } => (*typ).into(), ast::AtomInnerDetails::Signed { typ, .. } => (*typ).into(), ast::AtomInnerDetails::Float { typ, .. } => (*typ).into(), } } } impl ast::SIntType { fn from_size(width: u8) -> Self { match width { 1 => ast::SIntType::S8, 2 => ast::SIntType::S16, 4 => ast::SIntType::S32, 8 => ast::SIntType::S64, _ => unreachable!(), } } } impl ast::UIntType { fn from_size(width: u8) -> Self { match width { 1 => ast::UIntType::U8, 2 => ast::UIntType::U16, 4 => ast::UIntType::U32, 8 => ast::UIntType::U64, _ => unreachable!(), } } } impl ast::LdStateSpace { fn to_spirv(self) -> spirv::StorageClass { match self { ast::LdStateSpace::Const => spirv::StorageClass::UniformConstant, ast::LdStateSpace::Generic => spirv::StorageClass::Generic, ast::LdStateSpace::Global => spirv::StorageClass::CrossWorkgroup, ast::LdStateSpace::Local => spirv::StorageClass::Function, ast::LdStateSpace::Shared => spirv::StorageClass::Workgroup, ast::LdStateSpace::Param => spirv::StorageClass::Function, } } } impl From for ast::VariableType { fn from(t: ast::FnArgumentType) -> Self { match t { ast::FnArgumentType::Reg(t) => ast::VariableType::Reg(t), ast::FnArgumentType::Param(t) => ast::VariableType::Param(t), ast::FnArgumentType::Shared => todo!(), } } } impl ast::Operand { fn underlying(&self) -> Option<&T> { match self { ast::Operand::Reg(r) | ast::Operand::RegOffset(r, _) => Some(r), ast::Operand::Imm(_) => None, ast::Operand::VecMember(reg, _) => Some(reg), ast::Operand::VecPack(..) => None, } } } impl ast::MulDetails { fn is_wide(&self) -> bool { match self { ast::MulDetails::Unsigned(d) => d.control == ast::MulIntControl::Wide, ast::MulDetails::Signed(d) => d.control == ast::MulIntControl::Wide, ast::MulDetails::Float(_) => false, } } } impl ast::AtomSpace { fn to_ld_ss(self) -> ast::LdStateSpace { match self { ast::AtomSpace::Generic => ast::LdStateSpace::Generic, ast::AtomSpace::Global => ast::LdStateSpace::Global, ast::AtomSpace::Shared => ast::LdStateSpace::Shared, } } } impl ast::MemScope { fn to_spirv(self) -> spirv::Scope { match self { ast::MemScope::Cta => spirv::Scope::Workgroup, ast::MemScope::Gpu => spirv::Scope::Device, ast::MemScope::Sys => spirv::Scope::CrossDevice, } } } impl ast::AtomSemantics { fn to_spirv(self) -> spirv::MemorySemantics { match self { ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED, ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE, ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE, ast::AtomSemantics::AcquireRelease => spirv::MemorySemantics::ACQUIRE_RELEASE, } } } impl ast::FnArgumentType { fn semantics(&self) -> ArgumentSemantics { match self { ast::FnArgumentType::Reg(_) => ArgumentSemantics::Default, ast::FnArgumentType::Param(_) => ArgumentSemantics::RegisterPointer, ast::FnArgumentType::Shared => ArgumentSemantics::PhysicalPointer, } } } fn bitcast_register_pointer( operand_type: &ast::Type, instr_type: &ast::Type, ss: Option, ) -> Result, TranslateError> { bitcast_physical_pointer(operand_type, instr_type, ss) } fn bitcast_physical_pointer( operand_type: &ast::Type, instr_type: &ast::Type, ss: Option, ) -> Result, TranslateError> { match operand_type { // array decays to a pointer ast::Type::Array(op_scalar_t, _) => { if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type { if ss == Some(*instr_space) { if ast::Type::Scalar(*op_scalar_t) == ast::Type::from(instr_scalar_t.clone()) { Ok(None) } else { Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) } } else { if ss == Some(ast::LdStateSpace::Generic) || *instr_space == ast::LdStateSpace::Generic { Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) } else { Err(TranslateError::MismatchedType) } } } else { Err(TranslateError::MismatchedType) } } ast::Type::Scalar(ast::ScalarType::B64) | ast::Type::Scalar(ast::ScalarType::U64) | ast::Type::Scalar(ast::ScalarType::S64) => { if let Some(space) = ss { Ok(Some(ConversionKind::BitToPtr(space))) } else { Err(error_unreachable()) } } ast::Type::Scalar(ast::ScalarType::B32) | ast::Type::Scalar(ast::ScalarType::U32) | ast::Type::Scalar(ast::ScalarType::S32) => match ss { Some(ast::LdStateSpace::Shared) | Some(ast::LdStateSpace::Generic) | Some(ast::LdStateSpace::Param) | Some(ast::LdStateSpace::Local) => { Ok(Some(ConversionKind::BitToPtr(ast::LdStateSpace::Shared))) } _ => Err(TranslateError::MismatchedType), }, ast::Type::Pointer(op_scalar_t, op_space) => { if let ast::Type::Pointer(instr_scalar_t, instr_space) = instr_type { if op_space == instr_space { if op_scalar_t == instr_scalar_t { Ok(None) } else { Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) } } else { if *op_space == ast::LdStateSpace::Generic || *instr_space == ast::LdStateSpace::Generic { Ok(Some(ConversionKind::PtrToPtr { spirv_ptr: false })) } else { Err(TranslateError::MismatchedType) } } } else { Err(TranslateError::MismatchedType) } } _ => Err(TranslateError::MismatchedType), } } fn force_bitcast_ptr_to_bit( _: &ast::Type, instr_type: &ast::Type, _: Option, ) -> Result, TranslateError> { // TODO: verify this on f32, u16 and the like if let ast::Type::Scalar(scalar_t) = instr_type { if let Ok(int_type) = (*scalar_t).try_into() { return Ok(Some(ConversionKind::PtrToBit(int_type))); } } Err(TranslateError::MismatchedType) } fn should_bitcast(instr: &ast::Type, operand: &ast::Type) -> bool { match (instr, operand) { (ast::Type::Scalar(inst), ast::Type::Scalar(operand)) => { if inst.size_of() != operand.size_of() { return false; } match inst.kind() { ScalarKind::Bit => operand.kind() != ScalarKind::Bit, ScalarKind::Float => operand.kind() == ScalarKind::Bit, ScalarKind::Signed => { operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Unsigned } ScalarKind::Unsigned => { operand.kind() == ScalarKind::Bit || operand.kind() == ScalarKind::Signed } ScalarKind::Float2 => false, ScalarKind::Pred => false, } } (ast::Type::Vector(inst, _), ast::Type::Vector(operand, _)) | (ast::Type::Array(inst, _), ast::Type::Array(operand, _)) => { should_bitcast(&ast::Type::Scalar(*inst), &ast::Type::Scalar(*operand)) } _ => false, } } fn should_bitcast_packed( operand: &ast::Type, instr: &ast::Type, ss: Option, ) -> Result, TranslateError> { if let (ast::Type::Vector(vec_underlying_type, vec_len), ast::Type::Scalar(scalar)) = (operand, instr) { if scalar.kind() == ScalarKind::Bit && scalar.size_of() == (vec_underlying_type.size_of() * vec_len) { return Ok(Some(ConversionKind::Default)); } } should_bitcast_wrapper(operand, instr, ss) } fn should_bitcast_wrapper( operand: &ast::Type, instr: &ast::Type, _: Option, ) -> Result, TranslateError> { if instr == operand { return Ok(None); } if should_bitcast(instr, operand) { Ok(Some(ConversionKind::Default)) } else { Err(TranslateError::MismatchedType) } } fn should_convert_relaxed_src_wrapper( src_type: &ast::Type, instr_type: &ast::Type, _: Option, ) -> Result, TranslateError> { if src_type == instr_type { return Ok(None); } match should_convert_relaxed_src(src_type, instr_type) { conv @ Some(_) => Ok(conv), None => Err(TranslateError::MismatchedType), } } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-source-operands fn should_convert_relaxed_src( src_type: &ast::Type, instr_type: &ast::Type, ) -> Option { if src_type == instr_type { return None; } match (src_type, instr_type) { (ast::Type::Scalar(src_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { ScalarKind::Bit => { if instr_type.size_of() <= src_type.size_of() { Some(ConversionKind::Default) } else { None } } ScalarKind::Signed | ScalarKind::Unsigned => { if instr_type.size_of() <= src_type.size_of() && src_type.kind() != ScalarKind::Float { Some(ConversionKind::Default) } else { None } } ScalarKind::Float => { if instr_type.size_of() <= src_type.size_of() && src_type.kind() == ScalarKind::Bit { Some(ConversionKind::Default) } else { None } } ScalarKind::Float2 => todo!(), ScalarKind::Pred => None, }, (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { should_convert_relaxed_src( &ast::Type::Scalar(*dst_type), &ast::Type::Scalar(*instr_type), ) } _ => None, } } fn should_convert_relaxed_dst_wrapper( dst_type: &ast::Type, instr_type: &ast::Type, _: Option, ) -> Result, TranslateError> { if dst_type == instr_type { return Ok(None); } match should_convert_relaxed_dst(dst_type, instr_type) { conv @ Some(_) => Ok(conv), None => Err(TranslateError::MismatchedType), } } // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#operand-size-exceeding-instruction-type-size__relaxed-type-checking-rules-destination-operands fn should_convert_relaxed_dst( dst_type: &ast::Type, instr_type: &ast::Type, ) -> Option { if dst_type == instr_type { return None; } match (dst_type, instr_type) { (ast::Type::Scalar(dst_type), ast::Type::Scalar(instr_type)) => match instr_type.kind() { ScalarKind::Bit => { if instr_type.size_of() <= dst_type.size_of() { Some(ConversionKind::Default) } else { None } } ScalarKind::Signed => { if dst_type.kind() != ScalarKind::Float { if instr_type.size_of() == dst_type.size_of() { Some(ConversionKind::Default) } else if instr_type.size_of() < dst_type.size_of() { Some(ConversionKind::SignExtend) } else { None } } else { None } } ScalarKind::Unsigned => { if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() != ScalarKind::Float { Some(ConversionKind::Default) } else { None } } ScalarKind::Float => { if instr_type.size_of() <= dst_type.size_of() && dst_type.kind() == ScalarKind::Bit { Some(ConversionKind::Default) } else { None } } ScalarKind::Float2 => todo!(), ScalarKind::Pred => None, }, (ast::Type::Vector(dst_type, _), ast::Type::Vector(instr_type, _)) | (ast::Type::Array(dst_type, _), ast::Type::Array(instr_type, _)) => { should_convert_relaxed_dst( &ast::Type::Scalar(*dst_type), &ast::Type::Scalar(*instr_type), ) } _ => None, } } impl<'a> ast::MethodDecl<'a, &'a str> { fn name(&self) -> &'a str { match self { ast::MethodDecl::Kernel { name, .. } => name, ast::MethodDecl::Func(_, name, _) => name, } } } struct SpirvMethodDecl<'input> { input: Vec>, output: Vec>, name: MethodName<'input>, uses_shared_mem: bool, } impl<'input> SpirvMethodDecl<'input> { fn new(ast_decl: &ast::MethodDecl<'input, spirv::Word>) -> Self { let (input, output) = match ast_decl { ast::MethodDecl::Kernel { in_args, .. } => { let spirv_input = in_args .iter() .map(|var| { let v_type = match &var.v_type { ast::KernelArgumentType::Normal(t) => { ast::FnArgumentType::Param(t.clone()) } ast::KernelArgumentType::Shared => ast::FnArgumentType::Shared, }; ast::Variable { name: var.name, align: var.align, v_type: v_type.to_kernel_type(), array_init: var.array_init.clone(), } }) .collect(); (spirv_input, Vec::new()) } ast::MethodDecl::Func(out_args, _, in_args) => { let (param_output, non_param_output): (Vec<_>, Vec<_>) = out_args.iter().partition(|var| var.v_type.is_param()); let spirv_output = non_param_output .into_iter() .cloned() .map(|var| ast::Variable { name: var.name, align: var.align, v_type: var.v_type.to_func_type(), array_init: var.array_init.clone(), }) .collect(); let spirv_input = param_output .into_iter() .cloned() .chain(in_args.iter().cloned()) .map(|var| ast::Variable { name: var.name, align: var.align, v_type: var.v_type.to_func_type(), array_init: var.array_init.clone(), }) .collect(); (spirv_input, spirv_output) } }; SpirvMethodDecl { input, output, name: MethodName::new(ast_decl), uses_shared_mem: false, } } } #[cfg(test)] mod tests { use super::*; use crate::ast; static SCALAR_TYPES: [ast::ScalarType; 15] = [ ast::ScalarType::B8, ast::ScalarType::B16, ast::ScalarType::B32, ast::ScalarType::B64, ast::ScalarType::S8, ast::ScalarType::S16, ast::ScalarType::S32, ast::ScalarType::S64, ast::ScalarType::U8, ast::ScalarType::U16, ast::ScalarType::U32, ast::ScalarType::U64, ast::ScalarType::F16, ast::ScalarType::F32, ast::ScalarType::F64, ]; static RELAXED_SRC_CONVERSION_TABLE: &'static str = "b8 - chop chop chop - chop chop chop - chop chop chop chop chop chop b16 inv - chop chop inv - chop chop inv - chop chop - chop chop b32 inv inv - chop inv inv - chop inv inv - chop inv - chop b64 inv inv inv - inv inv inv - inv inv inv - inv inv - s8 - chop chop chop - chop chop chop - chop chop chop inv inv inv s16 inv - chop chop inv - chop chop inv - chop chop inv inv inv s32 inv inv - chop inv inv - chop inv inv - chop inv inv inv s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv u8 - chop chop chop - chop chop chop - chop chop chop inv inv inv u16 inv - chop chop inv - chop chop inv - chop chop inv inv inv u32 inv inv - chop inv inv - chop inv inv - chop inv inv inv u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv f16 inv - chop chop inv inv inv inv inv inv inv inv - inv inv f32 inv inv - chop inv inv inv inv inv inv inv inv inv - inv f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -"; static RELAXED_DST_CONVERSION_TABLE: &'static str = "b8 - zext zext zext - zext zext zext - zext zext zext zext zext zext b16 inv - zext zext inv - zext zext inv - zext zext - zext zext b32 inv inv - zext inv inv - zext inv inv - zext inv - zext b64 inv inv inv - inv inv inv - inv inv inv - inv inv - s8 - sext sext sext - sext sext sext - sext sext sext inv inv inv s16 inv - sext sext inv - sext sext inv - sext sext inv inv inv s32 inv inv - sext inv inv - sext inv inv - sext inv inv inv s64 inv inv inv - inv inv inv - inv inv inv - inv inv inv u8 - zext zext zext - zext zext zext - zext zext zext inv inv inv u16 inv - zext zext inv - zext zext inv - zext zext inv inv inv u32 inv inv - zext inv inv - zext inv inv - zext inv inv inv u64 inv inv inv - inv inv inv - inv inv inv - inv inv inv f16 inv - zext zext inv inv inv inv inv inv inv inv - inv inv f32 inv inv - zext inv inv inv inv inv inv inv inv inv - inv f64 inv inv inv - inv inv inv inv inv inv inv inv inv inv -"; fn table_entry_to_conversion(entry: &'static str) -> Option { match entry { "-" => Some(ConversionKind::Default), "inv" => None, "zext" => Some(ConversionKind::Default), "chop" => Some(ConversionKind::Default), "sext" => Some(ConversionKind::SignExtend), _ => unreachable!(), } } fn parse_conversion_table(table: &'static str) -> Vec>> { table .lines() .map(|line| { line.split_ascii_whitespace() .skip(1) .map(table_entry_to_conversion) .collect::>() }) .collect::>() } fn assert_conversion_table Option>( table: &'static str, f: F, ) { let conv_table = parse_conversion_table(table); for (instr_idx, instr_type) in SCALAR_TYPES.iter().enumerate() { for (op_idx, op_type) in SCALAR_TYPES.iter().enumerate() { let conversion = f( &ast::Type::Scalar(*op_type), &ast::Type::Scalar(*instr_type), ); if instr_idx == op_idx { assert!(conversion == None); } else { assert!(conversion == conv_table[instr_idx][op_idx]); } } } } #[test] fn should_convert_relaxed_src_all_combinations() { assert_conversion_table(RELAXED_SRC_CONVERSION_TABLE, should_convert_relaxed_src); } #[test] fn should_convert_relaxed_dst_all_combinations() { assert_conversion_table(RELAXED_DST_CONVERSION_TABLE, should_convert_relaxed_dst); } }