From 3870a96592c6a93d3a68391f6cbaecd9c7a2bc97 Mon Sep 17 00:00:00 2001 From: Andrzej Janik Date: Wed, 16 Oct 2024 03:15:48 +0200 Subject: Re-enable all failing PTX tests (#277) Additionally remove unused compilation paths --- ptx/src/pass/emit_spirv.rs | 2762 -------------------------------------------- 1 file changed, 2762 deletions(-) delete mode 100644 ptx/src/pass/emit_spirv.rs (limited to 'ptx/src/pass/emit_spirv.rs') diff --git a/ptx/src/pass/emit_spirv.rs b/ptx/src/pass/emit_spirv.rs deleted file mode 100644 index 120a477..0000000 --- a/ptx/src/pass/emit_spirv.rs +++ /dev/null @@ -1,2762 +0,0 @@ -use super::*; -use half::f16; -use ptx_parser as ast; -use rspirv::{binary::Assemble, dr}; -use std::{ - collections::{HashMap, HashSet}, - ffi::CString, - mem, -}; - -pub(super) fn run<'input>( - mut builder: dr::Builder, - id_defs: &GlobalStringIdResolver<'input>, - call_map: MethodsCallMap<'input>, - denorm_information: HashMap< - ptx_parser::MethodName, - HashMap, - >, - directives: Vec>, -) -> Result<(dr::Module, HashMap, CString), TranslateError> { - builder.set_version(1, 3); - emit_capabilities(&mut builder); - emit_extensions(&mut builder); - let opencl_id = emit_opencl_import(&mut builder); - emit_memory_model(&mut builder); - let mut map = TypeWordMap::new(&mut builder); - //emit_builtins(&mut builder, &mut map, &id_defs); - let mut kernel_info = HashMap::new(); - let (build_options, should_flush_denorms) = - emit_denorm_build_string(&call_map, &denorm_information); - let (directives, globals_use_map) = get_globals_use_map(directives); - emit_directives( - &mut builder, - &mut map, - &id_defs, - opencl_id, - should_flush_denorms, - &call_map, - globals_use_map, - directives, - &mut kernel_info, - )?; - Ok((builder.module(), kernel_info, build_options)) -} - -fn emit_capabilities(builder: &mut dr::Builder) { - builder.capability(spirv::Capability::GenericPointer); - builder.capability(spirv::Capability::Linkage); - builder.capability(spirv::Capability::Addresses); - builder.capability(spirv::Capability::Kernel); - builder.capability(spirv::Capability::Int8); - builder.capability(spirv::Capability::Int16); - builder.capability(spirv::Capability::Int64); - builder.capability(spirv::Capability::Float16); - builder.capability(spirv::Capability::Float64); - builder.capability(spirv::Capability::DenormFlushToZero); - // TODO: re-enable when Intel float control extension works - //builder.capability(spirv::Capability::FunctionFloatControlINTEL); -} - -// http://htmlpreview.github.io/?https://github.com/KhronosGroup/SPIRV-Registry/blob/master/extensions/KHR/SPV_KHR_float_controls.html -fn emit_extensions(builder: &mut dr::Builder) { - // TODO: re-enable when Intel float control extension works - //builder.extension("SPV_INTEL_float_controls2"); - builder.extension("SPV_KHR_float_controls"); - builder.extension("SPV_KHR_no_integer_wrap_decoration"); -} - -fn emit_opencl_import(builder: &mut dr::Builder) -> spirv::Word { - builder.ext_inst_import("OpenCL.std") -} - -fn emit_memory_model(builder: &mut dr::Builder) { - builder.memory_model( - spirv::AddressingModel::Physical64, - spirv::MemoryModel::OpenCL, - ); -} - -struct TypeWordMap { - void: spirv::Word, - complex: HashMap, - constants: HashMap<(SpirvType, u64), SpirvWord>, -} - -impl TypeWordMap { - fn new(b: &mut dr::Builder) -> TypeWordMap { - let void = b.type_void(None); - TypeWordMap { - void: void, - complex: HashMap::::new(), - constants: HashMap::new(), - } - } - - fn void(&self) -> spirv::Word { - self.void - } - - fn get_or_add_scalar(&mut self, b: &mut dr::Builder, t: ast::ScalarType) -> SpirvWord { - let key: SpirvScalarKey = t.into(); - self.get_or_add_spirv_scalar(b, key) - } - - fn get_or_add_spirv_scalar(&mut self, b: &mut dr::Builder, key: SpirvScalarKey) -> SpirvWord { - *self.complex.entry(SpirvType::Base(key)).or_insert_with(|| { - SpirvWord(match key { - SpirvScalarKey::B8 => b.type_int(None, 8, 0), - SpirvScalarKey::B16 => b.type_int(None, 16, 0), - SpirvScalarKey::B32 => b.type_int(None, 32, 0), - SpirvScalarKey::B64 => b.type_int(None, 64, 0), - SpirvScalarKey::F16 => b.type_float(None, 16), - SpirvScalarKey::F32 => b.type_float(None, 32), - SpirvScalarKey::F64 => b.type_float(None, 64), - SpirvScalarKey::Pred => b.type_bool(None), - SpirvScalarKey::F16x2 => todo!(), - }) - }) - } - - fn get_or_add(&mut self, b: &mut dr::Builder, t: SpirvType) -> SpirvWord { - match t { - SpirvType::Base(key) => self.get_or_add_spirv_scalar(b, key), - SpirvType::Pointer(ref typ, storage) => { - let base = self.get_or_add(b, *typ.clone()); - *self - .complex - .entry(t) - .or_insert_with(|| SpirvWord(b.type_pointer(None, storage, base.0))) - } - SpirvType::Vector(typ, len) => { - let base = self.get_or_add_spirv_scalar(b, typ); - *self - .complex - .entry(t) - .or_insert_with(|| SpirvWord(b.type_vector(None, base.0, len as u32))) - } - SpirvType::Array(typ, array_dimensions) => { - let (base_type, length) = match &*array_dimensions { - &[] => { - return self.get_or_add(b, SpirvType::Base(typ)); - } - &[len] => { - let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); - let base = self.get_or_add_spirv_scalar(b, typ); - let len_const = b.constant_u32(u32_type.0, None, len); - (base, len_const) - } - array_dimensions => { - let u32_type = self.get_or_add_scalar(b, ast::ScalarType::U32); - let base = self - .get_or_add(b, SpirvType::Array(typ, array_dimensions[1..].to_vec())); - let len_const = b.constant_u32(u32_type.0, None, array_dimensions[0]); - (base, len_const) - } - }; - *self - .complex - .entry(SpirvType::Array(typ, array_dimensions)) - .or_insert_with(|| SpirvWord(b.type_array(None, base_type.0, length))) - } - SpirvType::Func(ref out_params, ref in_params) => { - let out_t = match out_params { - Some(p) => self.get_or_add(b, *p.clone()), - None => SpirvWord(self.void()), - }; - let in_t = in_params - .iter() - .map(|t| self.get_or_add(b, t.clone()).0) - .collect::>(); - *self - .complex - .entry(t) - .or_insert_with(|| SpirvWord(b.type_function(None, out_t.0, in_t))) - } - SpirvType::Struct(ref underlying) => { - let underlying_ids = underlying - .iter() - .map(|t| self.get_or_add_spirv_scalar(b, *t).0) - .collect::>(); - *self - .complex - .entry(t) - .or_insert_with(|| SpirvWord(b.type_struct(None, underlying_ids))) - } - } - } - - fn get_or_add_fn( - &mut self, - b: &mut dr::Builder, - in_params: impl Iterator, - mut out_params: impl ExactSizeIterator, - ) -> (SpirvWord, SpirvWord) { - let (out_args, out_spirv_type) = if out_params.len() == 0 { - (None, SpirvWord(self.void())) - } else if out_params.len() == 1 { - let arg_as_key = out_params.next().unwrap(); - ( - Some(Box::new(arg_as_key.clone())), - self.get_or_add(b, arg_as_key), - ) - } else { - // TODO: support multiple return values - todo!() - }; - ( - out_spirv_type, - self.get_or_add(b, SpirvType::Func(out_args, in_params.collect::>())), - ) - } - - fn get_or_add_constant( - &mut self, - b: &mut dr::Builder, - typ: &ast::Type, - init: &[u8], - ) -> Result { - Ok(match typ { - ast::Type::Scalar(t) => match t { - ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => self - .get_or_add_constant_single::( - b, - *t, - init, - |v| v as u64, - |b, result_type, v| b.constant_u32(result_type, None, v as u32), - ), - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => self - .get_or_add_constant_single::( - b, - *t, - init, - |v| v as u64, - |b, result_type, v| b.constant_u32(result_type, None, v as u32), - ), - ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => self - .get_or_add_constant_single::( - b, - *t, - init, - |v| v as u64, - |b, result_type, v| b.constant_u32(result_type, None, v), - ), - ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => self - .get_or_add_constant_single::( - b, - *t, - init, - |v| v, - |b, result_type, v| b.constant_u64(result_type, None, v), - ), - ast::ScalarType::F16 => self.get_or_add_constant_single::( - b, - *t, - init, - |v| unsafe { mem::transmute::<_, u16>(v) } as u64, - |b, result_type, v| b.constant_f32(result_type, None, v.to_f32()), - ), - ast::ScalarType::F32 => self.get_or_add_constant_single::( - b, - *t, - init, - |v| unsafe { mem::transmute::<_, u32>(v) } as u64, - |b, result_type, v| b.constant_f32(result_type, None, v), - ), - ast::ScalarType::F64 => self.get_or_add_constant_single::( - b, - *t, - init, - |v| unsafe { mem::transmute::<_, u64>(v) }, - |b, result_type, v| b.constant_f64(result_type, None, v), - ), - ast::ScalarType::F16x2 => return Err(TranslateError::Todo), - ast::ScalarType::Pred => self.get_or_add_constant_single::( - b, - *t, - init, - |v| v as u64, - |b, result_type, v| { - if v == 0 { - b.constant_false(result_type, None) - } else { - b.constant_true(result_type, None) - } - }, - ), - ast::ScalarType::S16x2 - | ast::ScalarType::U16x2 - | ast::ScalarType::BF16 - | ast::ScalarType::BF16x2 - | ast::ScalarType::B128 => todo!(), - }, - ast::Type::Vector(len, typ) => { - let result_type = - self.get_or_add(b, SpirvType::Vector(SpirvScalarKey::from(*typ), *len)); - let size_of_t = typ.size_of(); - let components = (0..*len) - .map(|x| { - Ok::<_, TranslateError>( - self.get_or_add_constant( - b, - &ast::Type::Scalar(*typ), - &init[((size_of_t as usize) * (x as usize))..], - )? - .0, - ) - }) - .collect::, _>>()?; - SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) - } - ast::Type::Array(_, typ, dims) => match dims.as_slice() { - [] => return Err(error_unreachable()), - [dim] => { - let result_type = self - .get_or_add(b, SpirvType::Array(SpirvScalarKey::from(*typ), vec![*dim])); - let size_of_t = typ.size_of(); - let components = (0..*dim) - .map(|x| { - Ok::<_, TranslateError>( - self.get_or_add_constant( - b, - &ast::Type::Scalar(*typ), - &init[((size_of_t as usize) * (x as usize))..], - )? - .0, - ) - }) - .collect::, _>>()?; - SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) - } - [first_dim, rest @ ..] => { - let result_type = self.get_or_add( - b, - SpirvType::Array(SpirvScalarKey::from(*typ), rest.to_vec()), - ); - let size_of_t = rest - .iter() - .fold(typ.size_of() as u32, |x, y| (x as u32) * (*y)); - let components = (0..*first_dim) - .map(|x| { - Ok::<_, TranslateError>( - self.get_or_add_constant( - b, - &ast::Type::Array(None, *typ, rest.to_vec()), - &init[((size_of_t as usize) * (x as usize))..], - )? - .0, - ) - }) - .collect::, _>>()?; - SpirvWord(b.constant_composite(result_type.0, None, components.into_iter())) - } - }, - ast::Type::Pointer(..) => return Err(error_unreachable()), - }) - } - - fn get_or_add_constant_single< - T: Copy, - CastAsU64: FnOnce(T) -> u64, - InsertConstant: FnOnce(&mut dr::Builder, spirv::Word, T) -> spirv::Word, - >( - &mut self, - b: &mut dr::Builder, - key: ast::ScalarType, - init: &[u8], - cast: CastAsU64, - f: InsertConstant, - ) -> SpirvWord { - let value = unsafe { *(init.as_ptr() as *const T) }; - let value_64 = cast(value); - let ht_key = (SpirvType::Base(SpirvScalarKey::from(key)), value_64); - match self.constants.get(&ht_key) { - Some(value) => *value, - None => { - let spirv_type = self.get_or_add_scalar(b, key); - let result = SpirvWord(f(b, spirv_type.0, value)); - self.constants.insert(ht_key, result); - result - } - } - } -} - -#[derive(PartialEq, Eq, Hash, Clone)] -enum SpirvType { - Base(SpirvScalarKey), - Vector(SpirvScalarKey, u8), - Array(SpirvScalarKey, Vec), - Pointer(Box, spirv::StorageClass), - Func(Option>, Vec), - Struct(Vec), -} - -impl SpirvType { - fn new(t: ast::Type) -> Self { - match t { - ast::Type::Scalar(t) => SpirvType::Base(t.into()), - ast::Type::Vector(len, typ) => SpirvType::Vector(typ.into(), len), - ast::Type::Array(_, t, len) => SpirvType::Array(t.into(), len), - ast::Type::Pointer(pointer_t, space) => SpirvType::Pointer( - Box::new(SpirvType::Base(pointer_t.into())), - space_to_spirv(space), - ), - } - } - - fn pointer_to(t: ast::Type, outer_space: spirv::StorageClass) -> Self { - let key = Self::new(t); - SpirvType::Pointer(Box::new(key), outer_space) - } -} - -impl From for SpirvType { - fn from(t: ast::ScalarType) -> Self { - SpirvType::Base(t.into()) - } -} -// SPIR-V integer type definitions are signless, more below: -// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers -// https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_validation_rules_for_kernel_a_href_capability_capabilities_a -#[derive(PartialEq, Eq, Hash, Clone, Copy)] -enum SpirvScalarKey { - B8, - B16, - B32, - B64, - F16, - F32, - F64, - Pred, - F16x2, -} - -impl From for SpirvScalarKey { - fn from(t: ast::ScalarType) -> Self { - match t { - ast::ScalarType::B8 | ast::ScalarType::U8 | ast::ScalarType::S8 => SpirvScalarKey::B8, - ast::ScalarType::B16 | ast::ScalarType::U16 | ast::ScalarType::S16 => { - SpirvScalarKey::B16 - } - ast::ScalarType::B32 | ast::ScalarType::U32 | ast::ScalarType::S32 => { - SpirvScalarKey::B32 - } - ast::ScalarType::B64 | ast::ScalarType::U64 | ast::ScalarType::S64 => { - SpirvScalarKey::B64 - } - ast::ScalarType::F16 => SpirvScalarKey::F16, - ast::ScalarType::F32 => SpirvScalarKey::F32, - ast::ScalarType::F64 => SpirvScalarKey::F64, - ast::ScalarType::F16x2 => SpirvScalarKey::F16x2, - ast::ScalarType::Pred => SpirvScalarKey::Pred, - ast::ScalarType::S16x2 - | ast::ScalarType::U16x2 - | ast::ScalarType::BF16 - | ast::ScalarType::BF16x2 - | ast::ScalarType::B128 => todo!(), - } - } -} - -fn space_to_spirv(this: ast::StateSpace) -> spirv::StorageClass { - match this { - ast::StateSpace::Const => spirv::StorageClass::UniformConstant, - ast::StateSpace::Generic => spirv::StorageClass::Generic, - ast::StateSpace::Global => spirv::StorageClass::CrossWorkgroup, - ast::StateSpace::Local => spirv::StorageClass::Function, - ast::StateSpace::Shared => spirv::StorageClass::Workgroup, - ast::StateSpace::Param => spirv::StorageClass::Function, - ast::StateSpace::Reg => spirv::StorageClass::Function, - ast::StateSpace::ParamEntry - | ast::StateSpace::ParamFunc - | ast::StateSpace::SharedCluster - | ast::StateSpace::SharedCta => todo!(), - } -} - -// TODO: remove this once we have pef-function support for denorms -fn emit_denorm_build_string<'input>( - call_map: &MethodsCallMap, - denorm_information: &HashMap< - ast::MethodName<'input, SpirvWord>, - HashMap, - >, -) -> (CString, bool) { - let denorm_counts = denorm_information - .iter() - .map(|(method, meth_denorm)| { - let f16_count = meth_denorm - .get(&(mem::size_of::() as u8)) - .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) - .1; - let f32_count = meth_denorm - .get(&(mem::size_of::() as u8)) - .unwrap_or(&(spirv::FPDenormMode::FlushToZero, 0)) - .1; - (method, (f16_count + f32_count)) - }) - .collect::>(); - let mut flush_over_preserve = 0; - for (kernel, children) in call_map.kernels() { - flush_over_preserve += *denorm_counts - .get(&ast::MethodName::Kernel(kernel)) - .unwrap_or(&0); - for child_fn in children { - flush_over_preserve += *denorm_counts - .get(&ast::MethodName::Func(*child_fn)) - .unwrap_or(&0); - } - } - if flush_over_preserve > 0 { - ( - CString::new("-ze-take-global-address -ze-denorms-are-zero").unwrap(), - true, - ) - } else { - (CString::new("-ze-take-global-address").unwrap(), false) - } -} - -fn get_globals_use_map<'input>( - directives: Vec>, -) -> ( - Vec>, - HashMap, HashSet>, -) { - let mut known_globals = HashSet::new(); - for directive in directives.iter() { - match directive { - Directive::Variable(_, ast::Variable { name, .. }) => { - known_globals.insert(*name); - } - Directive::Method(..) => {} - } - } - let mut symbol_uses_map = HashMap::new(); - let directives = directives - .into_iter() - .map(|directive| match directive { - Directive::Variable(..) | Directive::Method(Function { body: None, .. }) => directive, - Directive::Method(Function { - func_decl, - body: Some(mut statements), - globals, - import_as, - tuning, - linkage, - }) => { - let method_name = func_decl.borrow().name; - statements = statements - .into_iter() - .map(|statement| { - statement.visit_map( - &mut |symbol, _: Option<(&ast::Type, ast::StateSpace)>, _, _| { - if known_globals.contains(&symbol) { - multi_hash_map_append( - &mut symbol_uses_map, - method_name, - symbol, - ); - } - Ok::<_, TranslateError>(symbol) - }, - ) - }) - .collect::, _>>() - .unwrap(); - Directive::Method(Function { - func_decl, - body: Some(statements), - globals, - import_as, - tuning, - linkage, - }) - } - }) - .collect::>(); - (directives, symbol_uses_map) -} - -fn emit_directives<'input>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - id_defs: &GlobalStringIdResolver<'input>, - opencl_id: spirv::Word, - should_flush_denorms: bool, - call_map: &MethodsCallMap<'input>, - globals_use_map: HashMap, HashSet>, - directives: Vec>, - kernel_info: &mut HashMap, -) -> Result<(), TranslateError> { - let empty_body = Vec::new(); - for d in directives.iter() { - match d { - Directive::Variable(linking, var) => { - emit_variable(builder, map, id_defs, *linking, &var)?; - } - Directive::Method(f) => { - let f_body = match &f.body { - Some(f) => f, - None => { - if f.linkage.contains(ast::LinkingDirective::EXTERN) { - &empty_body - } else { - continue; - } - } - }; - for var in f.globals.iter() { - emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; - } - let func_decl = (*f.func_decl).borrow(); - let fn_id = emit_function_header( - builder, - map, - &id_defs, - &*func_decl, - call_map, - &globals_use_map, - kernel_info, - )?; - if matches!(func_decl.name, ast::MethodName::Kernel(_)) { - if should_flush_denorms { - builder.execution_mode( - fn_id.0, - spirv_headers::ExecutionMode::DenormFlushToZero, - [16], - ); - builder.execution_mode( - fn_id.0, - spirv_headers::ExecutionMode::DenormFlushToZero, - [32], - ); - builder.execution_mode( - fn_id.0, - spirv_headers::ExecutionMode::DenormFlushToZero, - [64], - ); - } - // FP contraction happens when compiling source -> PTX and is illegal at this stage (unless you force it in cuModuleLoadDataEx) - builder.execution_mode( - fn_id.0, - spirv_headers::ExecutionMode::ContractionOff, - [], - ); - for t in f.tuning.iter() { - match *t { - ast::TuningDirective::MaxNtid(nx, ny, nz) => { - builder.execution_mode( - fn_id.0, - spirv_headers::ExecutionMode::MaxWorkgroupSizeINTEL, - [nx, ny, nz], - ); - } - ast::TuningDirective::ReqNtid(nx, ny, nz) => { - builder.execution_mode( - fn_id.0, - spirv_headers::ExecutionMode::LocalSize, - [nx, ny, nz], - ); - } - // Too architecture specific - ast::TuningDirective::MaxNReg(..) - | ast::TuningDirective::MinNCtaPerSm(..) => {} - } - } - } - emit_function_body_ops(builder, map, id_defs, opencl_id, &f_body)?; - emit_function_linkage(builder, id_defs, f, fn_id)?; - builder.select_block(None)?; - builder.end_function()?; - } - } - } - Ok(()) -} - -fn emit_variable<'input>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - id_defs: &GlobalStringIdResolver<'input>, - linking: ast::LinkingDirective, - var: &ast::Variable, -) -> Result<(), TranslateError> { - let (must_init, st_class) = match var.state_space { - ast::StateSpace::Reg | ast::StateSpace::Param | ast::StateSpace::Local => { - (false, spirv::StorageClass::Function) - } - ast::StateSpace::Global => (true, spirv::StorageClass::CrossWorkgroup), - ast::StateSpace::Shared => (false, spirv::StorageClass::Workgroup), - ast::StateSpace::Const => (false, spirv::StorageClass::UniformConstant), - ast::StateSpace::Generic => todo!(), - ast::StateSpace::ParamEntry - | ast::StateSpace::ParamFunc - | ast::StateSpace::SharedCluster - | ast::StateSpace::SharedCta => todo!(), - }; - let initalizer = if var.array_init.len() > 0 { - Some( - map.get_or_add_constant( - builder, - &ast::Type::from(var.v_type.clone()), - &*var.array_init, - )? - .0, - ) - } else if must_init { - let type_id = map.get_or_add(builder, SpirvType::new(var.v_type.clone())); - Some(builder.constant_null(type_id.0, None)) - } else { - None - }; - let ptr_type_id = map.get_or_add(builder, SpirvType::pointer_to(var.v_type.clone(), st_class)); - builder.variable(ptr_type_id.0, Some(var.name.0), st_class, initalizer); - if let Some(align) = var.align { - builder.decorate( - var.name.0, - spirv::Decoration::Alignment, - [dr::Operand::LiteralInt32(align)].iter().cloned(), - ); - } - if var.state_space != ast::StateSpace::Shared - || !linking.contains(ast::LinkingDirective::EXTERN) - { - emit_linking_decoration(builder, id_defs, None, var.name, linking); - } - Ok(()) -} - -fn emit_function_header<'input>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - defined_globals: &GlobalStringIdResolver<'input>, - func_decl: &ast::MethodDeclaration<'input, SpirvWord>, - call_map: &MethodsCallMap<'input>, - globals_use_map: &HashMap, HashSet>, - kernel_info: &mut HashMap, -) -> Result { - if let ast::MethodName::Kernel(name) = func_decl.name { - let args_lens = func_decl - .input_arguments - .iter() - .map(|param| { - ( - type_size_of(¶m.v_type), - matches!(param.v_type, ast::Type::Pointer(..)), - ) - }) - .collect(); - kernel_info.insert( - name.to_string(), - KernelInfo { - arguments_sizes: args_lens, - uses_shared_mem: func_decl.shared_mem.is_some(), - }, - ); - } - let (ret_type, func_type) = get_function_type( - builder, - map, - effective_input_arguments(func_decl).map(|(_, typ)| typ), - &func_decl.return_arguments, - ); - let fn_id = match func_decl.name { - ast::MethodName::Kernel(name) => { - let fn_id = defined_globals.get_id(name)?; - let interface = globals_use_map - .get(&ast::MethodName::Kernel(name)) - .into_iter() - .flatten() - .copied() - .chain({ - call_map - .get_kernel_children(name) - .copied() - .flat_map(|subfunction| { - globals_use_map - .get(&ast::MethodName::Func(subfunction)) - .into_iter() - .flatten() - .copied() - }) - .into_iter() - }) - .map(|word| word.0) - .collect::>(); - builder.entry_point(spirv::ExecutionModel::Kernel, fn_id.0, name, interface); - fn_id - } - ast::MethodName::Func(name) => name, - }; - builder.begin_function( - ret_type.0, - Some(fn_id.0), - spirv::FunctionControl::NONE, - func_type.0, - )?; - for (name, typ) in effective_input_arguments(func_decl) { - let result_type = map.get_or_add(builder, typ); - builder.function_parameter(Some(name.0), result_type.0)?; - } - Ok(fn_id) -} - -pub fn type_size_of(this: &ast::Type) -> usize { - match this { - ast::Type::Scalar(typ) => typ.size_of() as usize, - ast::Type::Vector(len, typ) => (typ.size_of() as usize) * (*len as usize), - ast::Type::Array(_, typ, len) => len - .iter() - .fold(typ.size_of() as usize, |x, y| (x as usize) * (*y as usize)), - ast::Type::Pointer(..) => mem::size_of::(), - } -} -fn emit_function_body_ops<'input>( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - id_defs: &GlobalStringIdResolver<'input>, - opencl: spirv::Word, - func: &[ExpandedStatement], -) -> Result<(), TranslateError> { - for s in func { - match s { - Statement::Label(id) => { - if builder.selected_block().is_some() { - builder.branch(id.0)?; - } - builder.begin_block(Some(id.0))?; - } - _ => { - if builder.selected_block().is_none() && builder.selected_function().is_some() { - builder.begin_block(None)?; - } - } - } - match s { - Statement::Label(_) => (), - Statement::Variable(var) => { - emit_variable(builder, map, id_defs, ast::LinkingDirective::NONE, var)?; - } - Statement::Constant(cnst) => { - let typ_id = map.get_or_add_scalar(builder, cnst.typ); - match (cnst.typ, cnst.value) { - (ast::ScalarType::B8, ast::ImmediateValue::U64(value)) - | (ast::ScalarType::U8, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32); - } - (ast::ScalarType::B16, ast::ImmediateValue::U64(value)) - | (ast::ScalarType::U16, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32); - } - (ast::ScalarType::B32, ast::ImmediateValue::U64(value)) - | (ast::ScalarType::U32, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32); - } - (ast::ScalarType::B64, ast::ImmediateValue::U64(value)) - | (ast::ScalarType::U64, ast::ImmediateValue::U64(value)) => { - builder.constant_u64(typ_id.0, Some(cnst.dst.0), value); - } - (ast::ScalarType::S8, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32); - } - (ast::ScalarType::S16, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32); - } - (ast::ScalarType::S32, ast::ImmediateValue::U64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32); - } - (ast::ScalarType::S64, ast::ImmediateValue::U64(value)) => { - builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as i64 as u64); - } - (ast::ScalarType::B8, ast::ImmediateValue::S64(value)) - | (ast::ScalarType::U8, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u8 as u32); - } - (ast::ScalarType::B16, ast::ImmediateValue::S64(value)) - | (ast::ScalarType::U16, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u16 as u32); - } - (ast::ScalarType::B32, ast::ImmediateValue::S64(value)) - | (ast::ScalarType::U32, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as u32); - } - (ast::ScalarType::B64, ast::ImmediateValue::S64(value)) - | (ast::ScalarType::U64, ast::ImmediateValue::S64(value)) => { - builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64); - } - (ast::ScalarType::S8, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i8 as u32); - } - (ast::ScalarType::S16, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i16 as u32); - } - (ast::ScalarType::S32, ast::ImmediateValue::S64(value)) => { - builder.constant_u32(typ_id.0, Some(cnst.dst.0), value as i32 as u32); - } - (ast::ScalarType::S64, ast::ImmediateValue::S64(value)) => { - builder.constant_u64(typ_id.0, Some(cnst.dst.0), value as u64); - } - (ast::ScalarType::F16, ast::ImmediateValue::F32(value)) => { - builder.constant_f32( - typ_id.0, - Some(cnst.dst.0), - f16::from_f32(value).to_f32(), - ); - } - (ast::ScalarType::F32, ast::ImmediateValue::F32(value)) => { - builder.constant_f32(typ_id.0, Some(cnst.dst.0), value); - } - (ast::ScalarType::F64, ast::ImmediateValue::F32(value)) => { - builder.constant_f64(typ_id.0, Some(cnst.dst.0), value as f64); - } - (ast::ScalarType::F16, ast::ImmediateValue::F64(value)) => { - builder.constant_f32( - typ_id.0, - Some(cnst.dst.0), - f16::from_f64(value).to_f32(), - ); - } - (ast::ScalarType::F32, ast::ImmediateValue::F64(value)) => { - builder.constant_f32(typ_id.0, Some(cnst.dst.0), value as f32); - } - (ast::ScalarType::F64, ast::ImmediateValue::F64(value)) => { - builder.constant_f64(typ_id.0, Some(cnst.dst.0), value); - } - (ast::ScalarType::Pred, ast::ImmediateValue::U64(value)) => { - let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0; - if value == 0 { - builder.constant_false(bool_type, Some(cnst.dst.0)); - } else { - builder.constant_true(bool_type, Some(cnst.dst.0)); - } - } - (ast::ScalarType::Pred, ast::ImmediateValue::S64(value)) => { - let bool_type = map.get_or_add_scalar(builder, ast::ScalarType::Pred).0; - if value == 0 { - builder.constant_false(bool_type, Some(cnst.dst.0)); - } else { - builder.constant_true(bool_type, Some(cnst.dst.0)); - } - } - _ => return Err(error_mismatched_type()), - } - } - Statement::Conversion(cv) => emit_implicit_conversion(builder, map, cv)?, - Statement::Conditional(bra) => { - builder.branch_conditional( - bra.predicate.0, - bra.if_true.0, - bra.if_false.0, - iter::empty(), - )?; - } - Statement::FunctionPointer(FunctionPointerDetails { dst, src }) => { - // TODO: implement properly - let zero = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U64), - &vec_repr(0u64), - )?; - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::U64); - builder.copy_object(result_type.0, Some(dst.0), zero.0)?; - } - Statement::Instruction(inst) => match inst { - ast::Instruction::PrmtSlow { .. } | ast::Instruction::Trap { .. } => todo!(), - ast::Instruction::Call { data, arguments } => { - let (result_type, result_id) = - match (&*data.return_arguments, &*arguments.return_arguments) { - ([(type_, space)], [id]) => { - if *space != ast::StateSpace::Reg { - return Err(error_unreachable()); - } - ( - map.get_or_add(builder, SpirvType::new(type_.clone())).0, - Some(id.0), - ) - } - ([], []) => (map.void(), None), - _ => todo!(), - }; - let arg_list = arguments - .input_arguments - .iter() - .map(|id| id.0) - .collect::>(); - builder.function_call(result_type, result_id, arguments.func.0, arg_list)?; - } - ast::Instruction::Abs { data, arguments } => { - emit_abs(builder, map, opencl, data, arguments)? - } - // SPIR-V does not support marking jumps as guaranteed-converged - ast::Instruction::Bra { arguments, .. } => { - builder.branch(arguments.src.0)?; - } - ast::Instruction::Ld { data, arguments } => { - let mem_access = match data.qualifier { - ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, - // ld.volatile does not match Volatile OpLoad nor Relaxed OpAtomicLoad - ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, - _ => return Err(TranslateError::Todo), - }; - let result_type = - map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); - builder.load( - result_type.0, - Some(arguments.dst.0), - arguments.src.0, - Some(mem_access | spirv::MemoryAccess::ALIGNED), - [dr::Operand::LiteralInt32( - type_size_of(&ast::Type::from(data.typ.clone())) as u32, - )] - .iter() - .cloned(), - )?; - } - ast::Instruction::St { data, arguments } => { - let mem_access = match data.qualifier { - ast::LdStQualifier::Weak => spirv::MemoryAccess::NONE, - // st.volatile does not match Volatile OpStore nor Relaxed OpAtomicStore - ast::LdStQualifier::Volatile => spirv::MemoryAccess::VOLATILE, - _ => return Err(TranslateError::Todo), - }; - builder.store( - arguments.src1.0, - arguments.src2.0, - Some(mem_access | spirv::MemoryAccess::ALIGNED), - [dr::Operand::LiteralInt32( - type_size_of(&ast::Type::from(data.typ.clone())) as u32, - )] - .iter() - .cloned(), - )?; - } - // SPIR-V does not support ret as guaranteed-converged - ast::Instruction::Ret { .. } => builder.ret()?, - ast::Instruction::Mov { data, arguments } => { - let result_type = - map.get_or_add(builder, SpirvType::new(ast::Type::from(data.typ.clone()))); - builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?; - } - ast::Instruction::Mul { data, arguments } => match data { - ast::MulDetails::Integer { type_, control } => { - emit_mul_int(builder, map, opencl, *type_, *control, arguments)? - } - ast::MulDetails::Float(ref ctr) => { - emit_mul_float(builder, map, ctr, arguments)? - } - }, - ast::Instruction::Add { data, arguments } => match data { - ast::ArithDetails::Integer(desc) => { - emit_add_int(builder, map, desc.type_.into(), desc.saturate, arguments)? - } - ast::ArithDetails::Float(desc) => { - emit_add_float(builder, map, desc, arguments)? - } - }, - ast::Instruction::Setp { data, arguments } => { - if arguments.dst2.is_some() { - todo!() - } - emit_setp(builder, map, data, arguments)?; - } - ast::Instruction::Not { data, arguments } => { - let result_type = map.get_or_add(builder, SpirvType::from(*data)); - let result_id = Some(arguments.dst.0); - let operand = arguments.src; - match data { - ast::ScalarType::Pred => { - logical_not(builder, result_type.0, result_id, operand.0) - } - _ => builder.not(result_type.0, result_id, operand.0), - }?; - } - ast::Instruction::Shl { data, arguments } => { - let full_type = ast::Type::Scalar(*data); - let size_of = type_size_of(&full_type); - let result_type = map.get_or_add(builder, SpirvType::new(full_type)); - let offset_src = insert_shift_hack(builder, map, arguments.src2.0, size_of)?; - builder.shift_left_logical( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - offset_src, - )?; - } - ast::Instruction::Shr { data, arguments } => { - let full_type = ast::ScalarType::from(data.type_); - let size_of = full_type.size_of(); - let result_type = map.get_or_add_scalar(builder, full_type).0; - let offset_src = - insert_shift_hack(builder, map, arguments.src2.0, size_of as usize)?; - match data.kind { - ptx_parser::RightShiftKind::Arithmetic => { - builder.shift_right_arithmetic( - result_type, - Some(arguments.dst.0), - arguments.src1.0, - offset_src, - )?; - } - ptx_parser::RightShiftKind::Logical => { - builder.shift_right_logical( - result_type, - Some(arguments.dst.0), - arguments.src1.0, - offset_src, - )?; - } - } - } - ast::Instruction::Cvt { data, arguments } => { - emit_cvt(builder, map, opencl, data, arguments)?; - } - ast::Instruction::Cvta { data, arguments } => { - // This would be only meaningful if const/slm/global pointers - // had a different format than generic pointers, but they don't pretty much by ptx definition - // Honestly, I have no idea why this instruction exists and is emitted by the compiler - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::B64); - builder.copy_object(result_type.0, Some(arguments.dst.0), arguments.src.0)?; - } - ast::Instruction::SetpBool { .. } => todo!(), - ast::Instruction::Mad { data, arguments } => match data { - ast::MadDetails::Integer { - type_, - control, - saturate, - } => { - if *saturate { - todo!() - } - if type_.kind() == ast::ScalarKind::Signed { - emit_mad_sint(builder, map, opencl, *type_, *control, arguments)? - } else { - emit_mad_uint(builder, map, opencl, *type_, *control, arguments)? - } - } - ast::MadDetails::Float(desc) => { - emit_mad_float(builder, map, opencl, desc, arguments)? - } - }, - ast::Instruction::Fma { data, arguments } => { - emit_fma_float(builder, map, opencl, data, arguments)? - } - ast::Instruction::Or { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, *data).0; - if *data == ast::ScalarType::Pred { - builder.logical_or( - result_type, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } else { - builder.bitwise_or( - result_type, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } - } - ast::Instruction::Sub { data, arguments } => match data { - ast::ArithDetails::Integer(desc) => { - emit_sub_int(builder, map, desc.type_.into(), desc.saturate, arguments)?; - } - ast::ArithDetails::Float(desc) => { - emit_sub_float(builder, map, desc, arguments)?; - } - }, - ast::Instruction::Min { data, arguments } => { - emit_min(builder, map, opencl, data, arguments)?; - } - ast::Instruction::Max { data, arguments } => { - emit_max(builder, map, opencl, data, arguments)?; - } - ast::Instruction::Rcp { data, arguments } => { - emit_rcp(builder, map, opencl, data, arguments)?; - } - ast::Instruction::And { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, *data); - if *data == ast::ScalarType::Pred { - builder.logical_and( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } else { - builder.bitwise_and( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } - } - ast::Instruction::Selp { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, *data); - builder.select( - result_type.0, - Some(arguments.dst.0), - arguments.src3.0, - arguments.src1.0, - arguments.src2.0, - )?; - } - // TODO: implement named barriers - ast::Instruction::Bar { data, arguments } => { - let workgroup_scope = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(spirv::Scope::Workgroup as u32), - )?; - let barrier_semantics = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr( - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - )?; - builder.control_barrier( - workgroup_scope.0, - workgroup_scope.0, - barrier_semantics.0, - )?; - } - ast::Instruction::Atom { data, arguments } => { - emit_atom(builder, map, data, arguments)?; - } - ast::Instruction::AtomCas { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, data.type_); - let memory_const = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(scope_to_spirv(data.scope) as u32), - )?; - let semantics_const = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(semantics_to_spirv(data.semantics).bits()), - )?; - builder.atomic_compare_exchange( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - memory_const.0, - semantics_const.0, - semantics_const.0, - arguments.src3.0, - arguments.src2.0, - )?; - } - ast::Instruction::Div { data, arguments } => match data { - ast::DivDetails::Unsigned(t) => { - let result_type = map.get_or_add_scalar(builder, (*t).into()); - builder.u_div( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } - ast::DivDetails::Signed(t) => { - let result_type = map.get_or_add_scalar(builder, (*t).into()); - builder.s_div( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } - ast::DivDetails::Float(t) => { - let result_type = map.get_or_add_scalar(builder, t.type_.into()); - builder.f_div( - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - emit_float_div_decoration(builder, arguments.dst, t.kind); - } - }, - ast::Instruction::Sqrt { data, arguments } => { - emit_sqrt(builder, map, opencl, data, arguments)?; - } - ast::Instruction::Rsqrt { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, data.type_.into()); - builder.ext_inst( - result_type.0, - Some(arguments.dst.0), - opencl, - spirv::CLOp::rsqrt as spirv::Word, - [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), - )?; - } - ast::Instruction::Neg { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, data.type_); - let negate_func = if data.type_.kind() == ast::ScalarKind::Float { - dr::Builder::f_negate - } else { - dr::Builder::s_negate - }; - negate_func( - builder, - result_type.0, - Some(arguments.dst.0), - arguments.src.0, - )?; - } - ast::Instruction::Sin { arguments, .. } => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); - builder.ext_inst( - result_type.0, - Some(arguments.dst.0), - opencl, - spirv::CLOp::sin as u32, - [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), - )?; - } - ast::Instruction::Cos { arguments, .. } => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); - builder.ext_inst( - result_type.0, - Some(arguments.dst.0), - opencl, - spirv::CLOp::cos as u32, - [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), - )?; - } - ast::Instruction::Lg2 { arguments, .. } => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); - builder.ext_inst( - result_type.0, - Some(arguments.dst.0), - opencl, - spirv::CLOp::log2 as u32, - [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), - )?; - } - ast::Instruction::Ex2 { arguments, .. } => { - let result_type = map.get_or_add_scalar(builder, ast::ScalarType::F32); - builder.ext_inst( - result_type.0, - Some(arguments.dst.0), - opencl, - spirv::CLOp::exp2 as u32, - [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), - )?; - } - ast::Instruction::Clz { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, (*data).into()); - builder.ext_inst( - result_type.0, - Some(arguments.dst.0), - opencl, - spirv::CLOp::clz as u32, - [dr::Operand::IdRef(arguments.src.0)].iter().cloned(), - )?; - } - ast::Instruction::Brev { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, (*data).into()); - builder.bit_reverse(result_type.0, Some(arguments.dst.0), arguments.src.0)?; - } - ast::Instruction::Popc { data, arguments } => { - let result_type = map.get_or_add_scalar(builder, (*data).into()); - builder.bit_count(result_type.0, Some(arguments.dst.0), arguments.src.0)?; - } - ast::Instruction::Xor { data, arguments } => { - let builder_fn: fn( - &mut dr::Builder, - u32, - Option, - u32, - u32, - ) -> Result = match data { - ast::ScalarType::Pred => emit_logical_xor_spirv, - _ => dr::Builder::bitwise_xor, - }; - let result_type = map.get_or_add_scalar(builder, (*data).into()); - builder_fn( - builder, - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } - ast::Instruction::Bfe { .. } - | ast::Instruction::Bfi { .. } - | ast::Instruction::Activemask { .. } => { - // Should have beeen replaced with a funciton call earlier - return Err(error_unreachable()); - } - - ast::Instruction::Rem { data, arguments } => { - let builder_fn = if data.kind() == ast::ScalarKind::Signed { - dr::Builder::s_mod - } else { - dr::Builder::u_mod - }; - let result_type = map.get_or_add_scalar(builder, (*data).into()); - builder_fn( - builder, - result_type.0, - Some(arguments.dst.0), - arguments.src1.0, - arguments.src2.0, - )?; - } - ast::Instruction::Prmt { data, arguments } => { - let control = *data as u32; - let components = [ - (control >> 0) & 0b1111, - (control >> 4) & 0b1111, - (control >> 8) & 0b1111, - (control >> 12) & 0b1111, - ]; - if components.iter().any(|&c| c > 7) { - return Err(TranslateError::Todo); - } - let vec4_b8_type = - map.get_or_add(builder, SpirvType::Vector(SpirvScalarKey::B8, 4)); - let b32_type = map.get_or_add_scalar(builder, ast::ScalarType::B32); - let src1_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src1.0)?; - let src2_vector = builder.bitcast(vec4_b8_type.0, None, arguments.src2.0)?; - let dst_vector = builder.vector_shuffle( - vec4_b8_type.0, - None, - src1_vector, - src2_vector, - components, - )?; - builder.bitcast(b32_type.0, Some(arguments.dst.0), dst_vector)?; - } - ast::Instruction::Membar { data } => { - let (scope, semantics) = match data { - ast::MemScope::Cta => ( - spirv::Scope::Workgroup, - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - ast::MemScope::Gpu => ( - spirv::Scope::Device, - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - ast::MemScope::Sys => ( - spirv::Scope::CrossDevice, - spirv::MemorySemantics::CROSS_WORKGROUP_MEMORY - | spirv::MemorySemantics::WORKGROUP_MEMORY - | spirv::MemorySemantics::SEQUENTIALLY_CONSISTENT, - ), - - ast::MemScope::Cluster => todo!(), - }; - let spirv_scope = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(scope as u32), - )?; - let spirv_semantics = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(semantics), - )?; - builder.memory_barrier(spirv_scope.0, spirv_semantics.0)?; - } - }, - Statement::LoadVar(details) => { - emit_load_var(builder, map, details)?; - } - Statement::StoreVar(details) => { - let dst_ptr = match details.member_index { - Some(index) => { - let result_ptr_type = map.get_or_add( - builder, - SpirvType::pointer_to( - details.typ.clone(), - spirv::StorageClass::Function, - ), - ); - let index_spirv = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(index as u32), - )?; - builder.in_bounds_access_chain( - result_ptr_type.0, - None, - details.arg.src1.0, - [index_spirv.0].iter().copied(), - )? - } - None => details.arg.src1.0, - }; - builder.store(dst_ptr, details.arg.src2.0, None, iter::empty())?; - } - Statement::RetValue(_, id) => { - builder.ret_value(id.0)?; - } - Statement::PtrAccess(PtrAccess { - underlying_type, - state_space, - dst, - ptr_src, - offset_src, - }) => { - let u8_pointer = map.get_or_add( - builder, - SpirvType::new(ast::Type::Pointer(ast::ScalarType::U8, *state_space)), - ); - let result_type = map.get_or_add( - builder, - SpirvType::pointer_to(underlying_type.clone(), space_to_spirv(*state_space)), - ); - let ptr_src_u8 = builder.bitcast(u8_pointer.0, None, ptr_src.0)?; - let temp = builder.in_bounds_ptr_access_chain( - u8_pointer.0, - None, - ptr_src_u8, - offset_src.0, - iter::empty(), - )?; - builder.bitcast(result_type.0, Some(dst.0), temp)?; - } - Statement::RepackVector(repack) => { - if repack.is_extract { - let scalar_type = map.get_or_add_scalar(builder, repack.typ); - for (index, dst_id) in repack.unpacked.iter().enumerate() { - builder.composite_extract( - scalar_type.0, - Some(dst_id.0), - repack.packed.0, - [index as u32].iter().copied(), - )?; - } - } else { - let vector_type = map.get_or_add( - builder, - SpirvType::Vector( - SpirvScalarKey::from(repack.typ), - repack.unpacked.len() as u8, - ), - ); - let mut temp_vec = builder.undef(vector_type.0, None); - for (index, src_id) in repack.unpacked.iter().enumerate() { - temp_vec = builder.composite_insert( - vector_type.0, - None, - src_id.0, - temp_vec, - [index as u32].iter().copied(), - )?; - } - builder.copy_object(vector_type.0, Some(repack.packed.0), temp_vec)?; - } - } - Statement::VectorAccess(vector_access) => todo!(), - } - } - Ok(()) -} - -fn emit_function_linkage<'input>( - builder: &mut dr::Builder, - id_defs: &GlobalStringIdResolver<'input>, - f: &Function, - fn_name: SpirvWord, -) -> Result<(), TranslateError> { - if f.linkage == ast::LinkingDirective::NONE { - return Ok(()); - }; - let linking_name = match f.func_decl.borrow().name { - // According to SPIR-V rules linkage attributes are invalid on kernels - ast::MethodName::Kernel(..) => return Ok(()), - ast::MethodName::Func(fn_id) => f.import_as.as_deref().map_or_else( - || match id_defs.reverse_variables.get(&fn_id) { - Some(fn_name) => Ok(fn_name), - None => Err(error_unknown_symbol()), - }, - Result::Ok, - )?, - }; - emit_linking_decoration(builder, id_defs, Some(linking_name), fn_name, f.linkage); - Ok(()) -} - -fn get_function_type( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - spirv_input: impl Iterator, - spirv_output: &[ast::Variable], -) -> (SpirvWord, SpirvWord) { - map.get_or_add_fn( - builder, - spirv_input, - spirv_output - .iter() - .map(|var| SpirvType::new(var.v_type.clone())), - ) -} - -fn emit_linking_decoration<'input>( - builder: &mut dr::Builder, - id_defs: &GlobalStringIdResolver<'input>, - name_override: Option<&str>, - name: SpirvWord, - linking: ast::LinkingDirective, -) { - if linking == ast::LinkingDirective::NONE { - return; - } - if linking.contains(ast::LinkingDirective::VISIBLE) { - let string_name = - name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); - builder.decorate( - name.0, - spirv::Decoration::LinkageAttributes, - [ - dr::Operand::LiteralString(string_name.to_string()), - dr::Operand::LinkageType(spirv::LinkageType::Export), - ] - .iter() - .cloned(), - ); - } else if linking.contains(ast::LinkingDirective::EXTERN) { - let string_name = - name_override.unwrap_or_else(|| id_defs.reverse_variables.get(&name).unwrap()); - builder.decorate( - name.0, - spirv::Decoration::LinkageAttributes, - [ - dr::Operand::LiteralString(string_name.to_string()), - dr::Operand::LinkageType(spirv::LinkageType::Import), - ] - .iter() - .cloned(), - ); - } - // TODO: handle LinkingDirective::WEAK -} - -fn effective_input_arguments<'a>( - this: &'a ast::MethodDeclaration<'a, SpirvWord>, -) -> impl Iterator + 'a { - let is_kernel = matches!(this.name, ast::MethodName::Kernel(_)); - this.input_arguments.iter().map(move |arg| { - if !is_kernel && arg.state_space != ast::StateSpace::Reg { - let spirv_type = - SpirvType::pointer_to(arg.v_type.clone(), space_to_spirv(arg.state_space)); - (arg.name, spirv_type) - } else { - (arg.name, SpirvType::new(arg.v_type.clone())) - } - }) -} - -fn emit_implicit_conversion( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - cv: &ImplicitConversion, -) -> Result<(), TranslateError> { - let from_parts = to_parts(&cv.from_type); - let to_parts = to_parts(&cv.to_type); - match (from_parts.kind, to_parts.kind, &cv.kind) { - (_, _, &ConversionKind::BitToPtr) => { - let dst_type = map.get_or_add( - builder, - SpirvType::pointer_to(cv.to_type.clone(), space_to_spirv(cv.to_space)), - ); - builder.convert_u_to_ptr(dst_type.0, Some(cv.dst.0), cv.src.0)?; - } - (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::Default) => { - if from_parts.width == to_parts.width { - let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - if from_parts.scalar_kind != ast::ScalarKind::Float - && to_parts.scalar_kind != ast::ScalarKind::Float - { - // It is noop, but another instruction expects result of this conversion - builder.copy_object(dst_type.0, Some(cv.dst.0), cv.src.0)?; - } else { - builder.bitcast(dst_type.0, Some(cv.dst.0), cv.src.0)?; - } - } else { - // This block is safe because it's illegal to implictly convert between floating point values - let same_width_bit_type = map.get_or_add( - builder, - SpirvType::new(type_from_parts(TypeParts { - scalar_kind: ast::ScalarKind::Bit, - ..from_parts - })), - ); - let same_width_bit_value = - builder.bitcast(same_width_bit_type.0, None, cv.src.0)?; - let wide_bit_type = type_from_parts(TypeParts { - scalar_kind: ast::ScalarKind::Bit, - ..to_parts - }); - let wide_bit_type_spirv = - map.get_or_add(builder, SpirvType::new(wide_bit_type.clone())); - if to_parts.scalar_kind == ast::ScalarKind::Unsigned - || to_parts.scalar_kind == ast::ScalarKind::Bit - { - builder.u_convert( - wide_bit_type_spirv.0, - Some(cv.dst.0), - same_width_bit_value, - )?; - } else { - let conversion_fn = if from_parts.scalar_kind == ast::ScalarKind::Signed - && to_parts.scalar_kind == ast::ScalarKind::Signed - { - dr::Builder::s_convert - } else { - dr::Builder::u_convert - }; - let wide_bit_value = - conversion_fn(builder, wide_bit_type_spirv.0, None, same_width_bit_value)?; - emit_implicit_conversion( - builder, - map, - &ImplicitConversion { - src: SpirvWord(wide_bit_value), - dst: cv.dst, - from_type: wide_bit_type, - from_space: cv.from_space, - to_type: cv.to_type.clone(), - to_space: cv.to_space, - kind: ConversionKind::Default, - }, - )?; - } - } - } - (TypeKind::Scalar, TypeKind::Scalar, &ConversionKind::SignExtend) => { - let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.s_convert(result_type.0, Some(cv.dst.0), cv.src.0)?; - } - (TypeKind::Vector, TypeKind::Scalar, &ConversionKind::Default) - | (TypeKind::Scalar, TypeKind::Array, &ConversionKind::Default) - | (TypeKind::Array, TypeKind::Scalar, &ConversionKind::Default) => { - let into_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.bitcast(into_type.0, Some(cv.dst.0), cv.src.0)?; - } - (_, _, &ConversionKind::PtrToPtr) => { - let result_type = map.get_or_add( - builder, - SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone())), - space_to_spirv(cv.to_space), - ), - ); - if cv.to_space == ast::StateSpace::Generic && cv.from_space != ast::StateSpace::Generic - { - let src = if cv.from_type != cv.to_type { - let temp_type = map.get_or_add( - builder, - SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone())), - space_to_spirv(cv.from_space), - ), - ); - builder.bitcast(temp_type.0, None, cv.src.0)? - } else { - cv.src.0 - }; - builder.ptr_cast_to_generic(result_type.0, Some(cv.dst.0), src)?; - } else if cv.from_space == ast::StateSpace::Generic - && cv.to_space != ast::StateSpace::Generic - { - let src = if cv.from_type != cv.to_type { - let temp_type = map.get_or_add( - builder, - SpirvType::Pointer( - Box::new(SpirvType::new(cv.to_type.clone())), - space_to_spirv(cv.from_space), - ), - ); - builder.bitcast(temp_type.0, None, cv.src.0)? - } else { - cv.src.0 - }; - builder.generic_cast_to_ptr(result_type.0, Some(cv.dst.0), src)?; - } else { - builder.bitcast(result_type.0, Some(cv.dst.0), cv.src.0)?; - } - } - (_, _, &ConversionKind::AddressOf) => { - let dst_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.convert_ptr_to_u(dst_type.0, Some(cv.dst.0), cv.src.0)?; - } - (TypeKind::Pointer, TypeKind::Scalar, &ConversionKind::Default) => { - let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.convert_ptr_to_u(result_type.0, Some(cv.dst.0), cv.src.0)?; - } - (TypeKind::Scalar, TypeKind::Pointer, &ConversionKind::Default) => { - let result_type = map.get_or_add(builder, SpirvType::new(cv.to_type.clone())); - builder.convert_u_to_ptr(result_type.0, Some(cv.dst.0), cv.src.0)?; - } - _ => unreachable!(), - } - Ok(()) -} - -fn vec_repr(t: T) -> Vec { - let mut result = vec![0; mem::size_of::()]; - unsafe { std::ptr::copy_nonoverlapping(&t, result.as_mut_ptr() as *mut _, 1) }; - result -} - -fn emit_abs( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - d: &ast::TypeFtz, - arg: &ast::AbsArgs, -) -> Result<(), dr::Error> { - let scalar_t = ast::ScalarType::from(d.type_); - let result_type = map.get_or_add(builder, SpirvType::from(scalar_t)); - let cl_abs = if scalar_t.kind() == ast::ScalarKind::Signed { - spirv::CLOp::s_abs - } else { - spirv::CLOp::fabs - }; - builder.ext_inst( - result_type.0, - Some(arg.dst.0), - opencl, - cl_abs as spirv::Word, - [dr::Operand::IdRef(arg.src.0)].iter().cloned(), - )?; - Ok(()) -} - -fn emit_mul_int( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - type_: ast::ScalarType, - control: ast::MulIntControl, - arg: &ast::MulArgs, -) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(type_)); - match control { - ast::MulIntControl::Low => { - builder.i_mul(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; - } - ast::MulIntControl::High => { - let opencl_inst = if type_.kind() == ast::ScalarKind::Signed { - spirv::CLOp::s_mul_hi - } else { - spirv::CLOp::u_mul_hi - }; - builder.ext_inst( - inst_type.0, - Some(arg.dst.0), - opencl, - opencl_inst as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - ] - .iter() - .cloned(), - )?; - } - ast::MulIntControl::Wide => { - let instr_width = type_.size_of(); - let instr_kind = type_.kind(); - let dst_type = scalar_from_parts(instr_width * 2, instr_kind); - let dst_type_id = map.get_or_add_scalar(builder, dst_type); - let (src1, src2) = if type_.kind() == ast::ScalarKind::Signed { - let src1 = builder.s_convert(dst_type_id.0, None, arg.src1.0)?; - let src2 = builder.s_convert(dst_type_id.0, None, arg.src2.0)?; - (src1, src2) - } else { - let src1 = builder.u_convert(dst_type_id.0, None, arg.src1.0)?; - let src2 = builder.u_convert(dst_type_id.0, None, arg.src2.0)?; - (src1, src2) - }; - builder.i_mul(dst_type_id.0, Some(arg.dst.0), src1, src2)?; - builder.decorate(arg.dst.0, spirv::Decoration::NoSignedWrap, iter::empty()); - } - } - Ok(()) -} - -fn emit_mul_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - ctr: &ast::ArithFloat, - arg: &ast::MulArgs, -) -> Result<(), dr::Error> { - if ctr.saturate { - todo!() - } - let result_type = map.get_or_add_scalar(builder, ctr.type_.into()); - builder.f_mul(result_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; - emit_rounding_decoration(builder, arg.dst, ctr.rounding); - Ok(()) -} - -fn scalar_from_parts(width: u8, kind: ast::ScalarKind) -> ast::ScalarType { - match kind { - ast::ScalarKind::Float => match width { - 2 => ast::ScalarType::F16, - 4 => ast::ScalarType::F32, - 8 => ast::ScalarType::F64, - _ => unreachable!(), - }, - ast::ScalarKind::Bit => match width { - 1 => ast::ScalarType::B8, - 2 => ast::ScalarType::B16, - 4 => ast::ScalarType::B32, - 8 => ast::ScalarType::B64, - _ => unreachable!(), - }, - ast::ScalarKind::Signed => match width { - 1 => ast::ScalarType::S8, - 2 => ast::ScalarType::S16, - 4 => ast::ScalarType::S32, - 8 => ast::ScalarType::S64, - _ => unreachable!(), - }, - ast::ScalarKind::Unsigned => match width { - 1 => ast::ScalarType::U8, - 2 => ast::ScalarType::U16, - 4 => ast::ScalarType::U32, - 8 => ast::ScalarType::U64, - _ => unreachable!(), - }, - ast::ScalarKind::Pred => ast::ScalarType::Pred, - } -} - -fn emit_rounding_decoration( - builder: &mut dr::Builder, - dst: SpirvWord, - rounding: Option, -) { - if let Some(rounding) = rounding { - builder.decorate( - dst.0, - spirv::Decoration::FPRoundingMode, - [rounding_to_spirv(rounding)].iter().cloned(), - ); - } -} - -fn rounding_to_spirv(this: ast::RoundingMode) -> rspirv::dr::Operand { - let mode = match this { - ast::RoundingMode::NearestEven => spirv::FPRoundingMode::RTE, - ast::RoundingMode::Zero => spirv::FPRoundingMode::RTZ, - ast::RoundingMode::PositiveInf => spirv::FPRoundingMode::RTP, - ast::RoundingMode::NegativeInf => spirv::FPRoundingMode::RTN, - }; - rspirv::dr::Operand::FPRoundingMode(mode) -} - -fn emit_add_int( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - typ: ast::ScalarType, - saturate: bool, - arg: &ast::AddArgs, -) -> Result<(), dr::Error> { - if saturate { - todo!() - } - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))); - builder.i_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; - Ok(()) -} - -fn emit_add_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - desc: &ast::ArithFloat, - arg: &ast::AddArgs, -) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))); - builder.f_add(inst_type.0, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; - emit_rounding_decoration(builder, arg.dst, desc.rounding); - Ok(()) -} - -fn emit_setp( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - setp: &ast::SetpData, - arg: &ast::SetpArgs, -) -> Result<(), dr::Error> { - let result_type = map - .get_or_add(builder, SpirvType::Base(SpirvScalarKey::Pred)) - .0; - let result_id = Some(arg.dst1.0); - let operand_1 = arg.src1.0; - let operand_2 = arg.src2.0; - match setp.cmp_op { - ast::SetpCompareOp::Integer(ast::SetpCompareInt::Eq) => { - builder.i_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::Eq) => { - builder.f_ord_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::NotEq) => { - builder.i_not_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NotEq) => { - builder.f_ord_not_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLess) => { - builder.u_less_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLess) => { - builder.s_less_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::Less) => { - builder.f_ord_less_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedLessOrEq) => { - builder.u_less_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedLessOrEq) => { - builder.s_less_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::LessOrEq) => { - builder.f_ord_less_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreater) => { - builder.u_greater_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreater) => { - builder.s_greater_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::Greater) => { - builder.f_ord_greater_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::UnsignedGreaterOrEq) => { - builder.u_greater_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Integer(ast::SetpCompareInt::SignedGreaterOrEq) => { - builder.s_greater_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::GreaterOrEq) => { - builder.f_ord_greater_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanEq) => { - builder.f_unord_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanNotEq) => { - builder.f_unord_not_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLess) => { - builder.f_unord_less_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanLessOrEq) => { - builder.f_unord_less_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreater) => { - builder.f_unord_greater_than(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::NanGreaterOrEq) => { - builder.f_unord_greater_than_equal(result_type, result_id, operand_1, operand_2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsAnyNan) => { - let temp1 = builder.is_nan(result_type, None, operand_1)?; - let temp2 = builder.is_nan(result_type, None, operand_2)?; - builder.logical_or(result_type, result_id, temp1, temp2) - } - ast::SetpCompareOp::Float(ast::SetpCompareFloat::IsNotNan) => { - let temp1 = builder.is_nan(result_type, None, operand_1)?; - let temp2 = builder.is_nan(result_type, None, operand_2)?; - let any_nan = builder.logical_or(result_type, None, temp1, temp2)?; - logical_not(builder, result_type, result_id, any_nan) - } - _ => todo!(), - }?; - Ok(()) -} - -// HACK ALERT -// Temporary workaround until IGC gets its shit together -// Currently IGC carries two copies of SPIRV-LLVM translator -// a new one in /llvm-spirv/ and old one in /IGC/AdaptorOCL/SPIRV/. -// Obviously, old and buggy one is used for compiling L0 SPIRV -// https://github.com/intel/intel-graphics-compiler/issues/148 -fn logical_not( - builder: &mut dr::Builder, - result_type: spirv::Word, - result_id: Option, - operand: spirv::Word, -) -> Result { - let const_true = builder.constant_true(result_type, None); - let const_false = builder.constant_false(result_type, None); - builder.select(result_type, result_id, operand, const_false, const_true) -} - -// HACK ALERT -// For some reason IGC fails linking if the value and shift size are of different type -fn insert_shift_hack( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - offset_var: spirv::Word, - size_of: usize, -) -> Result { - let result_type = match size_of { - 2 => map.get_or_add_scalar(builder, ast::ScalarType::B16), - 8 => map.get_or_add_scalar(builder, ast::ScalarType::B64), - 4 => return Ok(offset_var), - _ => return Err(error_unreachable()), - }; - Ok(builder.u_convert(result_type.0, None, offset_var)?) -} - -fn emit_cvt( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - dets: &ast::CvtDetails, - arg: &ast::CvtArgs, -) -> Result<(), TranslateError> { - match dets.mode { - ptx_parser::CvtMode::SignExtend => { - let cv = ImplicitConversion { - src: arg.src, - dst: arg.dst, - from_type: dets.from.into(), - from_space: ast::StateSpace::Reg, - to_type: dets.to.into(), - to_space: ast::StateSpace::Reg, - kind: ConversionKind::SignExtend, - }; - emit_implicit_conversion(builder, map, &cv)?; - } - ptx_parser::CvtMode::ZeroExtend - | ptx_parser::CvtMode::Truncate - | ptx_parser::CvtMode::Bitcast => { - let cv = ImplicitConversion { - src: arg.src, - dst: arg.dst, - from_type: dets.from.into(), - from_space: ast::StateSpace::Reg, - to_type: dets.to.into(), - to_space: ast::StateSpace::Reg, - kind: ConversionKind::Default, - }; - emit_implicit_conversion(builder, map, &cv)?; - } - ptx_parser::CvtMode::SaturateUnsignedToSigned => { - let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); - builder.sat_convert_u_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; - } - ptx_parser::CvtMode::SaturateSignedToUnsigned => { - let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); - builder.sat_convert_s_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; - } - ptx_parser::CvtMode::FPExtend { flush_to_zero } => { - let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); - builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; - } - ptx_parser::CvtMode::FPTruncate { - rounding, - flush_to_zero, - } => { - let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); - builder.f_convert(result_type.0, Some(arg.dst.0), arg.src.0)?; - emit_rounding_decoration(builder, arg.dst, Some(rounding)); - } - ptx_parser::CvtMode::FPRound { - integer_rounding, - flush_to_zero, - } => { - if flush_to_zero == Some(true) { - todo!() - } - let result_type = map.get_or_add(builder, SpirvType::from(dets.to)); - match integer_rounding { - Some(ast::RoundingMode::NearestEven) => { - builder.ext_inst( - result_type.0, - Some(arg.dst.0), - opencl, - spirv::CLOp::rint as u32, - [dr::Operand::IdRef(arg.src.0)].iter().cloned(), - )?; - } - Some(ast::RoundingMode::Zero) => { - builder.ext_inst( - result_type.0, - Some(arg.dst.0), - opencl, - spirv::CLOp::trunc as u32, - [dr::Operand::IdRef(arg.src.0)].iter().cloned(), - )?; - } - Some(ast::RoundingMode::NegativeInf) => { - builder.ext_inst( - result_type.0, - Some(arg.dst.0), - opencl, - spirv::CLOp::floor as u32, - [dr::Operand::IdRef(arg.src.0)].iter().cloned(), - )?; - } - Some(ast::RoundingMode::PositiveInf) => { - builder.ext_inst( - result_type.0, - Some(arg.dst.0), - opencl, - spirv::CLOp::ceil as u32, - [dr::Operand::IdRef(arg.src.0)].iter().cloned(), - )?; - } - None => { - builder.copy_object(result_type.0, Some(arg.dst.0), arg.src.0)?; - } - } - } - ptx_parser::CvtMode::SignedFromFP { - rounding, - flush_to_zero, - } => { - let dest_t: ast::ScalarType = dets.to.into(); - let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - builder.convert_f_to_s(result_type.0, Some(arg.dst.0), arg.src.0)?; - emit_rounding_decoration(builder, arg.dst, Some(rounding)); - } - ptx_parser::CvtMode::UnsignedFromFP { - rounding, - flush_to_zero, - } => { - let dest_t: ast::ScalarType = dets.to.into(); - let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - builder.convert_f_to_u(result_type.0, Some(arg.dst.0), arg.src.0)?; - emit_rounding_decoration(builder, arg.dst, Some(rounding)); - } - ptx_parser::CvtMode::FPFromSigned(rounding) => { - let dest_t: ast::ScalarType = dets.to.into(); - let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - builder.convert_s_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?; - emit_rounding_decoration(builder, arg.dst, Some(rounding)); - } - ptx_parser::CvtMode::FPFromUnsigned(rounding) => { - let dest_t: ast::ScalarType = dets.to.into(); - let result_type = map.get_or_add(builder, SpirvType::from(dest_t)); - builder.convert_u_to_f(result_type.0, Some(arg.dst.0), arg.src.0)?; - emit_rounding_decoration(builder, arg.dst, Some(rounding)); - } - } - Ok(()) -} - -fn emit_mad_uint( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - type_: ast::ScalarType, - control: ast::MulIntControl, - arg: &ast::MadArgs, -) -> Result<(), dr::Error> { - let inst_type = map - .get_or_add(builder, SpirvType::from(ast::ScalarType::from(type_))) - .0; - match control { - ast::MulIntControl::Low => { - let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?; - builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?; - } - ast::MulIntControl::High => { - builder.ext_inst( - inst_type, - Some(arg.dst.0), - opencl, - spirv::CLOp::u_mad_hi as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - dr::Operand::IdRef(arg.src3.0), - ] - .iter() - .cloned(), - )?; - } - ast::MulIntControl::Wide => todo!(), - }; - Ok(()) -} - -fn emit_mad_sint( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - type_: ast::ScalarType, - control: ast::MulIntControl, - arg: &ast::MadArgs, -) -> Result<(), dr::Error> { - let inst_type = map.get_or_add(builder, SpirvType::from(type_)).0; - match control { - ast::MulIntControl::Low => { - let mul_result = builder.i_mul(inst_type, None, arg.src1.0, arg.src2.0)?; - builder.i_add(inst_type, Some(arg.dst.0), arg.src3.0, mul_result)?; - } - ast::MulIntControl::High => { - builder.ext_inst( - inst_type, - Some(arg.dst.0), - opencl, - spirv::CLOp::s_mad_hi as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - dr::Operand::IdRef(arg.src3.0), - ] - .iter() - .cloned(), - )?; - } - ast::MulIntControl::Wide => todo!(), - }; - Ok(()) -} - -fn emit_mad_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::ArithFloat, - arg: &ast::MadArgs, -) -> Result<(), dr::Error> { - let inst_type = map - .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) - .0; - builder.ext_inst( - inst_type, - Some(arg.dst.0), - opencl, - spirv::CLOp::mad as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - dr::Operand::IdRef(arg.src3.0), - ] - .iter() - .cloned(), - )?; - Ok(()) -} - -fn emit_fma_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::ArithFloat, - arg: &ast::FmaArgs, -) -> Result<(), dr::Error> { - let inst_type = map - .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) - .0; - builder.ext_inst( - inst_type, - Some(arg.dst.0), - opencl, - spirv::CLOp::fma as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - dr::Operand::IdRef(arg.src3.0), - ] - .iter() - .cloned(), - )?; - Ok(()) -} - -fn emit_sub_int( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - typ: ast::ScalarType, - saturate: bool, - arg: &ast::SubArgs, -) -> Result<(), dr::Error> { - if saturate { - todo!() - } - let inst_type = map - .get_or_add(builder, SpirvType::from(ast::ScalarType::from(typ))) - .0; - builder.i_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; - Ok(()) -} - -fn emit_sub_float( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - desc: &ast::ArithFloat, - arg: &ast::SubArgs, -) -> Result<(), dr::Error> { - let inst_type = map - .get_or_add(builder, SpirvType::from(ast::ScalarType::from(desc.type_))) - .0; - builder.f_sub(inst_type, Some(arg.dst.0), arg.src1.0, arg.src2.0)?; - emit_rounding_decoration(builder, arg.dst, desc.rounding); - Ok(()) -} - -fn emit_min( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MinMaxDetails, - arg: &ast::MinArgs, -) -> Result<(), dr::Error> { - let cl_op = match desc { - ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_min, - ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_min, - ast::MinMaxDetails::Float(_) => spirv::CLOp::fmin, - }; - let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_())); - builder.ext_inst( - inst_type.0, - Some(arg.dst.0), - opencl, - cl_op as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - ] - .iter() - .cloned(), - )?; - Ok(()) -} - -fn emit_max( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::MinMaxDetails, - arg: &ast::MaxArgs, -) -> Result<(), dr::Error> { - let cl_op = match desc { - ast::MinMaxDetails::Signed(_) => spirv::CLOp::s_max, - ast::MinMaxDetails::Unsigned(_) => spirv::CLOp::u_max, - ast::MinMaxDetails::Float(_) => spirv::CLOp::fmax, - }; - let inst_type = map.get_or_add(builder, SpirvType::from(desc.type_())); - builder.ext_inst( - inst_type.0, - Some(arg.dst.0), - opencl, - cl_op as spirv::Word, - [ - dr::Operand::IdRef(arg.src1.0), - dr::Operand::IdRef(arg.src2.0), - ] - .iter() - .cloned(), - )?; - Ok(()) -} - -fn emit_rcp( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - desc: &ast::RcpData, - arg: &ast::RcpArgs, -) -> Result<(), TranslateError> { - let is_f64 = desc.type_ == ast::ScalarType::F64; - let (instr_type, constant) = if is_f64 { - (ast::ScalarType::F64, vec_repr(1.0f64)) - } else { - (ast::ScalarType::F32, vec_repr(1.0f32)) - }; - let result_type = map.get_or_add_scalar(builder, instr_type); - let rounding = match desc.kind { - ptx_parser::RcpKind::Approx => { - builder.ext_inst( - result_type.0, - Some(arg.dst.0), - opencl, - spirv::CLOp::native_recip as u32, - [dr::Operand::IdRef(arg.src.0)].iter().cloned(), - )?; - return Ok(()); - } - ptx_parser::RcpKind::Compliant(rounding) => rounding, - }; - let one = map.get_or_add_constant(builder, &ast::Type::Scalar(instr_type), &constant)?; - builder.f_div(result_type.0, Some(arg.dst.0), one.0, arg.src.0)?; - emit_rounding_decoration(builder, arg.dst, Some(rounding)); - builder.decorate( - arg.dst.0, - spirv::Decoration::FPFastMathMode, - [dr::Operand::FPFastMathMode( - spirv::FPFastMathMode::ALLOW_RECIP, - )] - .iter() - .cloned(), - ); - Ok(()) -} - -fn emit_atom( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - details: &ast::AtomDetails, - arg: &ast::AtomArgs, -) -> Result<(), TranslateError> { - let spirv_op = match details.op { - ptx_parser::AtomicOp::And => dr::Builder::atomic_and, - ptx_parser::AtomicOp::Or => dr::Builder::atomic_or, - ptx_parser::AtomicOp::Xor => dr::Builder::atomic_xor, - ptx_parser::AtomicOp::Exchange => dr::Builder::atomic_exchange, - ptx_parser::AtomicOp::Add => dr::Builder::atomic_i_add, - ptx_parser::AtomicOp::IncrementWrap | ptx_parser::AtomicOp::DecrementWrap => { - return Err(error_unreachable()) - } - ptx_parser::AtomicOp::SignedMin => dr::Builder::atomic_s_min, - ptx_parser::AtomicOp::UnsignedMin => dr::Builder::atomic_u_min, - ptx_parser::AtomicOp::SignedMax => dr::Builder::atomic_s_max, - ptx_parser::AtomicOp::UnsignedMax => dr::Builder::atomic_u_max, - ptx_parser::AtomicOp::FloatAdd => dr::Builder::atomic_f_add_ext, - ptx_parser::AtomicOp::FloatMin => todo!(), - ptx_parser::AtomicOp::FloatMax => todo!(), - }; - let result_type = map.get_or_add(builder, SpirvType::new(details.type_.clone())); - let memory_const = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(scope_to_spirv(details.scope) as u32), - )?; - let semantics_const = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(semantics_to_spirv(details.semantics).bits()), - )?; - spirv_op( - builder, - result_type.0, - Some(arg.dst.0), - arg.src1.0, - memory_const.0, - semantics_const.0, - arg.src2.0, - )?; - Ok(()) -} - -fn scope_to_spirv(this: ast::MemScope) -> spirv::Scope { - match this { - ast::MemScope::Cta => spirv::Scope::Workgroup, - ast::MemScope::Gpu => spirv::Scope::Device, - ast::MemScope::Sys => spirv::Scope::CrossDevice, - ptx_parser::MemScope::Cluster => todo!(), - } -} - -fn semantics_to_spirv(this: ast::AtomSemantics) -> spirv::MemorySemantics { - match this { - ast::AtomSemantics::Relaxed => spirv::MemorySemantics::RELAXED, - ast::AtomSemantics::Acquire => spirv::MemorySemantics::ACQUIRE, - ast::AtomSemantics::Release => spirv::MemorySemantics::RELEASE, - ast::AtomSemantics::AcqRel => spirv::MemorySemantics::ACQUIRE_RELEASE, - } -} - -fn emit_float_div_decoration(builder: &mut dr::Builder, dst: SpirvWord, kind: ast::DivFloatKind) { - match kind { - ast::DivFloatKind::Approx => { - builder.decorate( - dst.0, - spirv::Decoration::FPFastMathMode, - [dr::Operand::FPFastMathMode( - spirv::FPFastMathMode::ALLOW_RECIP, - )] - .iter() - .cloned(), - ); - } - ast::DivFloatKind::Rounding(rnd) => { - emit_rounding_decoration(builder, dst, Some(rnd)); - } - ast::DivFloatKind::ApproxFull => {} - } -} - -fn emit_sqrt( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - opencl: spirv::Word, - details: &ast::RcpData, - a: &ast::SqrtArgs, -) -> Result<(), TranslateError> { - let result_type = map.get_or_add_scalar(builder, details.type_.into()); - let (ocl_op, rounding) = match details.kind { - ast::RcpKind::Approx => (spirv::CLOp::sqrt, None), - ast::RcpKind::Compliant(rnd) => (spirv::CLOp::sqrt, Some(rnd)), - }; - builder.ext_inst( - result_type.0, - Some(a.dst.0), - opencl, - ocl_op as spirv::Word, - [dr::Operand::IdRef(a.src.0)].iter().cloned(), - )?; - emit_rounding_decoration(builder, a.dst, rounding); - Ok(()) -} - -// TODO: check what kind of assembly do we emit -fn emit_logical_xor_spirv( - builder: &mut dr::Builder, - result_type: spirv::Word, - result_id: Option, - op1: spirv::Word, - op2: spirv::Word, -) -> Result { - let temp_or = builder.logical_or(result_type, None, op1, op2)?; - let temp_and = builder.logical_and(result_type, None, op1, op2)?; - let temp_neg = logical_not(builder, result_type, None, temp_and)?; - builder.logical_and(result_type, result_id, temp_or, temp_neg) -} - -fn emit_load_var( - builder: &mut dr::Builder, - map: &mut TypeWordMap, - details: &LoadVarDetails, -) -> Result<(), TranslateError> { - let result_type = map.get_or_add(builder, SpirvType::new(details.typ.clone())); - match details.member_index { - Some((index, Some(width))) => { - let vector_type = match details.typ { - ast::Type::Scalar(scalar_t) => ast::Type::Vector(width, scalar_t), - _ => return Err(error_mismatched_type()), - }; - let vector_type_spirv = map.get_or_add(builder, SpirvType::new(vector_type)); - let vector_temp = builder.load( - vector_type_spirv.0, - None, - details.arg.src.0, - None, - iter::empty(), - )?; - builder.composite_extract( - result_type.0, - Some(details.arg.dst.0), - vector_temp, - [index as u32].iter().copied(), - )?; - } - Some((index, None)) => { - let result_ptr_type = map.get_or_add( - builder, - SpirvType::pointer_to(details.typ.clone(), spirv::StorageClass::Function), - ); - let index_spirv = map.get_or_add_constant( - builder, - &ast::Type::Scalar(ast::ScalarType::U32), - &vec_repr(index as u32), - )?; - let src = builder.in_bounds_access_chain( - result_ptr_type.0, - None, - details.arg.src.0, - [index_spirv.0].iter().copied(), - )?; - builder.load( - result_type.0, - Some(details.arg.dst.0), - src, - None, - iter::empty(), - )?; - } - None => { - builder.load( - result_type.0, - Some(details.arg.dst.0), - details.arg.src.0, - None, - iter::empty(), - )?; - } - }; - Ok(()) -} - -fn to_parts(this: &ast::Type) -> TypeParts { - match this { - ast::Type::Scalar(scalar) => TypeParts { - kind: TypeKind::Scalar, - state_space: ast::StateSpace::Reg, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: Vec::new(), - }, - ast::Type::Vector(components, scalar) => TypeParts { - kind: TypeKind::Vector, - state_space: ast::StateSpace::Reg, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: vec![*components as u32], - }, - ast::Type::Array(_, scalar, components) => TypeParts { - kind: TypeKind::Array, - state_space: ast::StateSpace::Reg, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: components.clone(), - }, - ast::Type::Pointer(scalar, space) => TypeParts { - kind: TypeKind::Pointer, - state_space: *space, - scalar_kind: scalar.kind(), - width: scalar.size_of(), - components: Vec::new(), - }, - } -} - -fn type_from_parts(t: TypeParts) -> ast::Type { - match t.kind { - TypeKind::Scalar => ast::Type::Scalar(scalar_from_parts(t.width, t.scalar_kind)), - TypeKind::Vector => ast::Type::Vector( - t.components[0] as u8, - scalar_from_parts(t.width, t.scalar_kind), - ), - TypeKind::Array => ast::Type::Array( - None, - scalar_from_parts(t.width, t.scalar_kind), - t.components, - ), - TypeKind::Pointer => { - ast::Type::Pointer(scalar_from_parts(t.width, t.scalar_kind), t.state_space) - } - } -} - -#[derive(Eq, PartialEq, Clone)] -struct TypeParts { - kind: TypeKind, - scalar_kind: ast::ScalarKind, - width: u8, - state_space: ast::StateSpace, - components: Vec, -} - -#[derive(Eq, PartialEq, Copy, Clone)] -enum TypeKind { - Scalar, - Vector, - Array, - Pointer, -} -- cgit v1.2.3